Source code for ffcv.transforms.common

from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State
from dataclasses import replace

[docs]class Squeeze(Operation): """Remove given dimensions of input of size 1. Operates on tensors. Parameters ---------- *dims : List[int] Dimensions to squeeze. """ def __init__(self, *dims): super().__init__() self.dims = dims
[docs] def generate_code(self) -> Callable: def squeeze(inp, _): inp.squeeze_(*self.dims) return inp return squeeze
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return replace(previous_state, shape=[x for x in previous_state.shape if not x == 1]), None