Source code for hyrax.models.model_registry

import logging
from pathlib import Path
from typing import Any, cast

import numpy as np
import torch.nn as nn
import torch.optim as optim

from hyrax.plugin_utils import (
    get_or_load_class,
    load_prepare_inputs,
    load_to_tensor,
    save_prepare_inputs,
    update_registry,
)
from hyrax.trace import get_trace

[docs] logger = logging.getLogger(__name__)
[docs] MODEL_REGISTRY: dict[str, type[nn.Module]] = {}
[docs] def _torch_save(self: nn.Module, save_path: Path): import inspect import textwrap import torch # save the model weights torch.save(self.state_dict(), save_path) # If model originally had to_tensor, save it with renamed function if hasattr(self.__class__, "_has_legacy_to_tensor"): # Get source of original to_tensor and rename the function # Access the staticmethod descriptor from vars() and get its underlying function to_tensor_staticmethod = vars(self.__class__)["to_tensor"] to_tensor_func = to_tensor_staticmethod.__func__ to_tensor_source = inspect.getsource(to_tensor_func) # Replace function name from "to_tensor" to "prepare_inputs" prepared_source = to_tensor_source.replace("def to_tensor(", "def prepare_inputs(", 1) # Add common imports that prepare_inputs/to_tensor functions typically need # This ensures the saved function can be loaded independently imports = "import numpy as np\nimport torch\n\n" with open(save_path.parent / "prepare_inputs.py", "w") as f: f.write(imports + textwrap.dedent(prepared_source)) logger.warning( "Model uses deprecated to_tensor method. Saving as prepare_inputs for forward compatibility. " "Please rename to_tensor to prepare_inputs in your model class to ensure reproducibility." ) else: # Modern model with prepare_inputs - save as-is save_prepare_inputs(self.prepare_inputs, save_path)
[docs] def _torch_load(self: nn.Module, load_path: Path): import ignite.distributed as idist import torch # Use ignite's device detection which handles distributed training and device availability device = idist.device() state = torch.load(load_path, weights_only=True, map_location=device) self.load_state_dict(state, assign=True) # Try loading prepare_inputs first (new name), fall back to to_tensor for backward compatibility prepare_inputs_fn = load_prepare_inputs(load_path.parent) old_prepare_inputs = getattr(self, "prepare_inputs", None) if prepare_inputs_fn: # Successfully loaded prepare_inputs.py self.prepare_inputs = prepare_inputs_fn else: # Fall back to loading to_tensor for backward compatibility to_tensor_fn = load_to_tensor(load_path.parent) if to_tensor_fn: # Backward compatibility: loading old model with to_tensor.py logger.warning( f"Found to_tensor function in {load_path.parent}. " "to_tensor is deprecated, please re-save your model with the new version " "to use prepare_inputs." ) self.prepare_inputs = to_tensor_fn else: logger.warning( f"Could not find prepare_inputs or to_tensor function in {load_path.parent}. " "Using the model's existing methods." ) # Instrument prepare_inputs, but only if we changed it and tracing is on. # If tracing is on but we didn't change it, its already instrumented. if old_prepare_inputs != getattr(self, "prepare_inputs", None): trace = get_trace() if trace: trace.instrument_prepare_inputs(self)
[docs] def _torch_criterion(self: nn.Module): """Load the criterion class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary criterion_name = config["criterion"]["name"] if not criterion_name: logger.warning("No criterion specified in config or self.criterion in model.") return None criterion_cls = get_or_load_class(criterion_name) arguments = {} if criterion_name in config: arguments = config[criterion_name] # Print some debugging info about the criterion function and parameters used log_string = f"Setting model's self.criterion from config: {criterion_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return criterion_cls(**arguments)
[docs] def _torch_optimizer(self: nn.Module): """Load the optimizer class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary optimizer_name = config["optimizer"]["name"] if not optimizer_name: logger.warning("No optimizer specified in config or self.optimizer in model.") return None optimizer_cls = get_or_load_class(optimizer_name) arguments = {} if optimizer_name in config: arguments = config[optimizer_name] # Print some debugging info about the optimizer function and parameters used log_string = f"Setting model's self.optimizer from config: {optimizer_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return optimizer_cls(self.parameters(), **arguments)
[docs] def _torch_schedulers(self: nn.Module): """Load the scheduler classes using the names defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary scheduler_name = config["scheduler"]["name"] if not scheduler_name: logger.warning("No scheduler specified in config or self.scheduler in model.") return None scheduler_cls = get_or_load_class(scheduler_name) arguments = {} if scheduler_name in config: arguments = config[scheduler_name] # Print some debugging info about the scheduler function and parameters used log_string = f"Setting model's self.scheduler from config: {scheduler_name}\n" if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) if not isinstance(self.optimizer, optim.Optimizer): raise RuntimeError("Model optimizer must be a torch.optim.Optimizer") scheduler = scheduler_cls(self.optimizer, **arguments) return scheduler
[docs] def hyrax_model(cls): """Decorator to register a model with the model registry, and to add common interface functions Returns ------- type The class with additional interface functions. """ if issubclass(cls, nn.Module): cls.save = _torch_save cls.load = _torch_load original_init = cls.__init__ def wrapped_init(self, config, *args, **kwargs): original_init(self, config, *args, **kwargs) if not hasattr(self, "optimizer"): self.optimizer = _torch_optimizer(self) else: if config["optimizer"]["name"]: logger.warning( "Both model and config define an optimizer. " "Hyrax will use self.optimizer defined in the model." ) opt_name = f"{type(self.optimizer).__module__}.{type(self.optimizer).__qualname__}" logger.info(f"Using self.optimizer defined in model: {opt_name}") if not hasattr(self, "criterion"): self.criterion = _torch_criterion(self) else: if config["criterion"]["name"]: logger.warning( "Both model and config define a criterion. " "Hyrax will use self.criterion defined in the model." ) crit_name = f"{type(self.criterion).__module__}.{type(self.criterion).__qualname__}" logger.info(f"Using self.criterion defined in model: {crit_name}") if not hasattr(self, "scheduler"): self.scheduler = _torch_schedulers(self) else: if config["scheduler"]["name"]: logger.warning( "Both model and config define a scheduler. " "Hyrax will use self.scheduler defined in the model." ) sched_name = f"{type(self.scheduler).__module__}.{type(self.scheduler).__qualname__}" logger.info(f"Using self.scheduler defined in model: {sched_name}") cls.__init__ = wrapped_init def default_prepare_inputs(data_dict): """Extract image and label arrays from the batch dictionary. This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects. This is the default implementation used when a model does not define its own ``prepare_inputs``. Hyrax will convert the returned arrays to PyTorch tensors and move them to the appropriate device automatically. Parameters ---------- data_dict : dict The collated batch dictionary produced by the data pipeline. Expected to contain a ``"data"`` key with ``"image"`` and optionally ``"label"`` fields. Returns ------- inputs : tuple of numpy.ndarray A tuple of ``(image, label)`` arrays. """ if "data" not in data_dict: msg = "Hyrax couldn't find a 'data' key in the data dictionaries from your dataset.\n" msg += f"We recommend you implement a function on {cls.__name__} to unpack the appropriate\n" msg += "value(s) from the dictionary your dataset is returning:\n\n" msg += f"class {cls.__name__}:\n\n" msg += " @staticmethod\n" msg += " def prepare_inputs(data_dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n\n" raise RuntimeError(msg) data = data_dict.get("data") image = data.get("image", np.array([])) label = data.get("label", np.array([])) return (image, label) # Check if the class has prepare_inputs or to_tensor has_prepare_inputs = hasattr(cls, "prepare_inputs") has_to_tensor = hasattr(cls, "to_tensor") if has_prepare_inputs: # Model has new prepare_inputs method - ensure it's a staticmethod if not isinstance(vars(cls)["prepare_inputs"], staticmethod): msg = f"You must implement prepare_inputs() in {cls.__name__} as\n\n" msg += "@staticmethod\n" msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n" raise RuntimeError(msg) elif has_to_tensor: # Model only has old to_tensor method - make prepare_inputs an alias if not isinstance(vars(cls)["to_tensor"], staticmethod): msg = f"You must rename to_tensor() to prepare_inputs() in {cls.__name__} as\n\n" msg += "@staticmethod\n" msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n" raise RuntimeError(msg) # Create an alias that's also a staticmethod # We need to get the underlying function from the staticmethod descriptor to_tensor_func = vars(cls)["to_tensor"].__func__ cls.prepare_inputs = staticmethod(to_tensor_func) # Mark for save logic to know this needs function renaming cls._has_legacy_to_tensor = True else: # No method defined - use defaults cls.prepare_inputs = staticmethod(default_prepare_inputs) # Update required methods to include prepare_inputs required_methods = ["train_batch", "infer_batch", "__init__", "prepare_inputs"] for name in required_methods: if not hasattr(cls, name): logger.error(f"Hyrax model {cls.__name__} missing required method {name}.") update_registry(MODEL_REGISTRY, cls.__name__, cls) return cls
[docs] def fetch_model_class(runtime_config: dict) -> type[nn.Module]: """Fetch the model class from the model registry. Parameters ---------- runtime_config : dict The runtime configuration dictionary. Returns ------- type The model class. Raises ------ ValueError If a built in model was requested, but not found in the model registry. ValueError If no model was specified in the runtime configuration. """ model_name = runtime_config["model"]["name"] if runtime_config["model"]["name"] else None model_cls = None if not model_name: model_list = "\n".join([f" - {model}" for model in sorted(MODEL_REGISTRY.keys())]) logger.error( "No model name was provided in the configuration. " "You must specify a model to use before running Hyrax.\n\n" "To set a model, use: h.set_config('model.name', '<model_name>')\n" "<model_name> can be one of the following registered models or a path to a custom model class " "e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'.\n\n" f"Currently registered models:\n{model_list}" ) raise RuntimeError( "A model class name or path must be provided. " "e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'." ) model_cls = cast(type[nn.Module], get_or_load_class(model_name, MODEL_REGISTRY)) return model_cls