import logging
from pathlib import Path
from typing import Any, cast
import numpy as np
import torch.nn as nn
import torch.optim as optim
from hyrax.plugin_utils import (
get_or_load_class,
load_prepare_inputs,
load_to_tensor,
save_prepare_inputs,
update_registry,
)
from hyrax.trace import get_trace
[docs]
logger = logging.getLogger(__name__)
[docs]
MODEL_REGISTRY: dict[str, type[nn.Module]] = {}
[docs]
def _torch_save(self: nn.Module, save_path: Path):
import inspect
import textwrap
import torch
# save the model weights
torch.save(self.state_dict(), save_path)
# If model originally had to_tensor, save it with renamed function
if hasattr(self.__class__, "_has_legacy_to_tensor"):
# Get source of original to_tensor and rename the function
# Access the staticmethod descriptor from vars() and get its underlying function
to_tensor_staticmethod = vars(self.__class__)["to_tensor"]
to_tensor_func = to_tensor_staticmethod.__func__
to_tensor_source = inspect.getsource(to_tensor_func)
# Replace function name from "to_tensor" to "prepare_inputs"
prepared_source = to_tensor_source.replace("def to_tensor(", "def prepare_inputs(", 1)
# Add common imports that prepare_inputs/to_tensor functions typically need
# This ensures the saved function can be loaded independently
imports = "import numpy as np\nimport torch\n\n"
with open(save_path.parent / "prepare_inputs.py", "w") as f:
f.write(imports + textwrap.dedent(prepared_source))
logger.warning(
"Model uses deprecated to_tensor method. Saving as prepare_inputs for forward compatibility. "
"Please rename to_tensor to prepare_inputs in your model class to ensure reproducibility."
)
else:
# Modern model with prepare_inputs - save as-is
save_prepare_inputs(self.prepare_inputs, save_path)
[docs]
def _torch_load(self: nn.Module, load_path: Path):
import ignite.distributed as idist
import torch
# Use ignite's device detection which handles distributed training and device availability
device = idist.device()
state = torch.load(load_path, weights_only=True, map_location=device)
self.load_state_dict(state)
# Try loading prepare_inputs first (new name), fall back to to_tensor for backward compatibility
prepare_inputs_fn = load_prepare_inputs(load_path.parent)
old_prepare_inputs = getattr(self, "prepare_inputs", None)
if prepare_inputs_fn:
# Successfully loaded prepare_inputs.py
self.prepare_inputs = prepare_inputs_fn
else:
# Fall back to loading to_tensor for backward compatibility
to_tensor_fn = load_to_tensor(load_path.parent)
if to_tensor_fn:
# Backward compatibility: loading old model with to_tensor.py
logger.warning(
f"Found to_tensor function in {load_path.parent}. "
"to_tensor is deprecated, please re-save your model with the new version "
"to use prepare_inputs."
)
self.prepare_inputs = to_tensor_fn
else:
logger.warning(
f"Could not find prepare_inputs or to_tensor function in {load_path.parent}. "
"Using the model's existing methods."
)
# Instrument prepare_inputs, but only if we changed it and tracing is on.
# If tracing is on but we didn't change it, its already instrumented.
if old_prepare_inputs != getattr(self, "prepare_inputs", None):
trace = get_trace()
if trace:
trace.instrument_prepare_inputs(self)
[docs]
def _torch_criterion(self: nn.Module):
"""Load the criterion class using the name defined in the config and
instantiate it with the arguments defined in the config."""
config = cast(dict[str, Any], self.config)
# Load the class and get any parameters from the config dictionary
criterion_name = config["criterion"]["name"]
if not criterion_name:
logger.warning("No criterion specified in config or self.criterion in model.")
return None
criterion_cls = get_or_load_class(criterion_name)
arguments = {}
if criterion_name in config:
arguments = config[criterion_name]
# Print some debugging info about the criterion function and parameters used
log_string = f"Setting model's self.criterion from config: {criterion_name} "
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)
return criterion_cls(**arguments)
[docs]
def _torch_optimizer(self: nn.Module):
"""Load the optimizer class using the name defined in the config and
instantiate it with the arguments defined in the config."""
config = cast(dict[str, Any], self.config)
# Load the class and get any parameters from the config dictionary
optimizer_name = config["optimizer"]["name"]
if not optimizer_name:
logger.warning("No optimizer specified in config or self.optimizer in model.")
return None
optimizer_cls = get_or_load_class(optimizer_name)
arguments = {}
if optimizer_name in config:
arguments = config[optimizer_name]
# Print some debugging info about the optimizer function and parameters used
log_string = f"Setting model's self.optimizer from config: {optimizer_name} "
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)
return optimizer_cls(self.parameters(), **arguments)
[docs]
def _torch_schedulers(self: nn.Module):
"""Load the scheduler classes using the names defined in the config and
instantiate it with the arguments defined in the config."""
config = cast(dict[str, Any], self.config)
# Load the class and get any parameters from the config dictionary
scheduler_name = config["scheduler"]["name"]
if not scheduler_name:
logger.warning("No scheduler specified in config or self.scheduler in model.")
return None
scheduler_cls = get_or_load_class(scheduler_name)
arguments = {}
if scheduler_name in config:
arguments = config[scheduler_name]
# Print some debugging info about the scheduler function and parameters used
log_string = f"Setting model's self.scheduler from config: {scheduler_name}\n"
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)
if not isinstance(self.optimizer, optim.Optimizer):
raise RuntimeError("Model optimizer must be a torch.optim.Optimizer")
scheduler = scheduler_cls(self.optimizer, **arguments)
return scheduler
[docs]
def hyrax_model(cls):
"""Decorator to register a model with the model registry, and to add common interface functions
Returns
-------
type
The class with additional interface functions.
"""
if issubclass(cls, nn.Module):
cls.save = _torch_save
cls.load = _torch_load
original_init = cls.__init__
def wrapped_init(self, config, *args, **kwargs):
original_init(self, config, *args, **kwargs)
if not hasattr(self, "optimizer"):
self.optimizer = _torch_optimizer(self)
else:
if config["optimizer"]["name"]:
logger.warning(
"Both model and config define an optimizer. "
"Hyrax will use self.optimizer defined in the model."
)
opt_name = f"{type(self.optimizer).__module__}.{type(self.optimizer).__qualname__}"
logger.info(f"Using self.optimizer defined in model: {opt_name}")
if not hasattr(self, "criterion"):
self.criterion = _torch_criterion(self)
else:
if config["criterion"]["name"]:
logger.warning(
"Both model and config define a criterion. "
"Hyrax will use self.criterion defined in the model."
)
crit_name = f"{type(self.criterion).__module__}.{type(self.criterion).__qualname__}"
logger.info(f"Using self.criterion defined in model: {crit_name}")
if not hasattr(self, "scheduler"):
self.scheduler = _torch_schedulers(self)
else:
if config["scheduler"]["name"]:
logger.warning(
"Both model and config define a scheduler. "
"Hyrax will use self.scheduler defined in the model."
)
sched_name = f"{type(self.scheduler).__module__}.{type(self.scheduler).__qualname__}"
logger.info(f"Using self.scheduler defined in model: {sched_name}")
cls.__init__ = wrapped_init
def default_prepare_inputs(data_dict):
"""Extract image and label arrays from the batch dictionary.
This static method is the interface between the data pipeline and the
model. Override it on the model class to reshape or select fields from
the collated batch to match the inputs your model expects. This is the
default implementation used when a model does not define its own
``prepare_inputs``.
Hyrax will convert the returned arrays to PyTorch tensors and move them
to the appropriate device automatically.
Parameters
----------
data_dict : dict
The collated batch dictionary produced by the data pipeline.
Expected to contain a ``"data"`` key with ``"image"`` and optionally
``"label"`` fields.
Returns
-------
inputs : tuple of numpy.ndarray
A tuple of ``(image, label)`` arrays.
"""
if "data" not in data_dict:
msg = "Hyrax couldn't find a 'data' key in the data dictionaries from your dataset.\n"
msg += f"We recommend you implement a function on {cls.__name__} to unpack the appropriate\n"
msg += "value(s) from the dictionary your dataset is returning:\n\n"
msg += f"class {cls.__name__}:\n\n"
msg += " @staticmethod\n"
msg += " def prepare_inputs(data_dict) -> Tuple[npt.NDArray, ...]:\n"
msg += " <Your implementation goes here>\n\n"
raise RuntimeError(msg)
data = data_dict.get("data")
image = data.get("image", np.array([]))
label = data.get("label", np.array([]))
return (image, label)
# Check if the class has prepare_inputs or to_tensor
has_prepare_inputs = hasattr(cls, "prepare_inputs")
has_to_tensor = hasattr(cls, "to_tensor")
if has_prepare_inputs:
# Model has new prepare_inputs method - ensure it's a staticmethod
if not isinstance(vars(cls)["prepare_inputs"], staticmethod):
msg = f"You must implement prepare_inputs() in {cls.__name__} as\n\n"
msg += "@staticmethod\n"
msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n"
msg += " <Your implementation goes here>\n"
raise RuntimeError(msg)
elif has_to_tensor:
# Model only has old to_tensor method - make prepare_inputs an alias
if not isinstance(vars(cls)["to_tensor"], staticmethod):
msg = f"You must rename to_tensor() to prepare_inputs() in {cls.__name__} as\n\n"
msg += "@staticmethod\n"
msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n"
msg += " <Your implementation goes here>\n"
raise RuntimeError(msg)
# Create an alias that's also a staticmethod
# We need to get the underlying function from the staticmethod descriptor
to_tensor_func = vars(cls)["to_tensor"].__func__
cls.prepare_inputs = staticmethod(to_tensor_func)
# Mark for save logic to know this needs function renaming
cls._has_legacy_to_tensor = True
else:
# No method defined - use defaults
cls.prepare_inputs = staticmethod(default_prepare_inputs)
# Update required methods to include prepare_inputs
required_methods = ["train_batch", "infer_batch", "__init__", "prepare_inputs"]
for name in required_methods:
if not hasattr(cls, name):
logger.error(f"Hyrax model {cls.__name__} missing required method {name}.")
update_registry(MODEL_REGISTRY, cls.__name__, cls)
return cls
[docs]
def fetch_model_class(runtime_config: dict) -> type[nn.Module]:
"""Fetch the model class from the model registry.
Parameters
----------
runtime_config : dict
The runtime configuration dictionary.
Returns
-------
type
The model class.
Raises
------
ValueError
If a built in model was requested, but not found in the model registry.
ValueError
If no model was specified in the runtime configuration.
"""
model_name = runtime_config["model"]["name"] if runtime_config["model"]["name"] else None
model_cls = None
if not model_name:
model_list = "\n".join([f" - {model}" for model in sorted(MODEL_REGISTRY.keys())])
logger.error(
"No model name was provided in the configuration. "
"You must specify a model to use before running Hyrax.\n\n"
"To set a model, use: h.set_config('model.name', '<model_name>')\n"
"<model_name> can be one of the following registered models or a path to a custom model class "
"e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'.\n\n"
f"Currently registered models:\n{model_list}"
)
raise RuntimeError(
"A model class name or path must be provided. "
"e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'."
)
model_cls = cast(type[nn.Module], get_or_load_class(model_name, MODEL_REGISTRY))
return model_cls