"""
Cutout augmentation (https://arxiv.org/abs/1708.04552)
"""
import numpy as np
from typing import Callable, Optional, Tuple
from dataclasses import replace
from ffcv.pipeline.compiler import Compiler
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State
[docs]class Cutout(Operation):
    """Cutout data augmentation (https://arxiv.org/abs/1708.04552).
    Parameters
    ----------
    crop_size : int
        Size of the random square to cut out.
    fill : Tuple[int, int, int], optional
        An RGB color ((0, 0, 0) by default) to fill the cutout square with.
        Useful for when a normalization layer follows cutout, in which case
        you can set the fill such that the square is zero
        post-normalization.
    """
    def __init__(self, crop_size: int, fill: Tuple[int, int, int] = (0, 0, 0)):
        super().__init__()
        self.crop_size = crop_size
        self.fill = np.array(fill)
[docs]    def generate_code(self) -> Callable:
        my_range = Compiler.get_iterator()
        crop_size = self.crop_size
        fill = self.fill
        def cutout_square(images, *_):
            for i in my_range(images.shape[0]):
                # Generate random origin
                coord = (
                    np.random.randint(images.shape[1] - crop_size + 1),
                    np.random.randint(images.shape[2] - crop_size + 1),
                )
                # Black out image in-place
                images[i, coord[0]:coord[0] + crop_size, coord[1]:coord[1] + crop_size] = fill
            return images
        cutout_square.is_parallel = True
        return cutout_square 
[docs]    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
        return replace(previous_state, jit_mode=True), None