Source code for hyrax.models.hyrax_loopback

import logging

import torch.nn as nn

from .model_registry import hyrax_model

[docs] logger = logging.getLogger()
@hyrax_model
[docs] class HyraxLoopback(nn.Module): """Simple model for testing which returns its own input""" def __init__(self, config, data_sample=None): from functools import partial super().__init__() # The optimizer needs at least one weight, so we add a dummy module here
[docs] self.unused_module = nn.Linear(1, 1)
[docs] self.config = config
def load(self, weight_file): """Boilerplate function to load weights. However, this model has no weights so we do nothing.""" pass # We override this way rather than defining a method because # Torch has some __init__ related cleverness which stomps our # load definition when performed in the usual fashion.
[docs] self.load = partial(load, self)
[docs] def forward(self, x): """We simply return our input""" if isinstance(x, (tuple, list)): # if x is a tuple, extract the first element (it should be a tensor) x, _ = x return x
[docs] def train_step(self, batch): """Training is a noop""" logger.debug(f"Batch length: {len(batch)}") return {"loss": 0.0}