hyrax.datasets.data_provider#
Attributes#
Bump when the cache format changes to invalidate all existing caches. |
Classes#
This class presents itself as a PyTorch Dataset, but acts like a GraphQL |
Functions#
|
The default _handle_nan function. Will print a warning and return batch. |
|
|
|
This is the tuple-specific implementation of _handle_nans. Each element |
|
|
|
|
|
|
|
Compute a lightweight fingerprint for a dataset's join-key column. |
|
Return the path where a join cache file would live, or |
|
Attempt to load a cached reverse-index map from disk. |
|
Persist a reverse-index map to disk. Failures are silently ignored |
This function extracts the data request from the configuration. |
Module Contents#
- _handle_nans(batch, config)[source]#
The default _handle_nan function. Will print a warning and return batch.
- _handle_nans_tuple(batch, config)[source]#
This is the tuple-specific implementation of _handle_nans. Each element of the tuple will have nan-handling applied. Non-numpy elements are returned unchanged.
- _JOIN_CACHE_VERSION = 1[source]#
Bump when the cache format changes to invalidate all existing caches.
- _join_cache_fingerprint(dataset_len: int, getter, *, n_probes: int = 8) str[source]#
Compute a lightweight fingerprint for a dataset’s join-key column.
The fingerprint incorporates the dataset length and a small number of deterministically sampled key values. If any of these change, the fingerprint changes and the cache is considered stale.
- Parameters:
dataset_len (int) – Number of items in the dataset.
getter (callable) – The
get_<join_field>method; called with integer indices.n_probes (int) – How many key values to sample. Kept small so the cost is negligible.
- _join_cache_path(data_location: str | None, fingerprint: str) pathlib.Path | None[source]#
Return the path where a join cache file would live, or
Noneif caching is not possible (e.g. nodata_location).
- _load_join_cache(data_location: str | None, dataset_len: int, getter) dict[str, int] | None[source]#
Attempt to load a cached reverse-index map from disk.
Returns
Noneon any cache miss (file absent, fingerprint mismatch, corrupt data, permission error, etc.).
- _save_join_cache(data_location: str | None, dataset_len: int, getter, reverse_map: dict[str, int]) None[source]#
Persist a reverse-index map to disk. Failures are silently ignored (caching is best-effort).
- generate_data_request_from_config(config)[source]#
This function extracts the data request from the configuration.
If [data_request] is not defined, an error will be raised.
- Parameters:
config (dict) – The Hyrax configuration that is passed to each dataset instance.
- Returns:
A dictionary where keys are dataset names and values are lists of fields
- Return type:
dict
- Raises:
RuntimeError – If data_request is not provided in the configuration.
- class DataProvider(config: dict, request: dict)[source]#
This class presents itself as a PyTorch Dataset, but acts like a GraphQL gateway that fetches data from multiple datasets based on the data_request dictionary provided during initialization.
This class allows for flexible data retrieval from multiple dataset classes, each of which can have different fields requested.
Additionally, the user can provide specific configuration options for each dataset class that will be merged with the original configuration provided during initialization.
Initialize the DataProvider with a Hyrax config and extract (or create) the data_request.
- Parameters:
config (dict) – The Hyrax configuration that defines the data_request.
request (dict) – A dictionary that defines the data request.
- pull_up_primary_dataset_methods()[source]#
If a primary dataset is defined, we will pull up some of its methods to the DataProvider level so that they can be called directly on the DataProvider instance.
- __getitem__(idx) dict[source]#
This method returns data for a given index.
It is also a wrapper that allows this class to be treated as a PyTorch Dataset.
- Parameters:
idx (int) – The index of the data item to retrieve.
- Returns:
A dictionary containing the requested data from the prepared datasets.
- Return type:
dict
- __len__() int[source]#
Returns the length of the dataset. If the primary dataset is defined, it will return that length, otherwise it will use the length of the first dataset in
self.prepped_datasets.
- fields() dict[source]#
Print all the available fields for each dataset in the DataProvider.
- Returns:
A dictionary mapping friendly dataset names to their available fields.
- Return type:
dict
- prepare_datasets()[source]#
Instantiate each of the requested datasets based on the
data_requestconfiguration dictionary. Store the prepared instances in theself.prepped_datasetsdictionary.
- _build_join_indices()[source]#
Build reverse-index mappings for datasets that declare a
join_field.For each joined secondary dataset, a dict
{str(key): int(index)}is built by iterating over all items in that dataset. At runtime,resolve_datauses these maps to translate primary object IDs to secondary indices (left outer join — unmatched primaries getNone).Reverse maps for independent secondaries are built in parallel using threads. Built maps are persisted to a cache file next to the dataset’s
data_location; a fingerprint check on subsequent runs avoids the O(N) rebuild.This method is called once during
__init__. Runtime lookups inresolve_dataremain O(1) dict access.
- static _apply_configurations(base_config: dict, dataset_definition: dict) dict[source]#
Merge the original base config with the dataset-specific config.
This function uses
ConfigManager.merge_configsto merge the dataset-specific configuration into a copy of the original base config.If no
dataset_configis provided in thedataset_definitiondict, the original base config will be returned unmodified.Data request dictionary examples:
Requesting a built-in Hyrax dataset, “MyDataset”
"my_dataset": { "dataset_class": "MyDataset", "data_location": "/path/to/data", "dataset_config": { "MyDataset": { "param1": "value1", "param2": "value2" } }, "fields": ["field1", "field2"] }
or equivalently in a .toml file:
[data_request] [data_request.my_dataset] dataset_class = "MyDataset" data_location = "/path/to/data" fields = ["field1", "field2"] [data_request.my_dataset.dataset_config.MyDataset] param1 = "value1" param2 = "value2"
Here the
dataset_configdictionary will be merged into the original base config, overriding the values of param1 and param2 when creating an instance ofMyDataset.2) Requesting an external dataset (not built-in), “ExternalDataset” Note that the dictionary nesting under “dataset_config” will match the dictionary structure in the external dataset’s default_config file.
"my_dataset": { "dataset_class": "ExternalDataset", "data_location": "/path/to/data", "dataset_config": { "external_example": { "ExternalDataset": { "param1": "value1", "param2": "value2" }, }, }, "fields": ["field1", "field2"] }
or equivalently in a .toml file:
[data_request] [data_request.my_dataset] dataset_class = "ExternalDataset" data_location = "/path/to/data" fields = ["field1", "field2"] [data_request.my_dataset.dataset_config.external_example.MyDataset] param1 = "value1" param2 = "value2"
Here the
dataset_configdictionary will be merged into the original base config, overriding the values of param1 and param2 when creating an instance ofExternalDataset.- Parameters:
base_config (dict) – The original base configuration dictionary. A copy of this is created, the dataset_definition dict is merged into the copy, and the copy is returned.
dataset_definition (dict) – A dictionary defining the dataset, including any dataset-specific configuration options in a nested
dataset_configdictionary.
- Returns:
A final configuration dictionary to be passed when creating an instance of the dataset class.
- Return type:
dict
- sample_data() dict[source]#
Returns a data sample. Primarily this will be used for instantiating a model so that any runtime resizing can be handled properly.
- Returns:
A dictionary containing the data for index 0.
- Return type:
dict
- get_object_id(idx: int) str[source]#
Returns the ID at a particular index.
IDs are provided by the primary dataset’s primary ID column.
- ids() list[str][source]#
Returns the IDs of the primary dataset.
- Returns:
A list of string IDs corresponding to the primary dataset, ordered by index.
- Return type:
list of str
- resolve_data(idx: int) dict[str, dict[str, Any] | str | None][source]#
This method requests the field data from the prepared datasets by index.
For joined secondary datasets (those with
join_field), the primary dataset’s object ID is looked up in the secondary’s reverse map. If a match exists the secondary’s data is returned normally; if no match exists the friendly-name entry is set toNone(left outer join).- Parameters:
idx (int) – The index of the data item to retrieve.
- Returns:
A dictionary containing the requested data from the prepared datasets. Each key is a dataset friendly name mapped to a dict of field values, or
Nonewhen a joined secondary has no match for this item. If a primary dataset is configured, the top-level"object_id"key holds a string representation of the primary ID.- Return type:
dict[str, dict[str, Any] | str | None]
- metadata(idxs=None, fields=None) numpy.ndarray[source]#
Fetch the requested metadata fields for the given indices.
Example:
# Fetch the metadata_1 and metadata_2 fields from the dataset with the # friendly name "random_1". metadata = data_provider.metadata( idxs=[0, 1, 2], fields=["metadata_1_random_1", "metadata_2_random_1"] )
- Parameters:
idxs (list of int, optional) – A list of indices for which to fetch metadata. If None, no metadata will be returned.
fields (list of str, optional) – A list of metadata fields to fetch. If None, no metadata will be returned.
- Returns:
A structured NumPy array containing the requested metadata fields. The dtype names of the array will be the metadata field names, modified to include the friendly name of the dataset they come from. For example, if the “RA” field comes from a dataset with the friendly name “cifar”, the returned field name will be “RA_cifar”.
- Return type:
np.ndarray
- metadata_fields(friendly_name=None) list[str][source]#
Returns a list of metadata fields that are available across all prepared datasets.
The field names will be modified to include the friendly name of the dataset they come from. For example, if the “RA” field comes from a dataset with the friendly name “cifar”, the returned field name will be “RA_cifar”.
NOTE: If a specific dataset friendly_name is provided, only the metadata fields for that dataset will be returned, and the field names will not include the friendly name suffix.
- Parameters:
friendly_name (str, optional) – If provided, only the metadata fields for the specified friendly name will be returned. If not provided, metadata fields from all datasets will be returned.
- Returns:
The column names of the metadata table passed. Empty list if no metadata was provided during construction of the DataProvider.
- Return type:
list[str]
- _translate_metadata_indices(idxs, friendly_name)[source]#
Translate primary indices to real dataset indices for metadata.
For joined secondaries, looks up the matching secondary index via the join map. Indices with no match in the secondary are omitted (the caller receives fewer rows than requested for that secondary).
Returns a tuple
(translated_idxs, mask)where mask is a boolean list of the same length as idxs indicating which positions had a valid match. Non-joined datasets always return a full-True mask.
- _primary_or_first_dataset()[source]#
Returns the primary dataset instance if it exists, otherwise returns the first dataset in the prepped_datasets.
- collate(batch: list[dict]) dict[source]#
Custom collate function to be used outside the context of a PyTorch DataLoader.
This function takes a list of data samples (each sample is a dictionary) and combines them into a single batch dictionary.
- Parameters:
batch (list of dict) – A list of data samples, where each sample is a dictionary.
- Returns:
A dictionary where each key corresponds to a field and the value is a list of values for that field across the batch.
- Return type:
dict
- handle_nans(batch_dict)[source]#
Apply nan handling to a batch dictionary
- Parameters:
batch_dict (dict[str, np.ndarray]) – Dictionary from data column to an entire batch of data in np.ndarray form
- Returns:
The same batch dict but with NaNs altered according to the Hyrax configuration.
- Return type:
dict[str, np.ndarray]