hyrax.models.model_registry
===========================

.. py:module:: hyrax.models.model_registry


Attributes
----------

.. autoapisummary::

   hyrax.models.model_registry.logger
   hyrax.models.model_registry.MODEL_REGISTRY


Functions
---------

.. autoapisummary::

   hyrax.models.model_registry._torch_save
   hyrax.models.model_registry._torch_load
   hyrax.models.model_registry._torch_criterion
   hyrax.models.model_registry._torch_optimizer
   hyrax.models.model_registry._torch_schedulers
   hyrax.models.model_registry.hyrax_model
   hyrax.models.model_registry.fetch_model_class


Module Contents
---------------

.. py:data:: logger

.. py:data:: MODEL_REGISTRY
   :type:  dict[str, type[torch.nn.Module]]

.. py:function:: _torch_save(self: torch.nn.Module, save_path: pathlib.Path)

.. py:function:: _torch_load(self: torch.nn.Module, load_path: pathlib.Path)

.. py:function:: _torch_criterion(self: torch.nn.Module)

   Load the criterion class using the name defined in the config and
   instantiate it with the arguments defined in the config.


.. py:function:: _torch_optimizer(self: torch.nn.Module)

   Load the optimizer class using the name defined in the config and
   instantiate it with the arguments defined in the config.


.. py:function:: _torch_schedulers(self: torch.nn.Module)

   Load the scheduler classes using the names defined in the config and
   instantiate it with the arguments defined in the config.


.. py:function:: hyrax_model(cls)

   Decorator to register a model with the model registry, and to add common interface functions

   :returns: The class with additional interface functions.
   :rtype: type


.. py:function:: fetch_model_class(runtime_config: dict) -> type[torch.nn.Module]

   Fetch the model class from the model registry.

   :param runtime_config: The runtime configuration dictionary.
   :type runtime_config: dict

   :returns: The model class.
   :rtype: type

   :raises ValueError: If a built in model was requested, but not found in the model registry.
   :raises ValueError: If no model was specified in the runtime configuration.


