hyrax.pytorch_ignite
====================

.. py:module:: hyrax.pytorch_ignite


Attributes
----------

.. autoapisummary::

   hyrax.pytorch_ignite.logger
   hyrax.pytorch_ignite._LEGACY_SPLIT_KEYS


Classes
-------

.. autoapisummary::

   hyrax.pytorch_ignite.SubsetSequentialSampler
   hyrax.pytorch_ignite.HyraxEvents


Functions
---------

.. autoapisummary::

   hyrax.pytorch_ignite.setup_dataset
   hyrax.pytorch_ignite.setup_model
   hyrax.pytorch_ignite.dist_data_loader
   hyrax.pytorch_ignite._inner_loop
   hyrax.pytorch_ignite._create_process_func
   hyrax.pytorch_ignite.create_engine
   hyrax.pytorch_ignite.extract_model_method
   hyrax.pytorch_ignite.create_evaluator
   hyrax.pytorch_ignite.create_validator
   hyrax.pytorch_ignite.create_tester
   hyrax.pytorch_ignite.attach_best_checkpoint
   hyrax.pytorch_ignite.create_trainer
   hyrax.pytorch_ignite.create_save_batch_callback
   hyrax.pytorch_ignite.fixup_engine


Module Contents
---------------

.. py:data:: logger

.. py:data:: _LEGACY_SPLIT_KEYS
   :value: ('train_size', 'validate_size', 'test_size')


.. py:class:: SubsetSequentialSampler(indices: collections.abc.Sequence[int], generator=None)

   Bases: :py:obj:`torch.utils.data.Sampler`\ [\ :py:obj:`int`\ ]


   Samples elements sequentially from a given list of indices, without replacement.

   :param indices: sequence
                   a sequence of indices


   .. py:attribute:: indices
      :type:  collections.abc.Sequence[int]


   .. py:attribute:: generator
      :value: None



   .. py:method:: __iter__() -> collections.abc.Iterator[int]


   .. py:method:: __len__() -> int


.. py:function:: setup_dataset(config: dict, *, splits: tuple[str, Ellipsis] | None = None, shuffle: bool = True) -> dict[str, hyrax.datasets.data_provider.DataProvider]

   Create DataProvider instances for each requested data group.

   :param config: The runtime configuration.
   :type config: dict
   :param splits: When provided, only create DataProvider instances for the listed groups.
                  When ``None`` every group in the data_request is loaded.
   :type splits: tuple[str, ...] | None, optional
   :param shuffle: Unused; kept for backward-compatibility with call sites that still pass
                   it.  Split shuffling is now handled by ``splitting_utils.create_splits``.
   :type shuffle: bool, optional

   :returns: Mapping of data group names to DataProvider instances.
   :rtype: dict[str, DataProvider]


.. py:function:: setup_model(config: dict, dataset: hyrax.datasets.data_provider.DataProvider) -> torch.nn.Module

   Create a model object based on the configuration.

   :param config: The runtime configuration
   :type config: dict
   :param dataset: The dataset object that will provide data to the model for training or
                   inference. Here it is only used to provide a data sample to the model so
                   that it can resize itself at runtime if necessary.
   :type dataset: DataProvider

   :returns: An instance of the model class specified in the configuration
   :rtype: torch.nn.Module


.. py:function:: dist_data_loader(dataset: torch.utils.data.Dataset, config: dict, shuffle: bool = False) -> torch.utils.data.DataLoader

   Create Pytorch Ignite distributed data loaders.

   It is recommended that each verb needing dataloaders only call this function once.

   :param dataset: A Hyrax dataset instance.  When *dataset* is a :class:`DataProvider`
                   with ``split_indices`` set (by :func:`~hyrax.splitting_utils.create_splits`),
                   the loader is restricted to those indices via a :class:`~torch.utils.data.Subset`.
                   When ``split_weights`` is also set, a
                   :class:`~torch.utils.data.WeightedRandomSampler` is used so that
                   under-represented classes are over-sampled to achieve the configured
                   class distribution.
   :type dataset: hyrax.datasets.dataset_registry.HyraxDataset
   :param config: Hyrax runtime configuration
   :type config: dict
   :param shuffle: If ``True`` and no weights are present, a
                   :class:`~torch.utils.data.SubsetRandomSampler` is used for uniform
                   shuffling.  If ``False`` and no weights, a sequential sampler preserves
                   deterministic order.  Ignored when ``split_weights`` is set (weighted
                   sampling always draws with replacement).  Defaults to ``False`` so
                   non-training verbs preserve deterministic order.
   :type shuffle: bool, optional

   :returns: The distributed dataloader.
   :rtype: DataLoader


.. py:function:: _inner_loop(func, prepare_inputs, device, config, engine, batch)

   This wraps a model-specific function (func) to move data to the appropriate device.


.. py:function:: _create_process_func(funcname, device, model, config)

.. py:function:: create_engine(funcname: str, device: torch.device, model: torch.nn.Module, config: dict) -> ignite.engine.Engine

   Unified creation of the pytorch engine object for either an evaluator or trainer.

   This function will automatically unwrap a distributed model to find the necessary function, and construct
   the necessary functions to transfer data to the device on every batch, so model code can be the same no
   matter where the model is being run.

   :param funcname: The function name on the model that we will call in the core of the engine loop, and be called once
                    per batch
   :type funcname: str
   :param device: The device the engine will run the model on
   :type device: torch.device
   :param model: The Model the engine will be using
   :type model: torch.nn.Module
   :param config: The runtime config in use
   :type config: dict


.. py:function:: extract_model_method(model, method_name)

   Extract a method from a model, which may be wrapped in a DistributedDataParallel
   or DataParallel object. For instance, method_name could be `train_batch` or
   `infer_batch`.

   :param model: The model to extract the method from
   :type model: nn.Module, DistributedDataParallel, or DataParallel
   :param method_name: Name of the method to extract
   :type method_name: str

   :returns: The method extracted from the model
   :rtype: Callable


.. py:function:: create_evaluator(model: torch.nn.Module, save_function: collections.abc.Callable[[torch.Tensor, torch.Tensor], Any], config: dict) -> ignite.engine.Engine

   Creates an evaluator engine
   Primary purpose of this function is to attach the appropriate handlers to an evaluator engine

   :param model: The model to evaluate
   :type model: torch.nn.Module
   :param save_function: A function which will receive Engine.state.output at the end of each iteration. The intent
                         is for the results of evaluation to be saved.
   :type save_function: Callable[[torch.Tensor], Any]
   :param config: The runtime config in use
   :type config: dict

   :returns: Engine object which when run will evaluate the model.
   :rtype: pytorch-ignite.Engine


.. py:function:: create_validator(model: torch.nn.Module, config: dict, validation_data_loader: torch.utils.data.DataLoader, trainer: ignite.engine.Engine) -> ignite.engine.Engine

   This function creates a Pytorch Ignite engine object that will be used to
   validate the model.

   :param model: The model to train
   :type model: torch.nn.Module
   :param config: Hyrax runtime configuration
   :type config: dict
   :param validation_data_loader: The data loader for the validation data
   :type validation_data_loader: DataLoader
   :param trainer: The engine object that will be used to train the model. We will use specific
                   hooks in the trainer to determine when to run the validation engine.
   :type trainer: pytorch-ignite.Engine

   :returns: Engine object that will be used to train the model.
   :rtype: pytorch-ignite.Engine


.. py:function:: create_tester(model: torch.nn.Module, config: dict) -> ignite.engine.Engine

   This function creates a Pytorch Ignite engine object that will be used to
   test the model and compute metrics without updating model weights.

   :param model: The model to test
   :type model: torch.nn.Module
   :param config: Hyrax runtime configuration
   :type config: dict

   :returns: Engine object that will be used to test the model and compute metrics.
   :rtype: pytorch-ignite.Engine


.. py:function:: attach_best_checkpoint(engine: ignite.engine.Engine, model: torch.nn.Module, trainer: ignite.engine.Engine, results_directory: pathlib.Path) -> None

   Attach a best-checkpoint handler to ``engine``, scored on ``engine.state.output["loss"]``.

   Call this function *after* both ``create_trainer`` and (optionally) ``create_validator``
   have been called so that handler registration order is correct.  When a validator is
   available, pass it as ``engine`` so that checkpointing is driven by validation loss.
   When no validator is available, pass the trainer as ``engine`` so that checkpointing
   falls back to training loss — preserving the previous behaviour.

   The saved checkpoint format is identical to the one produced by ``create_trainer``, so
   existing resume logic is fully backward-compatible.

   :param engine: The engine whose ``output["loss"]`` is used as the checkpoint score.  Pass the
                  validator when one exists; otherwise pass the trainer. If the engine has a
                  ``hyrax_label`` attribute, it will be included in the checkpoint filename.
   :type engine: pytorch-ignite.Engine
   :param model: The model being trained.  Must expose ``model.optimizer`` and optionally
                 ``model.scheduler``.
   :type model: torch.nn.Module
   :param trainer: The training engine.  Used to derive the global step counter and to attach the
                   end-of-training log handler.
   :type trainer: pytorch-ignite.Engine
   :param results_directory: Directory where checkpoint files are written.
   :type results_directory: Path


.. py:function:: create_trainer(model: torch.nn.Module, config: dict, results_directory: pathlib.Path) -> ignite.engine.Engine

   This function is originally copied from here:
   https://github.com/pytorch-ignite/examples/blob/main/tutorials/intermediate/cifar10-distributed.py#L164

   It was substantially trimmed down to make it easier to understand.

   :param model: The model to train
   :type model: torch.nn.Module
   :param config: Hyrax runtime configuration
   :type config: dict
   :param results_directory: The directory where training results will be saved
   :type results_directory: Path

   :returns: Engine object that will be used to train the model.
   :rtype: pytorch-ignite.Engine


.. py:function:: create_save_batch_callback(results_dir)

   Create a callback function for saving batch results during inference or testing.

   This factory function creates a closure that captures the output directory,
   then returns a callback that can be used by pytorch_ignite engines to save
   model outputs batch by batch.

   :param results_dir: Directory where results should be saved
   :type results_dir: Path

   :returns: A callback function with signature (batch, batch_results) that saves results
   :rtype: callable


.. py:class:: HyraxEvents(value: str, event_filter: collections.abc.Callable | None = None, name: str | None = None)

   Bases: :py:obj:`ignite.engine.EventEnum`


   Workaround event for a pytorch ignite bug. See fixup_engine for details


   .. py:attribute:: HYRAX_EPOCH_COMPLETED
      :value: 'HyraxEpochCompleted'



.. py:function:: fixup_engine(engine: ignite.engine.Engine)

   Workaround for this pytorch ignite bug (https://github.com/pytorch/ignite/issues/3372) where
   engine.state.output is not available at EPOCH_COMPLETED or later times (COMPLETED, etc)

   We create a new event HYRAX_EPOCH_COMPLETED which triggers at ITERATION_COMPLETED, but only on the final
   iteration. This is just before the erronious state reset.

   This hack relies on pytorch ignite internal state, but can be removed as soon as our fix is mainlined
   (https://github.com/pytorch/ignite/pull/3373) in version 0.6.0 estimated August 2025


