FFCV

Source code for ffcv.transforms.ops

"""
General operations:

- Collation
- Conversion to PyTorch Tensor
- Change device of Tensor
"""
import torch as ch
import numpy as np
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State
from dataclasses import replace


[docs]class ToTensor(Operation): """Convert from Numpy array to PyTorch Tensor.""" def __init__(self): super().__init__()
[docs] def generate_code(self) -> Callable: def to_tensor(inp, dst): return ch.from_numpy(inp) return to_tensor
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: new_dtype = ch.from_numpy(np.empty((), dtype=previous_state.dtype)).dtype return replace(previous_state, jit_mode=False, dtype=new_dtype), None
[docs]class ToDevice(Operation): """Move tensor to device. Parameters ---------- device: torch.device Device to move to. non_blocking: bool Asynchronous if copying from CPU to GPU. """ def __init__(self, device, non_blocking=True): super().__init__() self.device = device self.non_blocking = non_blocking
[docs] def generate_code(self) -> Callable: def to_device(inp, dst): if len(inp.shape) == 4: if inp.is_contiguous(memory_format=ch.channels_last): dst = dst.reshape(inp.shape[0], inp.shape[2], inp.shape[3], inp.shape[1]) dst = dst.permute(0, 3, 1, 2) dst = dst[:inp.shape[0]] dst.copy_(inp, non_blocking=self.non_blocking) return dst return to_device
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return replace(previous_state, device=self.device), AllocationQuery(previous_state.shape, dtype=previous_state.dtype, device=self.device)
[docs]class ToTorchImage(Operation): """Change tensor to PyTorch format for images (B x C x H x W). Parameters ---------- channels_last : bool Use torch.channels_last. convert_back_int16 : bool Convert to float16. """ def __init__(self, channels_last=True, convert_back_int16=True): super().__init__() self.channels_last = channels_last self.convert_int16 = convert_back_int16 self.enable_int16conv = False
[docs] def generate_code(self) -> Callable: do_conv = self.enable_int16conv def to_torch_image(inp: ch.Tensor, dst): # Returns a permuted view of the same tensor if do_conv: inp = inp.view(dtype=ch.float16) pass inp = inp.permute([0, 3, 1, 2]) # If channels last, it's already contiguous so we're good if self.channels_last: assert inp.is_contiguous(memory_format=ch.channels_last) return inp # Otherwise, need to fill the allocated memory with the contiguous tensor dst[:inp.shape[0]] = inp.contiguous() return dst[:inp.shape[0]] return to_torch_image
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: alloc = None H, W, C = previous_state.shape new_type = previous_state.dtype if new_type is ch.int16 and self.convert_int16: new_type = ch.float16 self.enable_int16conv = True if not self.channels_last: alloc = AllocationQuery((C, H, W), dtype=new_type) return replace(previous_state, shape=(C, H, W), dtype=new_type), alloc
[docs]class Convert(Operation): """Convert to target data type. Parameters ---------- target_dtype: numpy.dtype or torch.dtype Target data type. """ def __init__(self, target_dtype): super().__init__() self.target_dtype = target_dtype
[docs] def generate_code(self) -> Callable: def convert(inp, dst): return inp.type(self.target_dtype) convert.is_parallel = True return convert
# TODO: something weird about device to allocate on
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return replace(previous_state, dtype=self.target_dtype), None
[docs]class View(Operation): """View array using np.view or torch.view. Parameters ---------- target_dtype: numpy.dtype or torch.dtype Target data type. """ def __init__(self, target_dtype): super().__init__() self.target_dtype = target_dtype
[docs] def generate_code(self) -> Callable: def convert(inp, dst): return inp.view(self.target_dtype) convert.is_parallel = True return convert
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return replace(previous_state, dtype=self.target_dtype, jit_mode=False), None