hyrax.pytorch_ignite
====================

.. py:module:: hyrax.pytorch_ignite


Attributes
----------

.. autoapisummary::

   hyrax.pytorch_ignite.logger


Classes
-------

.. autoapisummary::

   hyrax.pytorch_ignite.SubsetSequentialSampler
   hyrax.pytorch_ignite.HyraxEvents


Functions
---------

.. autoapisummary::

   hyrax.pytorch_ignite.is_iterable_dataset_requested
   hyrax.pytorch_ignite.setup_dataset
   hyrax.pytorch_ignite.setup_model
   hyrax.pytorch_ignite.load_collate_function
   hyrax.pytorch_ignite.dist_data_loader
   hyrax.pytorch_ignite.create_splits
   hyrax.pytorch_ignite._handle_nans
   hyrax.pytorch_ignite._handle_nans_tensor
   hyrax.pytorch_ignite._handle_nans_numpy
   hyrax.pytorch_ignite._handle_nans_tuple
   hyrax.pytorch_ignite._handle_nans_logic_torch
   hyrax.pytorch_ignite._handle_nan_quantile_torch
   hyrax.pytorch_ignite._handle_nan_zero_torch
   hyrax.pytorch_ignite._handle_nans_logic_numpy
   hyrax.pytorch_ignite._handle_nan_quantile_numpy
   hyrax.pytorch_ignite._handle_nan_zero_numpy
   hyrax.pytorch_ignite._inner_loop
   hyrax.pytorch_ignite._create_process_func
   hyrax.pytorch_ignite.create_engine
   hyrax.pytorch_ignite.extract_model_method
   hyrax.pytorch_ignite.create_evaluator
   hyrax.pytorch_ignite.create_validator
   hyrax.pytorch_ignite.create_trainer
   hyrax.pytorch_ignite.fixup_engine


Module Contents
---------------

.. py:data:: logger

.. py:class:: SubsetSequentialSampler(indices: collections.abc.Sequence[int], generator=None)

   Bases: :py:obj:`torch.utils.data.Sampler`\ [\ :py:obj:`int`\ ]


   Samples elements sequentially from a given list of indices, without replacement.

   :param indices: sequence
                   a sequence of indices


   .. py:attribute:: indices
      :type:  collections.abc.Sequence[int]


   .. py:attribute:: generator
      :value: None



   .. py:method:: __iter__() -> collections.abc.Iterator[int]


   .. py:method:: __len__() -> int


.. py:function:: is_iterable_dataset_requested(data_request: dict) -> bool

   This function checks each of the datasets included in the data_request.
   If any of them are iterable-style datasets, we return True.


.. py:function:: setup_dataset(config: dict, tensorboardx_logger: tensorboardX.SummaryWriter | None = None) -> torch.utils.data.Dataset

   This function creates an instance of the requested dataset specified in the
   runtime configuration. There are two modes encapsulated here:

   1) If the dataset requested includes an iterable-style dataset, ensure that only
   one dataset was requested, and then return an instance of that dataset.
   2) If the dataset(s) requested is for 1 or more map-style dataset, create an
   instance of a DataProvider, and return that as the dataset.

   :param config: The runtime configuration
   :type config: dict
   :param tensorboardx_logger: If Tensorboard is in use, the tensorboard logger so the dataset can log things
   :type tensorboardx_logger: SummaryWriter, optional

   :returns: An instance of the dataset class specified in the configuration
   :rtype: Dataset


.. py:function:: setup_model(config: dict, dataset: torch.utils.data.Dataset) -> torch.nn.Module

   Create a model object based on the configuration.

   :param config: The runtime configuration
   :type config: dict
   :param dataset: The dataset object that will provide data to the model for training or
                   inference. Here it is only used to provide a data sample to the model so
                   that it can resize itself at runtime if necessary.
   :type dataset: Dataset

   :returns: An instance of the model class specified in the configuration
   :rtype: torch.nn.Module


.. py:function:: load_collate_function(data_loader_kwargs: dict) -> collections.abc.Callable | None

   Load a collate function if one is specified in the config. Otherwise return None.
   Returning None will cause the DataLoader to use PyTorch's default collate function.

   :param data_loader_kwargs: The configuration dictionary that will be passed as kwargs to the DataLoader
   :type data_loader_kwargs: dict

   :returns: The collate function if specified, else None
   :rtype: Optional[Callable]


.. py:function:: dist_data_loader(dataset: torch.utils.data.Dataset, config: dict, split: Union[str, list[str], bool] = False)

   Create Pytorch Ignite distributed data loaders

   It is recommended that each verb needing dataloaders only call this function once.

   :param dataset: A Hyrax dataset instance
   :type dataset: hyrax.data_sets.data_set_registry.HyraxDataset
   :param config: Hyrax runtime configuration
   :type config: dict
   :param split: The name(s) of the split we want to use from the data set.
                 If this is false or not passed, then a single data loader is returned
                 that corresponds to the entire dataset.
   :type split: Union[str, list[str]], Optional

   :returns: * *Dataloader (or an ignite-wrapped equivalent)* -- This is the distributed dataloader, formed by calling ignite.distributed.auto_dataloader
             * *For multiple splits, we return a dictionary where the keys are the names of the splits*
             * *and the value is either a Dataloader as described above or the value None if the split*
             * *was not configured.*
             * *If an iterable dataset is passed, we cannot create multiple splits with a pyTorch sampler object*
             * *so we return the same thing for all splits, which is a dataloader representing the entire iterable*


.. py:function:: create_splits(data_set: torch.utils.data.Dataset, config: dict)

   Returns train, test, and validation indexes constructed to be used with the passed in
   dataset. The allocation of indexes in the underlying dataset to samplers depends on
   the data_set section of the config dict.

   :param data_set: The data set to use
   :type data_set: Dataset
   :param config: Configuration that defines dataset splits
   :type config: dict
   :param split: Name of the split to use.
   :type split: str


.. py:function:: _handle_nans(batch, config)

   The default _handle_nan function. Will print a warning and return `batch`.


.. py:function:: _handle_nans_tensor(batch, config)

   The implementation of _handle_nans when expecting `batch` to be a tensor.


.. py:function:: _handle_nans_numpy(batch, config)

.. py:function:: _handle_nans_tuple(batch, config)

   This is the tuple-specific implementation of _handle_nans. Each tensor element
   of the tuple will have nan-handling applied. Non-tensor elements are returned unchanged.


.. py:function:: _handle_nans_logic_torch(batch, config)

.. py:function:: _handle_nan_quantile_torch(batch, quantile)

.. py:function:: _handle_nan_zero_torch(batch)

.. py:function:: _handle_nans_logic_numpy(batch, config)

.. py:function:: _handle_nan_quantile_numpy(batch, quantile)

.. py:function:: _handle_nan_zero_numpy(batch)

.. py:function:: _inner_loop(func, to_tensor, device, config, engine, batch)

   This wraps a model-specific function (func) to move data to the appropriate device.


.. py:function:: _create_process_func(funcname, device, model, config)

.. py:function:: create_engine(funcname: str, device: torch.device, model: torch.nn.Module, config: dict) -> ignite.engine.Engine

   Unified creation of the pytorch engine object for either an evaluator or trainer.

   This function will automatically unwrap a distributed model to find the necessary function, and construct
   the necessary functions to transfer data to the device on every batch, so model code can be the same no
   matter where the model is being run.

   :param funcname: The function name on the model that we will call in the core of the engine loop, and be called once
                    per batch
   :type funcname: str
   :param device: The device the engine will run the model on
   :type device: torch.device
   :param model: The Model the engine will be using
   :type model: torch.nn.Module
   :param config: The runtime config in use
   :type config: dict


.. py:function:: extract_model_method(model, method_name)

   Extract a method from a model, which may be wrapped in a DistributedDataParallel
   or DataParallel object. For instance, method_name could be `train_step` or
   `forward`.

   :param model: The model to extract the method from
   :type model: nn.Module, DistributedDataParallel, or DataParallel
   :param method_name: Name of the method to extract
   :type method_name: str

   :returns: The method extracted from the model
   :rtype: Callable


.. py:function:: create_evaluator(model: torch.nn.Module, save_function: collections.abc.Callable[[torch.Tensor, torch.Tensor], Any], config: dict) -> ignite.engine.Engine

   Creates an evaluator engine
   Primary purpose of this function is to attach the appropriate handlers to an evaluator engine

   :param model: The model to evaluate
   :type model: torch.nn.Module
   :param save_function: A function which will receive Engine.state.output at the end of each iteration. The intent
                         is for the results of evaluation to be saved.
   :type save_function: Callable[[torch.Tensor], Any]
   :param config: The runtime config in use
   :type config: dict

   :returns: Engine object which when run will evaluate the model.
   :rtype: pytorch-ignite.Engine


.. py:function:: create_validator(model: torch.nn.Module, config: dict, results_directory: pathlib.Path, tensorboardx_logger: tensorboardX.SummaryWriter, validation_data_loader: torch.utils.data.DataLoader, trainer: ignite.engine.Engine) -> ignite.engine.Engine

   This function creates a Pytorch Ignite engine object that will be used to
   validate the model.

   :param model: The model to train
   :type model: torch.nn.Module
   :param config: Hyrax runtime configuration
   :type config: dict
   :param results_directory: The directory where training results will be saved
   :type results_directory: Path
   :param tensorboardx_logger: The tensorboard logger object
   :type tensorboardx_logger: SummaryWriter
   :param validation_data_loader: The data loader for the validation data
   :type validation_data_loader: DataLoader
   :param trainer: The engine object that will be used to train the model. We will use specific
                   hooks in the trainer to determine when to run the validation engine.
   :type trainer: pytorch-ignite.Engine

   :returns: Engine object that will be used to train the model.
   :rtype: pytorch-ignite.Engine


.. py:function:: create_trainer(model: torch.nn.Module, config: dict, results_directory: pathlib.Path, tensorboardx_logger: tensorboardX.SummaryWriter) -> ignite.engine.Engine

   This function is originally copied from here:
   https://github.com/pytorch-ignite/examples/blob/main/tutorials/intermediate/cifar10-distributed.py#L164

   It was substantially trimmed down to make it easier to understand.

   :param model: The model to train
   :type model: torch.nn.Module
   :param config: Hyrax runtime configuration
   :type config: dict
   :param results_directory: The directory where training results will be saved
   :type results_directory: Path
   :param tensorboardx_logger: The tensorboard logger object
   :type tensorboardx_logger: SummaryWriter

   :returns: Engine object that will be used to train the model.
   :rtype: pytorch-ignite.Engine


.. py:class:: HyraxEvents(value: str, event_filter: Optional[Callable] = None, name: Optional[str] = None)

   Bases: :py:obj:`ignite.engine.EventEnum`


   Workaround event for a pytorch ignite bug. See fixup_engine for details


   .. py:attribute:: HYRAX_EPOCH_COMPLETED
      :value: 'HyraxEpochCompleted'



.. py:function:: fixup_engine(engine: ignite.engine.Engine) -> ignite.engine.Engine

   Workaround for this pytorch ignite bug (https://github.com/pytorch/ignite/issues/3372) where
   engine.state.output is not available at EPOCH_COMPLETED or later times (COMPLETED, etc)

   We create a new event HYRAX_EPOCH_COMPLETED which triggers at ITERATION_COMPLETED, but only on the final
   iteration. This is just before the erronious state reset.

   This hack relies on pytorch ignite internal state, but can be removed as soon as our fix is mainlined
   (https://github.com/pytorch/ignite/pull/3373) in version 0.6.0 estimated August 2025


