Source code for ffcv.transforms.replace_label

Replace label
from typing import Tuple

import numpy as np
from dataclasses import replace
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State
from ..pipeline.compiler import Compiler

[docs]class ReplaceLabel(Operation): """Replace label of specified images. Parameters ---------- indices : Sequence[int] The indices of images to relabel. new_label : int The new label to assign. """ def __init__(self, indices, new_label: int): super().__init__() self.indices = np.sort(indices) self.new_label = new_label
[docs] def generate_code(self) -> Callable: to_change = self.indices new_label = self.new_label my_range = Compiler.get_iterator() def replace_label(labels, temp_array, indices): for i in my_range(labels.shape[0]): sample_ix = indices[i] position = np.searchsorted(to_change, sample_ix) if position < len(to_change) and to_change[position] == sample_ix: labels[i] = new_label return labels replace_label.is_parallel = True replace_label.with_indices = True return replace_label
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return (replace(previous_state, jit_mode=True), None)