Source code for hyrax.models.model_registry

import logging
from pathlib import Path
from typing import Any, cast

import numpy as np
import torch.nn as nn

from hyrax.plugin_utils import get_or_load_class, load_to_tensor, save_to_tensor, update_registry

[docs] logger = logging.getLogger(__name__)
[docs] MODEL_REGISTRY: dict[str, type[nn.Module]] = {}
[docs] def _torch_save(self: nn.Module, save_path: Path): import torch # save the model weights torch.save(self.state_dict(), save_path) save_to_tensor(self.to_tensor, save_path)
[docs] def _torch_load(self: nn.Module, load_path: Path): import ignite.distributed as idist import torch # Use ignite's device detection which handles distributed training and device availability # This allows models trained on GPU to be loaded on CPU-only machines device = idist.device() state = torch.load(load_path, weights_only=True, map_location=device) self.load_state_dict(state, assign=True) to_tensor = load_to_tensor(load_path.parent) if not to_tensor: logger.warning( f"Could not find to_tensor function in {load_path}. Using the model's existing to_tensor method." ) else: if isinstance(to_tensor, staticmethod): self.to_tensor = to_tensor else: self.to_tensor = staticmethod(to_tensor)
[docs] def _torch_criterion(self: nn.Module): """Load the criterion class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary criterion_name = config["criterion"]["name"] if not criterion_name: logger.warning("No criterion specified in config or self.criterion in model.") return None criterion_cls = get_or_load_class(criterion_name) arguments = {} if criterion_name in config: arguments = config[criterion_name] # Print some debugging info about the criterion function and parameters used log_string = f"Setting model's self.criterion from config: {criterion_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return criterion_cls(**arguments)
[docs] def _torch_optimizer(self: nn.Module): """Load the optimizer class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary optimizer_name = config["optimizer"]["name"] if not optimizer_name: logger.warning("No optimizer specified in config or self.optimizer in model.") return None optimizer_cls = get_or_load_class(optimizer_name) arguments = {} if optimizer_name in config: arguments = config[optimizer_name] # Print some debugging info about the optimizer function and parameters used log_string = f"Setting model's self.optimizer from config: {optimizer_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return optimizer_cls(self.parameters(), **arguments)
[docs] def hyrax_model(cls): """Decorator to register a model with the model registry, and to add common interface functions Returns ------- type The class with additional interface functions. """ if issubclass(cls, nn.Module): cls.save = _torch_save cls.load = _torch_load original_init = cls.__init__ def wrapped_init(self, config, *args, **kwargs): original_init(self, config, *args, **kwargs) if not hasattr(self, "optimizer"): self.optimizer = _torch_optimizer(self) else: if config["optimizer"]["name"]: logger.warning( "Both model and config define an optimizer. " "Hyrax will use self.optimizer defined in the model." ) opt_name = f"{type(self.optimizer).__module__}.{type(self.optimizer).__qualname__}" logger.info(f"Using self.optimizer defined in model: {opt_name}") if not hasattr(self, "criterion"): self.criterion = _torch_criterion(self) else: if config["criterion"]["name"]: logger.warning( "Both model and config define a criterion. " "Hyrax will use self.criterion defined in the model." ) crit_name = f"{type(self.criterion).__module__}.{type(self.criterion).__qualname__}" logger.info(f"Using self.criterion defined in model: {crit_name}") cls.__init__ = wrapped_init def default_to_tensor(data_dict): if "data" not in data_dict: msg = "Hyrax couldn't find a 'data' key in the data dictionaries from your dataset.\n" msg += f"We recommend you implement a function on {cls.__name__} to unpack the appropriate\n" msg += "value(s) from the dictionary your dataset is returning:\n\n" msg += f"class {cls.__name__}:\n\n" msg += " @staticmethod\n" msg += " def to_tensor(data_dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n\n" raise RuntimeError(msg) data = data_dict.get("data") image = data.get("image", np.array([])) label = data.get("label", np.array([])) return (image, label) if not hasattr(cls, "to_tensor"): cls.to_tensor = staticmethod(default_to_tensor) if not isinstance(vars(cls)["to_tensor"], staticmethod): msg = f"You must implement to_tensor() in {cls.__name__} as\n\n" msg += "@staticmethod\n" msg += "to_tensor(data_dict: dict) -> torch.Tensor:\n" msg += " <Your implementation goes here>\n" raise RuntimeError(msg) required_methods = ["train_step", "forward", "__init__", "to_tensor"] for name in required_methods: if not hasattr(cls, name): logger.error(f"Hyrax model {cls.__name__} missing required method {name}.") update_registry(MODEL_REGISTRY, cls.__name__, cls) return cls
[docs] def fetch_model_class(runtime_config: dict) -> type[nn.Module]: """Fetch the model class from the model registry. Parameters ---------- runtime_config : dict The runtime configuration dictionary. Returns ------- type The model class. Raises ------ ValueError If a built in model was requested, but not found in the model registry. ValueError If no model was specified in the runtime configuration. """ model_name = runtime_config["model"]["name"] if runtime_config["model"]["name"] else None model_cls = None if not model_name: model_list = "\n".join([f" - {model}" for model in sorted(MODEL_REGISTRY.keys())]) logger.error( "No model name was provided in the configuration. " "You must specify a model to use before running Hyrax.\n\n" "To set a model, use: h.set_config('model.name', '<model_name>')\n" "<model_name> can be one of the following registered models or a path to a custom model class " "e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'.\n\n" f"Currently registered models:\n{model_list}" ) raise RuntimeError( "A model class name or path must be provided. " "e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'." ) model_cls = cast(type[nn.Module], get_or_load_class(model_name, MODEL_REGISTRY)) return model_cls