hyrax.pytorch_ignite#

Attributes#

Classes#

SubsetSequentialSampler

Samples elements sequentially from a given list of indices, without replacement.

HyraxEvents

Workaround event for a pytorch ignite bug. See fixup_engine for details

Functions#

setup_dataset(→ dict[str, ...)

This function creates an instance of the requested dataset(s) specified in the

setup_model(→ torch.nn.Module)

Create a model object based on the configuration.

dist_data_loader(dataset, config[, split, shuffle])

Create Pytorch Ignite distributed data loaders

create_splits(data_set, config)

Returns train, test, and validation indexes constructed to be used with the passed in

create_splits_from_fractions(→ dict[str, list[int]])

Partition a shared set of indices across dataset groups using the

_inner_loop(func, prepare_inputs, device, config, ...)

This wraps a model-specific function (func) to move data to the appropriate device.

_create_process_func(funcname, device, model, config)

create_engine(→ ignite.engine.Engine)

Unified creation of the pytorch engine object for either an evaluator or trainer.

extract_model_method(model, method_name)

Extract a method from a model, which may be wrapped in a DistributedDataParallel

create_evaluator(→ ignite.engine.Engine)

Creates an evaluator engine

create_validator(→ ignite.engine.Engine)

This function creates a Pytorch Ignite engine object that will be used to

create_tester(→ ignite.engine.Engine)

This function creates a Pytorch Ignite engine object that will be used to

attach_best_checkpoint(→ None)

Attach a best-checkpoint handler to engine, scored on engine.state.output["loss"].

create_trainer(→ ignite.engine.Engine)

This function is originally copied from here:

create_save_batch_callback(results_dir)

Create a callback function for saving batch results during inference or testing.

fixup_engine(engine)

Workaround for this pytorch ignite bug (pytorch/ignite#3372) where

Module Contents#

logger[source]#
class SubsetSequentialSampler(indices: collections.abc.Sequence[int], generator=None)[source]#

Bases: torch.utils.data.Sampler[int]

Samples elements sequentially from a given list of indices, without replacement.

Parameters:

indices – sequence a sequence of indices

indices: collections.abc.Sequence[int][source]#
generator = None[source]#
__iter__() collections.abc.Iterator[int][source]#
__len__() int[source]#
setup_dataset(config: dict, *, splits: tuple[str, Ellipsis] | None = None, shuffle: bool = True) dict[str, hyrax.datasets.data_provider.DataProvider][source]#

This function creates an instance of the requested dataset(s) specified in the runtime configuration for the given splits (data_groups).

It will create an instance of a DataProvider, and return that as the dataset.

Parameters:
  • config (dict) – The runtime configuration

  • splits (tuple[str, ...] | None, optional) – When provided, only create DataProvider instances for the groups whose names appear in splits. Groups present in the data_request but not listed here are silently skipped. When None (the default) every group in the data_request is loaded — preserving backward compatibility.

  • shuffle (bool, optional) – Whether to shuffle indices when computing split_fraction-based partitions via create_splits_from_fractions(). Defaults to True. Set to False for inference / test verbs where deterministic ordering is required.

Returns:

A dictionary mapping data group names to DataProvider instances.

Return type:

dict[str, DataProvider]

setup_model(config: dict, dataset: hyrax.datasets.data_provider.DataProvider) torch.nn.Module[source]#

Create a model object based on the configuration.

Parameters:
  • config (dict) – The runtime configuration

  • dataset (DataProvider) – The dataset object that will provide data to the model for training or inference. Here it is only used to provide a data sample to the model so that it can resize itself at runtime if necessary.

Returns:

An instance of the model class specified in the configuration

Return type:

torch.nn.Module

dist_data_loader(dataset: torch.utils.data.Dataset, config: dict, split: str | list[str] | bool = False, shuffle: bool = False)[source]#

Create Pytorch Ignite distributed data loaders

It is recommended that each verb needing dataloaders only call this function once.

Parameters:
  • dataset (hyrax.datasets.dataset_registry.HyraxDataset) – A Hyrax dataset instance

  • config (dict) – Hyrax runtime configuration

  • split (Union[str, list[str]], Optional) – The name(s) of the split we want to use from the data set. If this is false or not passed, then a single data loader is returned that corresponds to the entire dataset.

  • shuffle (bool, optional) – If True, selected training indices are sampled with SubsetRandomSampler. If False, selected indices are sampled with SubsetSequentialSampler. Defaults to False so non-training verbs preserve deterministic order.

Returns:

  • Dataloader (or an ignite-wrapped equivalent) – This is the distributed dataloader, formed by calling ignite.distributed.auto_dataloader

  • For multiple splits, we return a dictionary where the keys are the names of the splits

  • and the value is either a Dataloader as described above or the value None if the split

  • was not configured.

create_splits(data_set: torch.utils.data.Dataset, config: dict)[source]#

Returns train, test, and validation indexes constructed to be used with the passed in dataset. The allocation of indexes in the underlying dataset to samplers depends on the data_set section of the config dict.

Deprecated since version This: function and the associated configuration style using config["data_set"]["train_size"], config["data_set"]["validate_size"], and config["data_set"]["test_size"] is deprecated and will be removed in a future release. Please migrate to defining separate dataset groups in [data_request] with split_fraction for each group.

Parameters:
  • data_set (Dataset) – The data set to use

  • config (dict) – Configuration that defines dataset splits

  • split (str) – Name of the split to use.

create_splits_from_fractions(dataset_providers: dict[str, Any], config: dict, *, shuffle: bool = True) dict[str, list[int]][source]#

Partition a shared set of indices across dataset groups using the split_fraction defined on each DataProvider.

All providers in dataset_providers are expected to wrap the same underlying data source (same data_location). The full index range [0, len) of the first provider is shuffled deterministically (when shuffle is True) using config["data_set"]["seed"], then sliced into contiguous, non-overlapping segments proportional to each provider’s split_fraction.

Parameters:
  • dataset_providers (dict[str, Any]) – Mapping of group name (e.g. "train", "validate") to a DataProvider instance whose split_fraction is set.

  • config (dict) – The Hyrax runtime configuration. Only config["data_set"]["seed"] is used here.

  • shuffle (bool, optional) – Whether to shuffle the index array before slicing. Defaults to True. Set to False for inference / test workloads where deterministic sequential ordering is required.

Returns:

Mapping of group name → list of indices assigned to that group.

Return type:

dict[str, list[int]]

Raises:

RuntimeError – If any provider is missing a split_fraction, if the fractions sum to more than 1.0, or if providers have mismatched lengths.

_inner_loop(func, prepare_inputs, device, config, engine, batch)[source]#

This wraps a model-specific function (func) to move data to the appropriate device.

_create_process_func(funcname, device, model, config)[source]#
create_engine(funcname: str, device: torch.device, model: torch.nn.Module, config: dict) ignite.engine.Engine[source]#

Unified creation of the pytorch engine object for either an evaluator or trainer.

This function will automatically unwrap a distributed model to find the necessary function, and construct the necessary functions to transfer data to the device on every batch, so model code can be the same no matter where the model is being run.

Parameters:
  • funcname (str) – The function name on the model that we will call in the core of the engine loop, and be called once per batch

  • device (torch.device) – The device the engine will run the model on

  • model (torch.nn.Module) – The Model the engine will be using

  • config (dict) – The runtime config in use

extract_model_method(model, method_name)[source]#

Extract a method from a model, which may be wrapped in a DistributedDataParallel or DataParallel object. For instance, method_name could be train_batch or infer_batch.

Parameters:
  • model (nn.Module, DistributedDataParallel, or DataParallel) – The model to extract the method from

  • method_name (str) – Name of the method to extract

Returns:

The method extracted from the model

Return type:

Callable

create_evaluator(model: torch.nn.Module, save_function: collections.abc.Callable[[torch.Tensor, torch.Tensor], Any], config: dict) ignite.engine.Engine[source]#

Creates an evaluator engine Primary purpose of this function is to attach the appropriate handlers to an evaluator engine

Parameters:
  • model (torch.nn.Module) – The model to evaluate

  • save_function (Callable[[torch.Tensor], Any]) – A function which will receive Engine.state.output at the end of each iteration. The intent is for the results of evaluation to be saved.

  • config (dict) – The runtime config in use

Returns:

Engine object which when run will evaluate the model.

Return type:

pytorch-ignite.Engine

create_validator(model: torch.nn.Module, config: dict, validation_data_loader: torch.utils.data.DataLoader, trainer: ignite.engine.Engine) ignite.engine.Engine[source]#

This function creates a Pytorch Ignite engine object that will be used to validate the model.

Parameters:
  • model (torch.nn.Module) – The model to train

  • config (dict) – Hyrax runtime configuration

  • validation_data_loader (DataLoader) – The data loader for the validation data

  • trainer (pytorch-ignite.Engine) – The engine object that will be used to train the model. We will use specific hooks in the trainer to determine when to run the validation engine.

Returns:

Engine object that will be used to train the model.

Return type:

pytorch-ignite.Engine

create_tester(model: torch.nn.Module, config: dict) ignite.engine.Engine[source]#

This function creates a Pytorch Ignite engine object that will be used to test the model and compute metrics without updating model weights.

Parameters:
  • model (torch.nn.Module) – The model to test

  • config (dict) – Hyrax runtime configuration

Returns:

Engine object that will be used to test the model and compute metrics.

Return type:

pytorch-ignite.Engine

attach_best_checkpoint(engine: ignite.engine.Engine, model: torch.nn.Module, trainer: ignite.engine.Engine, results_directory: pathlib.Path) None[source]#

Attach a best-checkpoint handler to engine, scored on engine.state.output["loss"].

Call this function after both create_trainer and (optionally) create_validator have been called so that handler registration order is correct. When a validator is available, pass it as engine so that checkpointing is driven by validation loss. When no validator is available, pass the trainer as engine so that checkpointing falls back to training loss — preserving the previous behaviour.

The saved checkpoint format is identical to the one produced by create_trainer, so existing resume logic is fully backward-compatible.

Parameters:
  • engine (pytorch-ignite.Engine) – The engine whose output["loss"] is used as the checkpoint score. Pass the validator when one exists; otherwise pass the trainer. If the engine has a hyrax_label attribute, it will be included in the checkpoint filename.

  • model (torch.nn.Module) – The model being trained. Must expose model.optimizer and optionally model.scheduler.

  • trainer (pytorch-ignite.Engine) – The training engine. Used to derive the global step counter and to attach the end-of-training log handler.

  • results_directory (Path) – Directory where checkpoint files are written.

create_trainer(model: torch.nn.Module, config: dict, results_directory: pathlib.Path) ignite.engine.Engine[source]#

This function is originally copied from here: pytorch-ignite/examples

It was substantially trimmed down to make it easier to understand.

Parameters:
  • model (torch.nn.Module) – The model to train

  • config (dict) – Hyrax runtime configuration

  • results_directory (Path) – The directory where training results will be saved

Returns:

Engine object that will be used to train the model.

Return type:

pytorch-ignite.Engine

create_save_batch_callback(results_dir)[source]#

Create a callback function for saving batch results during inference or testing.

This factory function creates a closure that captures the output directory, then returns a callback that can be used by pytorch_ignite engines to save model outputs batch by batch.

Parameters:

results_dir (Path) – Directory where results should be saved

Returns:

A callback function with signature (batch, batch_results) that saves results

Return type:

callable

class HyraxEvents(value: str, event_filter: collections.abc.Callable | None = None, name: str | None = None)[source]#

Bases: ignite.engine.EventEnum

Workaround event for a pytorch ignite bug. See fixup_engine for details

HYRAX_EPOCH_COMPLETED = 'HyraxEpochCompleted'[source]#
fixup_engine(engine: ignite.engine.Engine)[source]#

Workaround for this pytorch ignite bug (pytorch/ignite#3372) where engine.state.output is not available at EPOCH_COMPLETED or later times (COMPLETED, etc)

We create a new event HYRAX_EPOCH_COMPLETED which triggers at ITERATION_COMPLETED, but only on the final iteration. This is just before the erronious state reset.

This hack relies on pytorch ignite internal state, but can be removed as soon as our fix is mainlined (pytorch/ignite#3373) in version 0.6.0 estimated August 2025