Wrapper for a torch.nn.Module
import torch as ch
from numpy.random import permutation, rand
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State

[docs]class ModuleWrapper(Operation): """Transform using the given torch.nn.Module Parameters ---------- module: torch.nn.Module The module for transformation """ def __init__(self, module: ch.nn.Module): super().__init__() self.module = module
[docs] def generate_code(self) -> Callable: def apply_module(inp, _): res = self.module(inp) return res return apply_module
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return previous_state, None