hyrax.data_sets.data_provider

Attributes

logger

Classes

DataProvider

This class presents itself as a PyTorch Dataset, but acts like a GraphQL

Functions

generate_data_request_from_config(config)

This function handles the backward compatibility issue of defining the requested

Module Contents

logger[source]
generate_data_request_from_config(config)[source]

This function handles the backward compatibility issue of defining the requested dataset in the [data_set] table in the config. If a [model_inputs] table is not defined, we will assemble a data_request dictionary from the values defined elsewhere in the configuration file.

NOTE: We should anticipate deprecating the ability to define a data_request in [data_set], when that happens, we should be able to remove this function.

Parameters:

config (dict) – The Hyrax configuration that can is passed to each dataset instance.

Returns:

A dictionary where keys are dataset names and values are lists of fields

Return type:

dict

class DataProvider(config: dict, request: dict)[source]

This class presents itself as a PyTorch Dataset, but acts like a GraphQL gateway that fetches data from multiple datasets based on the model_inputs dictionary provided during initialization.

This class allows for flexible data retrieval from multiple dataset classes, each of which can have different fields requested.

Additionally, the user can provide specific configuration options for each dataset class that will be merged with the original configuration provided during initialization.

Initialize the DataProvider with a Hyrax config and extract (or create) the data_request.

Parameters:
  • config (dict) – The Hyrax configuration that defines the data_request.

  • request (dict) – A dictionary that defines the data request.

config[source]
data_request[source]
prepped_datasets[source]
dataset_getters[source]
all_metadata_fields[source]
requested_fields[source]
custom_collate_functions[source]
primary_dataset = None[source]
primary_dataset_id_field_name = None[source]
pull_up_primary_dataset_methods()[source]

If a primary dataset is defined, we will pull up some of its methods to the DataProvider level so that they can be called directly on the DataProvider instance.

__getitem__(idx) dict[source]

This method returns data for a given index.

It is also a wrapper that allows this class to be treated as a PyTorch Dataset.

Parameters:

idx (int) – The index of the data item to retrieve.

Returns:

A dictionary containing the requested data from the prepared datasets.

Return type:

dict

__len__() int[source]

Returns the length of the dataset. If the primary dataset is defined, it will return that length, otherwise it will use the length of the first dataset in self.prepped_datasets.

__repr__() str[source]
fields() dict[source]

Print all the available fields for each dataset in the DataProvider.

Returns:

A dictionary mapping friendly dataset names to their available fields.

Return type:

dict

is_iterable()[source]

DataProvider datasets will always be map-style datasets.

is_map()[source]

DataProvider datasets will always be map-style datasets.

prepare_datasets()[source]

Instantiate each of the requested datasets based on the model_inputs configuration dictionary. Store the prepared instances in the self.prepped_datasets dictionary.

static _apply_configurations(base_config: dict, dataset_definition: dict) dict[source]

Merge the original base config with the dataset-specific config.

This function uses ConfigManager.merge_configs to merge the dataset-specific configuration into a copy of the original base config.

If no dataset_config is provided in the dataset_definition dict, the original base config will be returned unmodified.

Example of a dataset definition dictionary:

"my_dataset": {
    "dataset_class": "MyDataset",
    "data_location": "/path/to/data",
    "dataset_config": {
        "param1": "value1",
        "param2": "value2"
    },
    "fields": ["field1", "field2"]
}

or equivalently in a .toml file:

[model_inputs]
[model_inputs.my_dataset]
dataset_class = "MyDataset"
data_location = "/path/to/data"
fields = ["field1", "field2"]
[model_inputs.my_dataset.dataset_config]
param1 = "value1"
param2 = "value2"

In this example, the dataset_config dictionary will be merged into the original base config, overriding the values of param1 and param2 when creating an instance of MyDataset.

Parameters:
  • base_config (dict) – The original base configuration dictionary. A copy of this is created, the dataset_definition dict is merged into the copy, and the copy is returned.

  • dataset_definition (dict) – A dictionary defining the dataset, including any dataset-specific configuration options in a nested dataset_config dictionary.

Returns:

A final configuration dictionary to be passed when creating an instance of the dataset class.

Return type:

dict

sample_data() dict[source]

Returns a data sample. Primarily this will be used for instantiating a model so that any runtime resizing can be handled properly.

Returns:

A dictionary containing the data for index 0.

Return type:

dict

ids()[source]

Returns the IDs of the dataset.

If the primary dataset is defined it will return those ids, if not, it will return the ids of the first dataset in the list of prepped_dataset.keys().

resolve_data(idx: int) dict[source]

This method requests the field data from the prepared datasets by index.

Parameters:

idx (int) – The index of the data item to retrieve.

Returns:

A dictionary containing the requested data from the prepared datasets.

Return type:

dict

metadata(idxs=None, fields=None) numpy.ndarray[source]

Fetch the requested metadata fields for the given indices.

Example:

# Fetch the metadata_1 and metadata_2 fields from the dataset with the
# friendly name "random_1".

metadata = data_provider.metadata(
    idxs=[0, 1, 2],
    fields=["metadata_1_random_1", "metadata_2_random_1"]
)
Parameters:
  • idxs (list of int, optional) – A list of indices for which to fetch metadata. If None, no metadata will be returned.

  • fields (list of str, optional) – A list of metadata fields to fetch. If None, no metadata will be returned.

Returns:

A structured NumPy array containing the requested metadata fields. The dtype names of the array will be the metadata field names, modified to include the friendly name of the dataset they come from. For example, if the “RA” field comes from a dataset with the friendly name “cifar”, the returned field name will be “RA_cifar”.

Return type:

np.ndarray

metadata_fields(friendly_name=None) list[str][source]

Returns a list of metadata fields that are available across all prepared datasets.

The field names will be modified to include the friendly name of the dataset they come from. For example, if the “RA” field comes from a dataset with the friendly name “cifar”, the returned field name will be “RA_cifar”.

NOTE: If a specific dataset friendly_name is provided, only the metadata fields for that dataset will be returned, and the field names will not include the friendly name suffix.

Parameters:

friendly_name (str, optional) – If provided, only the metadata fields for the specified friendly name will be returned. If not provided, metadata fields from all datasets will be returned.

Returns:

The column names of the metadata table passed. Empty list if no metadata was provided during construction of the DataProvider.

Return type:

list[str]

_primary_or_first_dataset()[source]

Returns the primary dataset instance if it exists, otherwise returns the first dataset in the prepped_datasets.

collate(batch: list[dict]) dict[source]

Custom collate function to be used outside the context of a PyTorch DataLoader.

This function takes a list of data samples (each sample is a dictionary) and combines them into a single batch dictionary.

Parameters:

batch (list of dict) – A list of data samples, where each sample is a dictionary.

Returns:

A dictionary where each key corresponds to a field and the value is a list of values for that field across the batch.

Return type:

dict