hyrax.models.model_registry#

Attributes#

Functions#

_torch_save(self, save_path)

_torch_load(self, load_path)

_torch_criterion(self)

Load the criterion class using the name defined in the config and

_torch_optimizer(self)

Load the optimizer class using the name defined in the config and

_torch_schedulers(self)

Load the scheduler classes using the names defined in the config and

hyrax_model(cls)

Decorator to register a model with the model registry, and to add common interface functions

fetch_model_class(→ type[torch.nn.Module])

Fetch the model class from the model registry.

Module Contents#

logger[source]#
MODEL_REGISTRY: dict[str, type[torch.nn.Module]][source]#
_torch_save(self: torch.nn.Module, save_path: pathlib.Path)[source]#
_torch_load(self: torch.nn.Module, load_path: pathlib.Path)[source]#
_torch_criterion(self: torch.nn.Module)[source]#

Load the criterion class using the name defined in the config and instantiate it with the arguments defined in the config.

_torch_optimizer(self: torch.nn.Module)[source]#

Load the optimizer class using the name defined in the config and instantiate it with the arguments defined in the config.

_torch_schedulers(self: torch.nn.Module)[source]#

Load the scheduler classes using the names defined in the config and instantiate it with the arguments defined in the config.

hyrax_model(cls)[source]#

Decorator to register a model with the model registry, and to add common interface functions

Returns:

The class with additional interface functions.

Return type:

type

fetch_model_class(runtime_config: dict) type[torch.nn.Module][source]#

Fetch the model class from the model registry.

Parameters:

runtime_config (dict) – The runtime configuration dictionary.

Returns:

The model class.

Return type:

type

Raises:
  • ValueError – If a built in model was requested, but not found in the model registry.

  • ValueError – If no model was specified in the runtime configuration.