Source code for hyrax.models.model_registry

import logging
from pathlib import Path
from typing import Any, cast

import torch.nn as nn
from torch import Tensor

from hyrax.plugin_utils import get_or_load_class, update_registry

[docs] logger = logging.getLogger(__name__)
[docs] MODEL_REGISTRY: dict[str, type[nn.Module]] = {}
[docs] def _torch_save(self: nn.Module, save_path: Path): import torch torch.save(self.state_dict(), save_path)
[docs] def _torch_load(self: nn.Module, load_path: Path): import torch state_dict = torch.load(load_path, weights_only=True) self.load_state_dict(state_dict, assign=True)
[docs] def _torch_criterion(self: nn.Module): """Load the criterion class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary criterion_cls = get_or_load_class(config["criterion"]) criterion_name = config["criterion"]["name"] arguments = {} if criterion_name in config: arguments = config[criterion_name] # Print some information about the criterion function and parameters used log_string = f"Using criterion: {criterion_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return criterion_cls(**arguments)
[docs] def _torch_optimizer(self: nn.Module): """Load the optimizer class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary optimizer_cls = get_or_load_class(config["optimizer"]) optimizer_name = config["optimizer"]["name"] arguments = {} if optimizer_name in config: arguments = config[optimizer_name] # Print some information about the optimizer function and parameters used log_string = f"Using optimizer: {optimizer_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return optimizer_cls(self.parameters(), **arguments)
[docs] def hyrax_model(cls): """Decorator to register a model with the model registry, and to add common interface functions Returns ------- type The class with additional interface functions. """ if issubclass(cls, nn.Module): cls.save = _torch_save cls.load = _torch_load cls._criterion = _torch_criterion if not hasattr(cls, "_criterion") else cls._criterion cls._optimizer = _torch_optimizer if not hasattr(cls, "_optimizer") else cls._optimizer original_init = cls.__init__ def wrapped_init(self, dataset, *args, **kwargs): # Model constructors need a shape, but only a model can tell how to take a # data dict and extract the tensor. We fixup the call here so original_init # is passed a valid shape. # # TODO: We may want to stop passing 'shape' to model __init__(). This will # come at the cost of allowing models to dynamically size their layers/architecture # given shape information. The typical solution to this limitation seems to be a # static-size model crop/resize transforms defined with the model and executed # during data loading by the driver code. # Get a sample item of data if dataset.is_map(): sample = dataset[0] elif dataset.is_iterable(): sample = next(iter(dataset)) else: msg = f"{dataset.__class__.__name} must define __getitem__ or __iter__." return NotImplementedError(msg) # Perform conversion to tensor(s) if necessary if isinstance(sample, dict): sample = self.__class__.to_tensor(sample) # If its a tuple or list extract first element because it is (data, label) if isinstance(sample, (tuple, list)): sample = sample[0] if not isinstance(sample, Tensor): msg = "{self.__class__.__name__}.to_tensor() is not returning a tensor when run with " msg += "data from {dataset.__class__.__name__}." raise RuntimeError(msg) kwargs.update({"shape": sample.shape}) original_init(self, *args, **kwargs) self.criterion = self._criterion() self.optimizer = self._optimizer() cls.__init__ = wrapped_init def default_to_tensor(data_dict): if isinstance(data_dict.get("image"), Tensor): if "label" in data_dict: return (data_dict["image"], data_dict["label"]) else: return data_dict["image"] else: msg = "Hyrax couldn't find an image in the data dictionaries from your dataset.\n" msg += f"We recommend you implement a function on {cls.__name__} to unpack the appropriate\n" msg += "value(s) from the dictionary your dataset is returning:\n\n" msg += f"class {cls.__name__}:\n\n" msg += " @staticmethod\n" msg += " def to_tensor(data_dict) -> Tensor:\n" msg += " <Your implementation goes here>\n\n" raise RuntimeError(msg) if not hasattr(cls, "to_tensor"): cls.to_tensor = staticmethod(default_to_tensor) if not isinstance(vars(cls)["to_tensor"], staticmethod): msg = f"You must implement to_tensor() in {cls.__name__} as\n\n" msg += "@staticmethod\n" msg += "to_tensor(data_dict: dict) -> torch.Tensor:\n" msg += " <Your implementation goes here>\n" raise RuntimeError(msg) required_methods = ["train_step", "forward", "__init__", "to_tensor"] for name in required_methods: if not hasattr(cls, name): logger.error(f"Hyrax model {cls.__name__} missing required method {name}.") update_registry(MODEL_REGISTRY, cls.__name__, cls) return cls
[docs] def fetch_model_class(runtime_config: dict) -> type[nn.Module]: """Fetch the model class from the model registry. Parameters ---------- runtime_config : dict The runtime configuration dictionary. Returns ------- type The model class. Raises ------ ValueError If a built in model was requested, but not found in the model registry. ValueError If no model was specified in the runtime configuration. """ model_config = runtime_config["model"] model_cls = None try: model_cls = cast(type[nn.Module], get_or_load_class(model_config, MODEL_REGISTRY)) except ValueError as exc: raise ValueError("Error fetching model class") from exc return model_cls