hyrax.pytorch_ignite
Attributes
Classes
Samples elements sequentially from a given list of indices, without replacement. |
|
Workaround event for a pytorch ignite bug. See fixup_engine for details |
Functions
|
This function checks each of the datasets included in the data_request. |
|
This function creates an instance of the requested dataset specified in the |
|
Create a model object based on the configuration. |
|
Load a collate function if one is specified in the config. Otherwise return None. |
|
Create Pytorch Ignite distributed data loaders |
|
Returns train, test, and validation indexes constructed to be used with the passed in |
|
The default _handle_nan function. Will print a warning and return batch. |
|
The implementation of _handle_nans when expecting batch to be a tensor. |
|
This is the tuple-specific implementation of _handle_nans. Each tensor element |
|
|
|
|
|
|
|
This wraps a model-specific function (func) to move data to the appropriate device. |
|
|
|
Unified creation of the pytorch engine object for either an evaluator or trainer. |
|
Extract a method from a model, which may be wrapped in a DistributedDataParallel |
|
Creates an evaluator engine |
|
This function creates a Pytorch Ignite engine object that will be used to |
|
This function is originally copied from here: |
|
Workaround for this pytorch ignite bug (https://github.com/pytorch/ignite/issues/3372) where |
Module Contents
- class SubsetSequentialSampler(indices: collections.abc.Sequence[int], generator=None)[source]
Bases:
torch.utils.data.Sampler[int]Samples elements sequentially from a given list of indices, without replacement.
- Parameters:
indices – sequence a sequence of indices
- is_iterable_dataset_requested(data_request: dict) bool[source]
This function checks each of the datasets included in the data_request. If any of them are iterable-style datasets, we return True.
- setup_dataset(config: dict, tensorboardx_logger: tensorboardX.SummaryWriter | None = None) torch.utils.data.Dataset[source]
This function creates an instance of the requested dataset specified in the runtime configuration. There are two modes encapsulated here:
1) If the dataset requested includes an iterable-style dataset, ensure that only one dataset was requested, and then return an instance of that dataset. 2) If the dataset(s) requested is for 1 or more map-style dataset, create an instance of a DataProvider, and return that as the dataset.
- Parameters:
config (dict) – The runtime configuration
tensorboardx_logger (SummaryWriter, optional) – If Tensorboard is in use, the tensorboard logger so the dataset can log things
- Returns:
An instance of the dataset class specified in the configuration
- Return type:
Dataset
- setup_model(config: dict, dataset: torch.utils.data.Dataset) torch.nn.Module[source]
Create a model object based on the configuration.
- Parameters:
config (dict) – The runtime configuration
dataset (Dataset) – The dataset object that will provide data to the model for training or inference. Here it is only used to provide a data sample to the model so that it can resize itself at runtime if necessary.
- Returns:
An instance of the model class specified in the configuration
- Return type:
torch.nn.Module
- load_collate_function(data_loader_kwargs: dict) Callable | None[source]
Load a collate function if one is specified in the config. Otherwise return None. Returning None will cause the DataLoader to use PyTorch’s default collate function.
- Parameters:
data_loader_kwargs (dict) – The configuration dictionary that will be passed as kwargs to the DataLoader
- Returns:
The collate function if specified, else None
- Return type:
Optional[Callable]
- dist_data_loader(dataset: torch.utils.data.Dataset, config: dict, split: str | list[str] | bool = False)[source]
Create Pytorch Ignite distributed data loaders
It is recommended that each verb needing dataloaders only call this function once.
- Parameters:
dataset (HyraxDataset) – A Hyrax dataset instance
config (dict) – Hyrax runtime configuration
split (Union[str, list[str]], Optional) – The name(s) of the split we want to use from the data set. If this is false or not passed, then a single data loader is returned that corresponds to the entire dataset.
- Returns:
Dataloader (or an ignite-wrapped equivalent) – This is the distributed dataloader, formed by calling ignite.distributed.auto_dataloader
For multiple splits, we return a dictionary where the keys are the names of the splits
and the value is either a Dataloader as described above or the value None if the split
was not configured.
If an iterable dataset is passed, we cannot create multiple splits with a pyTorch sampler object
so we return the same thing for all splits, which is a dataloader representing the entire iterable
- create_splits(data_set: torch.utils.data.Dataset, config: dict)[source]
Returns train, test, and validation indexes constructed to be used with the passed in dataset. The allocation of indexes in the underlying dataset to samplers depends on the data_set section of the config dict.
- Parameters:
data_set (Dataset) – The data set to use
config (dict) – Configuration that defines dataset splits
split (str) – Name of the split to use.
- _handle_nans(batch, config)[source]
The default _handle_nan function. Will print a warning and return batch.
- _handle_nans_tensor(batch, config)[source]
The implementation of _handle_nans when expecting batch to be a tensor.
- _handle_nans_tuple(batch, config)[source]
This is the tuple-specific implementation of _handle_nans. Each tensor element of the tuple will have nan-handling applied. Non-tensor elements are returned unchanged.
- _inner_loop(func, to_tensor, device, config, engine, batch)[source]
This wraps a model-specific function (func) to move data to the appropriate device.
- create_engine(funcname: str, device: torch.device, model: torch.nn.Module, config: dict) ignite.engine.Engine[source]
Unified creation of the pytorch engine object for either an evaluator or trainer.
This function will automatically unwrap a distributed model to find the necessary function, and construct the necessary functions to transfer data to the device on every batch, so model code can be the same no matter where the model is being run.
- Parameters:
funcname (str) – The function name on the model that we will call in the core of the engine loop, and be called once per batch
device (torch.device) – The device the engine will run the model on
model (torch.nn.Module) – The Model the engine will be using
config (dict) – The runtime config in use
- extract_model_method(model, method_name)[source]
Extract a method from a model, which may be wrapped in a DistributedDataParallel or DataParallel object. For instance, method_name could be train_step or forward.
- Parameters:
model (nn.Module, DistributedDataParallel, or DataParallel) – The model to extract the method from
method_name (str) – Name of the method to extract
- Returns:
The method extracted from the model
- Return type:
Callable
- create_evaluator(model: torch.nn.Module, save_function: Callable[[torch.Tensor, torch.Tensor], Any], config: dict) ignite.engine.Engine[source]
Creates an evaluator engine Primary purpose of this function is to attach the appropriate handlers to an evaluator engine
- Parameters:
model (torch.nn.Module) – The model to evaluate
save_function (Callable[[torch.Tensor], Any]) – A function which will receive Engine.state.output at the end of each iteration. The intent is for the results of evaluation to be saved.
config (dict) – The runtime config in use
- Returns:
Engine object which when run will evaluate the model.
- Return type:
pytorch-ignite.Engine
- create_validator(model: torch.nn.Module, config: dict, results_directory: pathlib.Path, tensorboardx_logger: tensorboardX.SummaryWriter, validation_data_loader: torch.utils.data.DataLoader, trainer: ignite.engine.Engine) ignite.engine.Engine[source]
This function creates a Pytorch Ignite engine object that will be used to validate the model.
- Parameters:
model (torch.nn.Module) – The model to train
config (dict) – Hyrax runtime configuration
results_directory (Path) – The directory where training results will be saved
tensorboardx_logger (SummaryWriter) – The tensorboard logger object
validation_data_loader (DataLoader) – The data loader for the validation data
trainer (Engine) – The engine object that will be used to train the model. We will use specific hooks in the trainer to determine when to run the validation engine.
- Returns:
Engine object that will be used to train the model.
- Return type:
pytorch-ignite.Engine
- create_trainer(model: torch.nn.Module, config: dict, results_directory: pathlib.Path, tensorboardx_logger: tensorboardX.SummaryWriter) ignite.engine.Engine[source]
This function is originally copied from here: https://github.com/pytorch-ignite/examples/blob/main/tutorials/intermediate/cifar10-distributed.py#L164
It was substantially trimmed down to make it easier to understand.
- Parameters:
model (torch.nn.Module) – The model to train
config (dict) – Hyrax runtime configuration
results_directory (Path) – The directory where training results will be saved
tensorboardx_logger (SummaryWriter) – The tensorboard logger object
- Returns:
Engine object that will be used to train the model.
- Return type:
pytorch-ignite.Engine
- class HyraxEvents(value: str, event_filter: Callable | None = None, name: str | None = None)[source]
Bases:
ignite.engine.EventEnumWorkaround event for a pytorch ignite bug. See fixup_engine for details
- fixup_engine(engine: ignite.engine.Engine) ignite.engine.Engine[source]
Workaround for this pytorch ignite bug (https://github.com/pytorch/ignite/issues/3372) where engine.state.output is not available at EPOCH_COMPLETED or later times (COMPLETED, etc)
We create a new event HYRAX_EPOCH_COMPLETED which triggers at ITERATION_COMPLETED, but only on the final iteration. This is just before the erronious state reset.
This hack relies on pytorch ignite internal state, but can be removed as soon as our fix is mainlined (https://github.com/pytorch/ignite/pull/3373) in version 0.6.0 estimated August 2025