Source code for hyrax.models.model_registry

import logging
from pathlib import Path
from typing import Any, cast

import numpy as np
import torch.nn as nn
import torch.optim as optim

from hyrax.plugin_utils import (
    get_or_load_class,
    load_prepare_inputs,
    load_to_tensor,
    save_prepare_inputs,
    update_registry,
)
from hyrax.trace import get_trace

[docs] logger = logging.getLogger(__name__)
[docs] MODEL_REGISTRY: dict[str, type[nn.Module]] = {}
[docs] def _torch_save(self: nn.Module, save_path: Path): import inspect import textwrap import torch # save the model weights torch.save(self.state_dict(), save_path) # If model originally had to_tensor, save it with renamed function if hasattr(self.__class__, "_has_legacy_to_tensor"): # Get source of original to_tensor and rename the function # Access the staticmethod descriptor from vars() and get its underlying function to_tensor_staticmethod = vars(self.__class__)["to_tensor"] to_tensor_func = to_tensor_staticmethod.__func__ to_tensor_source = inspect.getsource(to_tensor_func) # Replace function name from "to_tensor" to "prepare_inputs" prepared_source = to_tensor_source.replace("def to_tensor(", "def prepare_inputs(", 1) # Add common imports that prepare_inputs/to_tensor functions typically need # This ensures the saved function can be loaded independently imports = "import numpy as np\nimport torch\n\n" with open(save_path.parent / "prepare_inputs.py", "w") as f: f.write(imports + textwrap.dedent(prepared_source)) logger.warning( "Model uses deprecated to_tensor method. Saving as prepare_inputs for forward compatibility. " "Please rename to_tensor to prepare_inputs in your model class to ensure reproducibility." ) else: # Modern model with prepare_inputs - save as-is save_prepare_inputs(self.prepare_inputs, save_path)
[docs] def _torch_load(self: nn.Module, load_path: Path): import ignite.distributed as idist import torch # Use ignite's device detection which handles distributed training and device availability device = idist.device() state = torch.load(load_path, weights_only=True, map_location=device) self.load_state_dict(state) # Try loading prepare_inputs first (new name), fall back to to_tensor for backward compatibility prepare_inputs_fn = load_prepare_inputs(load_path.parent) old_prepare_inputs = getattr(self, "prepare_inputs", None) if prepare_inputs_fn: # Successfully loaded prepare_inputs.py self.prepare_inputs = prepare_inputs_fn else: # Fall back to loading to_tensor for backward compatibility to_tensor_fn = load_to_tensor(load_path.parent) if to_tensor_fn: # Backward compatibility: loading old model with to_tensor.py logger.warning( f"Found to_tensor function in {load_path.parent}. " "to_tensor is deprecated, please re-save your model with the new version " "to use prepare_inputs." ) self.prepare_inputs = to_tensor_fn else: logger.warning( f"Could not find prepare_inputs or to_tensor function in {load_path.parent}. " "Using the model's existing methods." ) # Instrument prepare_inputs, but only if we changed it and tracing is on. # If tracing is on but we didn't change it, its already instrumented. if old_prepare_inputs != getattr(self, "prepare_inputs", None): trace = get_trace() if trace: trace.instrument_prepare_inputs(self)
[docs] def _torch_criterion(self: nn.Module): """Load the criterion class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary criterion_name = config["criterion"]["name"] if not criterion_name: logger.warning("No criterion specified in config or self.criterion in model.") return None criterion_cls = get_or_load_class(criterion_name) arguments = {} if criterion_name in config: arguments = config[criterion_name] # Print some debugging info about the criterion function and parameters used log_string = f"Setting model's self.criterion from config: {criterion_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return criterion_cls(**arguments)
[docs] def _torch_optimizer(self: nn.Module): """Load the optimizer class using the name defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary optimizer_name = config["optimizer"]["name"] if not optimizer_name: logger.warning("No optimizer specified in config or self.optimizer in model.") return None optimizer_cls = get_or_load_class(optimizer_name) arguments = {} if optimizer_name in config: arguments = config[optimizer_name] # Print some debugging info about the optimizer function and parameters used log_string = f"Setting model's self.optimizer from config: {optimizer_name} " if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) return optimizer_cls(self.parameters(), **arguments)
[docs] def _torch_schedulers(self: nn.Module): """Load the scheduler classes using the names defined in the config and instantiate it with the arguments defined in the config.""" config = cast(dict[str, Any], self.config) # Load the class and get any parameters from the config dictionary scheduler_name = config["scheduler"]["name"] if not scheduler_name: logger.warning("No scheduler specified in config or self.scheduler in model.") return None scheduler_cls = get_or_load_class(scheduler_name) arguments = {} if scheduler_name in config: arguments = config[scheduler_name] # Print some debugging info about the scheduler function and parameters used log_string = f"Setting model's self.scheduler from config: {scheduler_name}\n" if arguments: log_string += f"with arguments: {arguments}." else: log_string += "with default arguments." logger.info(log_string) if not isinstance(self.optimizer, optim.Optimizer): raise RuntimeError("Model optimizer must be a torch.optim.Optimizer") scheduler = scheduler_cls(self.optimizer, **arguments) return scheduler
[docs] def hyrax_model(cls): """Decorator to register a model with the model registry, and to add common interface functions Returns ------- type The class with additional interface functions. """ if issubclass(cls, nn.Module): cls.save = _torch_save cls.load = _torch_load original_init = cls.__init__ def wrapped_init(self, config, *args, **kwargs): original_init(self, config, *args, **kwargs) if not hasattr(self, "optimizer"): self.optimizer = _torch_optimizer(self) else: if config["optimizer"]["name"]: logger.warning( "Both model and config define an optimizer. " "Hyrax will use self.optimizer defined in the model." ) opt_name = f"{type(self.optimizer).__module__}.{type(self.optimizer).__qualname__}" logger.info(f"Using self.optimizer defined in model: {opt_name}") if not hasattr(self, "criterion"): self.criterion = _torch_criterion(self) else: if config["criterion"]["name"]: logger.warning( "Both model and config define a criterion. " "Hyrax will use self.criterion defined in the model." ) crit_name = f"{type(self.criterion).__module__}.{type(self.criterion).__qualname__}" logger.info(f"Using self.criterion defined in model: {crit_name}") if not hasattr(self, "scheduler"): self.scheduler = _torch_schedulers(self) else: if config["scheduler"]["name"]: logger.warning( "Both model and config define a scheduler. " "Hyrax will use self.scheduler defined in the model." ) sched_name = f"{type(self.scheduler).__module__}.{type(self.scheduler).__qualname__}" logger.info(f"Using self.scheduler defined in model: {sched_name}") cls.__init__ = wrapped_init def default_prepare_inputs(data_dict): if "data" not in data_dict: msg = "Hyrax couldn't find a 'data' key in the data dictionaries from your dataset.\n" msg += f"We recommend you implement a function on {cls.__name__} to unpack the appropriate\n" msg += "value(s) from the dictionary your dataset is returning:\n\n" msg += f"class {cls.__name__}:\n\n" msg += " @staticmethod\n" msg += " def prepare_inputs(data_dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n\n" raise RuntimeError(msg) data = data_dict.get("data") image = data.get("image", np.array([])) label = data.get("label", np.array([])) return (image, label) # Check if the class has prepare_inputs or to_tensor has_prepare_inputs = hasattr(cls, "prepare_inputs") has_to_tensor = hasattr(cls, "to_tensor") if has_prepare_inputs: # Model has new prepare_inputs method - ensure it's a staticmethod if not isinstance(vars(cls)["prepare_inputs"], staticmethod): msg = f"You must implement prepare_inputs() in {cls.__name__} as\n\n" msg += "@staticmethod\n" msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n" raise RuntimeError(msg) elif has_to_tensor: # Model only has old to_tensor method - make prepare_inputs an alias if not isinstance(vars(cls)["to_tensor"], staticmethod): msg = f"You must rename to_tensor() to prepare_inputs() in {cls.__name__} as\n\n" msg += "@staticmethod\n" msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n" msg += " <Your implementation goes here>\n" raise RuntimeError(msg) # Create an alias that's also a staticmethod # We need to get the underlying function from the staticmethod descriptor to_tensor_func = vars(cls)["to_tensor"].__func__ cls.prepare_inputs = staticmethod(to_tensor_func) # Mark for save logic to know this needs function renaming cls._has_legacy_to_tensor = True else: # No method defined - use defaults cls.prepare_inputs = staticmethod(default_prepare_inputs) # Update required methods to include prepare_inputs required_methods = ["train_batch", "infer_batch", "__init__", "prepare_inputs"] for name in required_methods: if not hasattr(cls, name): logger.error(f"Hyrax model {cls.__name__} missing required method {name}.") update_registry(MODEL_REGISTRY, cls.__name__, cls) return cls
[docs] def fetch_model_class(runtime_config: dict) -> type[nn.Module]: """Fetch the model class from the model registry. Parameters ---------- runtime_config : dict The runtime configuration dictionary. Returns ------- type The model class. Raises ------ ValueError If a built in model was requested, but not found in the model registry. ValueError If no model was specified in the runtime configuration. """ model_name = runtime_config["model"]["name"] if runtime_config["model"]["name"] else None model_cls = None if not model_name: model_list = "\n".join([f" - {model}" for model in sorted(MODEL_REGISTRY.keys())]) logger.error( "No model name was provided in the configuration. " "You must specify a model to use before running Hyrax.\n\n" "To set a model, use: h.set_config('model.name', '<model_name>')\n" "<model_name> can be one of the following registered models or a path to a custom model class " "e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'.\n\n" f"Currently registered models:\n{model_list}" ) raise RuntimeError( "A model class name or path must be provided. " "e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'." ) model_cls = cast(type[nn.Module], get_or_load_class(model_name, MODEL_REGISTRY)) return model_cls