"""
Wrapper for a torch.nn.Module
"""
import torch as ch
from numpy.random import permutation, rand
from typing import Callable, Optional, Tuple
from ..pipeline.allocation_query import AllocationQuery
from ..pipeline.operation import Operation
from ..pipeline.state import State
[docs]class ModuleWrapper(Operation):
"""Transform using the given torch.nn.Module
Parameters
----------
module: torch.nn.Module
The module for transformation
"""
def __init__(self, module: ch.nn.Module):
super().__init__()
self.module = module
[docs] def generate_code(self) -> Callable:
def apply_module(inp, _):
res = self.module(inp)
return res
return apply_module
[docs] def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
return previous_state, None