import logging
from pathlib import Path
from typing import Any, cast
import numpy as np
import torch.nn as nn
import torch.optim as optim
from hyrax.plugin_utils import (
get_or_load_class,
load_prepare_inputs,
load_to_tensor,
save_prepare_inputs,
update_registry,
)
from hyrax.trace import get_trace
[docs]
logger = logging.getLogger(__name__)
[docs]
MODEL_REGISTRY: dict[str, type[nn.Module]] = {}
[docs]
def _torch_save(self: nn.Module, save_path: Path):
import inspect
import textwrap
import torch
# save the model weights
torch.save(self.state_dict(), save_path)
# If model originally had to_tensor, save it with renamed function
if hasattr(self.__class__, "_has_legacy_to_tensor"):
# Get source of original to_tensor and rename the function
# Access the staticmethod descriptor from vars() and get its underlying function
to_tensor_staticmethod = vars(self.__class__)["to_tensor"]
to_tensor_func = to_tensor_staticmethod.__func__
to_tensor_source = inspect.getsource(to_tensor_func)
# Replace function name from "to_tensor" to "prepare_inputs"
prepared_source = to_tensor_source.replace("def to_tensor(", "def prepare_inputs(", 1)
# Add common imports that prepare_inputs/to_tensor functions typically need
# This ensures the saved function can be loaded independently
imports = "import numpy as np\nimport torch\n\n"
with open(save_path.parent / "prepare_inputs.py", "w") as f:
f.write(imports + textwrap.dedent(prepared_source))
logger.warning(
"Model uses deprecated to_tensor method. Saving as prepare_inputs for forward compatibility. "
"Please rename to_tensor to prepare_inputs in your model class to ensure reproducibility."
)
else:
# Modern model with prepare_inputs - save as-is
save_prepare_inputs(self.prepare_inputs, save_path)
[docs]
def _torch_load(self: nn.Module, load_path: Path):
import ignite.distributed as idist
import torch
# Use ignite's device detection which handles distributed training and device availability
device = idist.device()
state = torch.load(load_path, weights_only=True, map_location=device)
self.load_state_dict(state)
# Try loading prepare_inputs first (new name), fall back to to_tensor for backward compatibility
prepare_inputs_fn = load_prepare_inputs(load_path.parent)
old_prepare_inputs = getattr(self, "prepare_inputs", None)
if prepare_inputs_fn:
# Successfully loaded prepare_inputs.py
self.prepare_inputs = prepare_inputs_fn
else:
# Fall back to loading to_tensor for backward compatibility
to_tensor_fn = load_to_tensor(load_path.parent)
if to_tensor_fn:
# Backward compatibility: loading old model with to_tensor.py
logger.warning(
f"Found to_tensor function in {load_path.parent}. "
"to_tensor is deprecated, please re-save your model with the new version "
"to use prepare_inputs."
)
self.prepare_inputs = to_tensor_fn
else:
logger.warning(
f"Could not find prepare_inputs or to_tensor function in {load_path.parent}. "
"Using the model's existing methods."
)
# Instrument prepare_inputs, but only if we changed it and tracing is on.
# If tracing is on but we didn't change it, its already instrumented.
if old_prepare_inputs != getattr(self, "prepare_inputs", None):
trace = get_trace()
if trace:
trace.instrument_prepare_inputs(self)
[docs]
def _torch_criterion(self: nn.Module):
"""Load the criterion class using the name defined in the config and
instantiate it with the arguments defined in the config."""
config = cast(dict[str, Any], self.config)
# Load the class and get any parameters from the config dictionary
criterion_name = config["criterion"]["name"]
if not criterion_name:
logger.warning("No criterion specified in config or self.criterion in model.")
return None
criterion_cls = get_or_load_class(criterion_name)
arguments = {}
if criterion_name in config:
arguments = config[criterion_name]
# Print some debugging info about the criterion function and parameters used
log_string = f"Setting model's self.criterion from config: {criterion_name} "
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)
return criterion_cls(**arguments)
[docs]
def _torch_optimizer(self: nn.Module):
"""Load the optimizer class using the name defined in the config and
instantiate it with the arguments defined in the config."""
config = cast(dict[str, Any], self.config)
# Load the class and get any parameters from the config dictionary
optimizer_name = config["optimizer"]["name"]
if not optimizer_name:
logger.warning("No optimizer specified in config or self.optimizer in model.")
return None
optimizer_cls = get_or_load_class(optimizer_name)
arguments = {}
if optimizer_name in config:
arguments = config[optimizer_name]
# Print some debugging info about the optimizer function and parameters used
log_string = f"Setting model's self.optimizer from config: {optimizer_name} "
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)
return optimizer_cls(self.parameters(), **arguments)
[docs]
def _torch_schedulers(self: nn.Module):
"""Load the scheduler classes using the names defined in the config and
instantiate it with the arguments defined in the config."""
config = cast(dict[str, Any], self.config)
# Load the class and get any parameters from the config dictionary
scheduler_name = config["scheduler"]["name"]
if not scheduler_name:
logger.warning("No scheduler specified in config or self.scheduler in model.")
return None
scheduler_cls = get_or_load_class(scheduler_name)
arguments = {}
if scheduler_name in config:
arguments = config[scheduler_name]
# Print some debugging info about the scheduler function and parameters used
log_string = f"Setting model's self.scheduler from config: {scheduler_name}\n"
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)
if not isinstance(self.optimizer, optim.Optimizer):
raise RuntimeError("Model optimizer must be a torch.optim.Optimizer")
scheduler = scheduler_cls(self.optimizer, **arguments)
return scheduler
[docs]
def hyrax_model(cls):
"""Decorator to register a model with the model registry, and to add common interface functions
Returns
-------
type
The class with additional interface functions.
"""
if issubclass(cls, nn.Module):
cls.save = _torch_save
cls.load = _torch_load
original_init = cls.__init__
def wrapped_init(self, config, *args, **kwargs):
original_init(self, config, *args, **kwargs)
if not hasattr(self, "optimizer"):
self.optimizer = _torch_optimizer(self)
else:
if config["optimizer"]["name"]:
logger.warning(
"Both model and config define an optimizer. "
"Hyrax will use self.optimizer defined in the model."
)
opt_name = f"{type(self.optimizer).__module__}.{type(self.optimizer).__qualname__}"
logger.info(f"Using self.optimizer defined in model: {opt_name}")
if not hasattr(self, "criterion"):
self.criterion = _torch_criterion(self)
else:
if config["criterion"]["name"]:
logger.warning(
"Both model and config define a criterion. "
"Hyrax will use self.criterion defined in the model."
)
crit_name = f"{type(self.criterion).__module__}.{type(self.criterion).__qualname__}"
logger.info(f"Using self.criterion defined in model: {crit_name}")
if not hasattr(self, "scheduler"):
self.scheduler = _torch_schedulers(self)
else:
if config["scheduler"]["name"]:
logger.warning(
"Both model and config define a scheduler. "
"Hyrax will use self.scheduler defined in the model."
)
sched_name = f"{type(self.scheduler).__module__}.{type(self.scheduler).__qualname__}"
logger.info(f"Using self.scheduler defined in model: {sched_name}")
cls.__init__ = wrapped_init
def default_prepare_inputs(data_dict):
if "data" not in data_dict:
msg = "Hyrax couldn't find a 'data' key in the data dictionaries from your dataset.\n"
msg += f"We recommend you implement a function on {cls.__name__} to unpack the appropriate\n"
msg += "value(s) from the dictionary your dataset is returning:\n\n"
msg += f"class {cls.__name__}:\n\n"
msg += " @staticmethod\n"
msg += " def prepare_inputs(data_dict) -> Tuple[npt.NDArray, ...]:\n"
msg += " <Your implementation goes here>\n\n"
raise RuntimeError(msg)
data = data_dict.get("data")
image = data.get("image", np.array([]))
label = data.get("label", np.array([]))
return (image, label)
# Check if the class has prepare_inputs or to_tensor
has_prepare_inputs = hasattr(cls, "prepare_inputs")
has_to_tensor = hasattr(cls, "to_tensor")
if has_prepare_inputs:
# Model has new prepare_inputs method - ensure it's a staticmethod
if not isinstance(vars(cls)["prepare_inputs"], staticmethod):
msg = f"You must implement prepare_inputs() in {cls.__name__} as\n\n"
msg += "@staticmethod\n"
msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n"
msg += " <Your implementation goes here>\n"
raise RuntimeError(msg)
elif has_to_tensor:
# Model only has old to_tensor method - make prepare_inputs an alias
if not isinstance(vars(cls)["to_tensor"], staticmethod):
msg = f"You must rename to_tensor() to prepare_inputs() in {cls.__name__} as\n\n"
msg += "@staticmethod\n"
msg += "def prepare_inputs(data_dict: dict) -> Tuple[npt.NDArray, ...]:\n"
msg += " <Your implementation goes here>\n"
raise RuntimeError(msg)
# Create an alias that's also a staticmethod
# We need to get the underlying function from the staticmethod descriptor
to_tensor_func = vars(cls)["to_tensor"].__func__
cls.prepare_inputs = staticmethod(to_tensor_func)
# Mark for save logic to know this needs function renaming
cls._has_legacy_to_tensor = True
else:
# No method defined - use defaults
cls.prepare_inputs = staticmethod(default_prepare_inputs)
# Update required methods to include prepare_inputs
required_methods = ["train_batch", "infer_batch", "__init__", "prepare_inputs"]
for name in required_methods:
if not hasattr(cls, name):
logger.error(f"Hyrax model {cls.__name__} missing required method {name}.")
update_registry(MODEL_REGISTRY, cls.__name__, cls)
return cls
[docs]
def fetch_model_class(runtime_config: dict) -> type[nn.Module]:
"""Fetch the model class from the model registry.
Parameters
----------
runtime_config : dict
The runtime configuration dictionary.
Returns
-------
type
The model class.
Raises
------
ValueError
If a built in model was requested, but not found in the model registry.
ValueError
If no model was specified in the runtime configuration.
"""
model_name = runtime_config["model"]["name"] if runtime_config["model"]["name"] else None
model_cls = None
if not model_name:
model_list = "\n".join([f" - {model}" for model in sorted(MODEL_REGISTRY.keys())])
logger.error(
"No model name was provided in the configuration. "
"You must specify a model to use before running Hyrax.\n\n"
"To set a model, use: h.set_config('model.name', '<model_name>')\n"
"<model_name> can be one of the following registered models or a path to a custom model class "
"e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'.\n\n"
f"Currently registered models:\n{model_list}"
)
raise RuntimeError(
"A model class name or path must be provided. "
"e.g. 'HyraxCNN' or 'my_package.my_module.MyModelClass'."
)
model_cls = cast(type[nn.Module], get_or_load_class(model_name, MODEL_REGISTRY))
return model_cls