FFCV

Source code for ffcv.fields.basics

from typing import Callable, TYPE_CHECKING, Tuple, Type
from dataclasses import replace

import numpy as np

from .base import Field, ARG_TYPE
from ..pipeline.operation import Operation
from ..pipeline.state import State
from ..pipeline.allocation_query import AllocationQuery

if TYPE_CHECKING:
    from ..memory_managers.base import MemoryManager

class BasicDecoder(Operation):
    """For decoding scalar fields

    This Decoder can be extend to decode any fixed length numpy data type
    """
    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, AllocationQuery]:
        my_shape = (1,)
        return (
            replace(previous_state, jit_mode=True,
                    shape=my_shape,
                    dtype=self.dtype),
            AllocationQuery(my_shape, dtype=self.dtype)
        )

    def generate_code(self) -> Callable:
        def decoder(indices, destination, metadata, storage_state):
            for ix, sample_id in enumerate(indices):
                destination[ix] = metadata[sample_id]
            return destination[:len(indices)]

        return decoder

[docs]class IntDecoder(BasicDecoder): """Decoder for signed integers scalars (int64) """ dtype = np.dtype('<i8')
[docs]class FloatDecoder(BasicDecoder): """Decoder for floating point scalars (float64) """ dtype = np.dtype('<f8')
[docs]class FloatField(Field): """ A subclass of :class:`~ffcv.fields.Field` supporting (scalar) floating-point (float64) values. """ def __init__(self): pass @property def metadata_type(self) -> np.dtype: return np.dtype('<f8')
[docs] @staticmethod def from_binary(binary: ARG_TYPE) -> Field: return FloatField()
[docs] def to_binary(self) -> ARG_TYPE: return np.zeros(1, dtype=ARG_TYPE)[0]
[docs] def encode(self, destination, field, malloc): destination[0] = field
[docs] def get_decoder_class(self) -> Type[Operation]: return FloatDecoder
[docs]class IntField(Field): """ A subclass of :class:`~ffcv.fields.Field` supporting (scalar) integer values. """ @property def metadata_type(self) -> np.dtype: return np.dtype('<i8')
[docs] @staticmethod def from_binary(binary: ARG_TYPE) -> Field: return IntField()
[docs] def to_binary(self) -> ARG_TYPE: return np.zeros(1, dtype=ARG_TYPE)[0]
[docs] def encode(self, destination, field, malloc): # We just allocate 1024bytes for fun destination[0] = field
[docs] def get_decoder_class(self) -> Type[Operation]: return IntDecoder