from typing import Callable, TYPE_CHECKING, Tuple, Type
import warnings
import json
from dataclasses import replace
import numpy as np
import torch as ch
from .base import Field, ARG_TYPE
from ..pipeline.operation import Operation
from ..pipeline.state import State
from ..pipeline.compiler import Compiler
from ..pipeline.allocation_query import AllocationQuery
from ..libffcv import memcpy
if TYPE_CHECKING:
from ..memory_managers.base import MemoryManager
[docs]class NDArrayDecoder(Operation):
"""
Default decoder for :class:`~ffcv.fields.NDArrayField`.
"""
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
return (
replace(previous_state, jit_mode=True,
shape=self.field.shape,
dtype=self.field.dtype),
AllocationQuery(self.field.shape, self.field.dtype)
)
[docs] def generate_code(self) -> Callable:
my_range = Compiler.get_iterator()
mem_read = self.memory_read
my_memcpy = Compiler.compile(memcpy)
def decoder(indices, destination, metadata, storage_state):
for ix in my_range(indices.shape[0]):
sample_id = indices[ix]
ptr = metadata[sample_id]
data = mem_read(ptr, storage_state)
my_memcpy(data, destination[ix].view(np.uint8))
return destination
return decoder
NDArrayArgsType = np.dtype([
('shape', '<u8', 32), # 32 is the max number of dimensions for numpy
('type_length', '<u8'), # length of the dtype description
])
[docs]class NDArrayField(Field):
"""A subclass of :class:`~ffcv.fields.Field` supporting
multi-dimensional fixed size matrices of any numpy type.
"""
def __init__(self, dtype:np.dtype, shape:Tuple[int, ...]):
self.dtype = dtype
self.shape = shape
self.element_size = dtype.itemsize * np.prod(shape)
if dtype == np.uint16:
warnings.warn("Pytorch currently doesn't support uint16"
"we recommend storing as int16 and reinterpret your data later"
"in your pipeline")
@property
def metadata_type(self) -> np.dtype:
return np.dtype('<u8')
[docs] @staticmethod
def from_binary(binary: ARG_TYPE) -> Field:
header_size = NDArrayArgsType.itemsize
header = binary[:header_size].view(NDArrayArgsType)[0]
type_length = header['type_length']
type_data = binary[header_size:][:type_length].tobytes().decode('ascii')
type_desc = json.loads(type_data)
type_desc = [tuple(x) for x in type_desc]
assert len(type_desc) == 1
dtype = np.dtype(type_desc)['f0']
shape = list(header['shape'])
while shape[-1] == 0:
shape.pop()
return NDArrayField(dtype, tuple(shape))
[docs] def to_binary(self) -> ARG_TYPE:
result = np.zeros(1, dtype=ARG_TYPE)[0]
header = np.zeros(1, dtype=NDArrayArgsType)
s = np.array(self.shape).astype('<u8')
header['shape'][0][:len(s)] = s
encoded_type = json.dumps(self.dtype.descr)
encoded_type = np.frombuffer(encoded_type.encode('ascii'), dtype='<u1')
header['type_length'][0] = len(encoded_type)
to_write = np.concatenate([header.view('<u1'), encoded_type])
result[0][:to_write.shape[0]] = to_write
return result
[docs] def encode(self, destination, field, malloc):
destination[0], data_region = malloc(self.element_size)
data_region[:] = field.reshape(-1).view('<u1')
[docs] def get_decoder_class(self) -> Type[Operation]:
return NDArrayDecoder
[docs]class TorchTensorField(NDArrayField):
"""A subclass of :class:`~ffcv.fields.Field` supporting
multi-dimensional fixed size matrices of any torch type.
"""
def __init__(self, dtype:ch.dtype, shape:Tuple[int, ...]):
self.dtype = dtype
self.shape = shape
dtype = ch.zeros(0, dtype=dtype).numpy().dtype
super().__init__(dtype, shape)
[docs] def encode(self, destination, field, malloc):
field = field.numpy()
return super().encode(destination, field, malloc)