hyrax.datasets

Contents

hyrax.datasets#

Hyrax has several built-in datasets that you can use for astronomical data. For many uses, these datasets can be configured out-of-the box for a given project.

FitsImageDataset is a generic container for fits image cutout data indexed by a user-provided catalog file. It attempts to cover common usage paradigms such as multiple images of the same object differentiated by telescope filter; however, extending the class as a custom dataset may be more well fit to advanced usage.

LSSTDataset Is a alpha-quality container for LSST cutout images, currently limited to deep_coadd type images, and restricted to run only on a Rubin observatory RSP environment where LSST Pipeline tools and a data butler with the appropriate images are available.

DownloadedLSSTDataset is a subclass of LSSTDataset that generates cutouts from the butler and saves them as .pt files on first access. On subsequent access, it loads the cutouts directly from these files, which can significantly speed up data loading times. It inherits from LSSTDataset to access the data butler and catalog functionality.

HSCDataset Works similarly to FitsImageDataset, but is specialized to Hyper Suprime-Cam (HSC) cutout images downloaded with the hyrax download verb. It contains additional integrity checks and is tightly integrated with the download and rebuild_manifest verbs. In future this class and the downloader may become a separate package.

HyraxCifarDataset gives access to the standard CIFAR10 labeled image dataset, automatically downloading the dataset if it is not present. This dataset is useful for testing hyrax and occasionally individual models, but it is not an astronomical dataset.

HyraxRandomDataset is a utility dataset that generates random data with a specific shape. This dataset makes it easy to test new models with simple random data. It is highly configurable such that it’s possible to simulate input data for models that are under development.

Each of these datasets can be used a starting point for a Custom Dataset by inheriting your custom dataset from e.g. FitsImageDataset, or you can make an entirely custom dataset following the dataset class reference and/or dataset class notebook example.

The remaining classes in this module exist primarily for Hyrax interface purposes:

InferenceDataset is a dataset class that represents an infer or umap result, and may be returned from those verbs to provide data access

HyraxDataset is a base class for all datasets in Hyrax and must be within the inheritance hierarchy of all custom datasets. It is not usable on its own, but provides various fall-back functionality to make custom datasets easier to write. See the dataset class reference and example notebook for more information.

Submodules#

Classes#

FitsImageDataset

Dataset for Fits Images, typically cutouts.

LSSTDataset

LSSTDataset: A dataset to access deep_coadd images from lsst pipelines

DownloadedLSSTDataset

DownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads

HSCDataset

Dataset for sets of HSC cutouts created by the fibad download command.

HyraxCifarDataset

Map style CIFAR 10 dataset for Hyrax

HyraxRandomDataset

This dataset is stand-in for a map-style dataset.

HyraxRandomDatasetBase

This is the base class for the random datasets provided by Hyrax.

InferenceDataset

This is a dataset class to represent the situations where we wish to treat the output of inference

ResultDataset

Reader for Lance-based inference results.

ResultDatasetWriter

Writer for Lance-based inference results.

HyraxDataset

How to make a hyrax dataset:

HyraxCSVDataset

A Hyrax Dataset for CSV files.

HyraxHATSDataset

Generic Hyrax dataset for HATS catalogs loaded through LSDB.

MultimodalUniverseDataset

Load a MultimodalUniverse dataset through Hugging Face datasets.

NestedPandasDataset

A minimal Hyrax wrapper around nested_pandas.read_parquet.

LanceDBDataset

A minimal Hyrax wrapper around a LanceDB table.

DataCache

DataCache tracks and manages a caching layer which can be used most effectively if the entirety of a

Functions#

create_results_writer(result_dir)

Create a writer for results (Lance format).

load_results_dataset(config[, results_dir, verb])

Load a results dataset, auto-detecting format.

Package Contents#

class FitsImageDataset(config: dict, data_location=None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset, hyrax.datasets.dataset_registry.HyraxImageDataset, torch.utils.data.Dataset

Dataset for Fits Images, typically cutouts.

__init__()[source]#

Initialize a FitsImageDataset

Most work is done in _init_from_path and functions it calls in order to allow subclasses to override behavior.

Parameters:
  • config (dict) – Nested configuration dictionary for hyrax

  • data_location (Optional[Union[Path, str]]) – The directory location of the data that this dataset class will access

_called_from_test = False#
_config#
data_location = None#
object_id_column_name#
filter_column_name#
filename_column_name#
_init_from_path(path: pathlib.Path | str)[source]#

__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan operations and will ultimately open and read the header info of every fits file in the given directory

Parameters:

path (Union[Path, str]) – Path or string specifying the directory path that is the root of all filenames in the catalog table

_set_crop_transform()[source]#

Returns the crop transform on the image

If overriden, subclass must: 1) set self.cutout_shape to a tuple of ints representing the size of the cutouts that will be returned at some point in the init flow.

  1. Update the crop tranform using self.set_crop_transform() from the HyraxImageDataset mixin

_read_filter_catalog(filter_catalog_path: pathlib.Path | None)[source]#
_parse_filter_catalog(table) None[source]#

Sets self.files by parsing the catalog.

Subclasses may override this function to control parsing of the table more directly, but the overriding class must create the files dict which has type dict[object_id -> dict[filter -> filename]] with object_id, filter, and filename all strings. In the case of no filter distinction, a single flag value may be used for the filter dict keys in the inner dicts.

Parameters:

table (Table) – The catalog we read in

_before_preload() None[source]#
_prepare_metadata()[source]#
shape() tuple[int, int, int][source]#

Shape of the individual cutouts this will give to a model

Returns:

Tuple describing the dimensions of the 3 dimensional tensor handed back to models The first index is the number of filters The second index is the width of each image The third index is the height of each image

Return type:

tuple[int,int,int]

__len__() int[source]#

Returns number of objects in this loader

Returns:

number of objects in this data loader

Return type:

int

get_object_id(idx: int) str[source]#

Get the object ID at the given index

Parameters:

idx (int) – Index of the object ID to return

Returns:

The object ID at the given index

Return type:

str

get_image(idx: int)[source]#

Get the image at the given index as a PyTorch Tensor.

Parameters:

idx (int) – Index of the image to return

Returns:

The image at the given index as a PyTorch Tensor.

Return type:

torch.Tensor

__getitem__(idx: int)[source]#
__contains__(object_id: str) bool[source]#

Allows you to do object_id in dataset queries. Used by testing code.

Parameters:

object_id (str) – The object ID you’d like to know if is in the dataset

Returns:

True of the object_id given is in the data set

Return type:

bool

_get_file(index: int) pathlib.Path[source]#

Private indexing method across all files.

Returns the file path corresponding to the given index.

The index is zero-based and defined in the same manner as the total order of _all_files() and _object_files() iterator. Useful if you have an np.array() or list built from _all_files() and you need to select an individual item.

Only valid after self.object_ids, self.files, self.path, and self.num_filters have been initialized in __init__

Parameters:

index (int) – Index, see above for order semantics

Returns:

The path to the file

Return type:

Path

_all_ids(log_every=None) collections.abc.Generator[str][source]#

Private read-only iterator over all object_ids that enforces a strict total order across objects. Will not work prior to self.files initialization in __init__

Yields:

Iterator[str] – Object IDs currently in the dataset

_all_files()[source]#

Private read-only iterator over all files that enforces a strict total order across objects and filters. Will not work prior to self.files, and self.path initialization in __init__

Yields:

Path – The path to the file.

_filter_filename(object_id)[source]#

Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files initialization in __init__

Yields:

filter_name, file name – The name of a filter and the file name for the fits file. The file name is relative to self.path

_object_files(object_id)[source]#

Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files, and self.path initialization in __init__

Yields:

Path – The path to the file.

_file_to_path(filename: str) pathlib.Path[source]#

Turns a filename into a full path suitable for open. Equivalent to:

Path(self.path) / Path(filename)

Parameters:

filename (str) – The filename string

Returns:

A full path that is openable.

Return type:

Path

_read_object_id(object_id: str)[source]#
_apply_transforms(data: list[numpy.typing.ArrayLike])[source]#
_load_tensor_for_cache(object_id: str)[source]#

Implementation of TensorCacheMixin abstract method.

_object_id_to_tensor(object_id: str)[source]#

Converts an object_id to a pytorch tensor with dimensions (self.num_filters, self.cutout_shape[0], self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the far corners of the image from (0,0).

The current implementation reads the files once the first time they are accessed, and then keeps them in a dict for future accesses.

Parameters:

object_id (str) – The object_id requested

Returns:

A tensor with dimension (self.num_filters, self.cutout_shape[0], self.cutout_shape[1])

Return type:

torch.Tensor

class LSSTDataset(config, data_location=None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset, hyrax.datasets.dataset_registry.HyraxImageDataset, torch.utils.data.Dataset

LSSTDataset: A dataset to access deep_coadd images from lsst pipelines via the butler. Must be run in an RSP.

__init__()[source]#

Initialize the dataset with either a HATS catalog or astropy table.

Config can specify either: - config[“data_set”][“hats_catalog”]: path to HATS catalog - config[“data_set”][“astropy_table”]: path to any file readable by Astropy Table

BANDS = ['u', 'g', 'r', 'i', 'z', 'y']#
object_id_autodetect_names = ['object_id', 'objectId']#
catalog#
sh_deg#
sw_deg#
oid_column_name#
_butler_available()[source]#
_get_butler_thread_safe()[source]#

Thread safe butler creation

This function ensures that there is one and only one butler created per thread and that threads always use their assigned butler.

This is necessary because child classes of this one use butlers, and butler objects are not safe for multithreaded access.

Returns:

The butler assigned to the current thread.

Return type:

butler

_detect_object_id_column_name()[source]#

Setup file naming strategy based on catalog columns.

_load_catalog(data_set_config)[source]#

Load the catalog from either a HATS catalog or an astropy table.

_load_hats_catalog(hats_path)[source]#

Load catalog from HATS format using LSDB.

_load_astropy_catalog(table_path)[source]#

Load catalog from astropy table format or pickled astropy table.

__len__()[source]#
get_image(idxs)[source]#

Get image cutouts for the given indices.

Parameters:

idxs (int or list of int) – The index or indices of the cutouts to retrieve.

Returns:

Single cutout tensor or list of cutout tensors.

Return type:

list or torch.Tensor

__getitem__(idxs)[source]#

Get default data fields for the this dataset.

Parameters:

idxs (int or list of int) – The index or indices of the cutouts to retrieve.

Returns:

A dictionary containing the default data fields.

Return type:

dict

_parse_box(patch, row)[source]#

Return a Box2I representing the desired cutout in pixel space, given a “row” of catalog data which includes the semi-height (sh) and semi-width (sw) in degrees desired for the cutout.

_parse_sphere_point(row)[source]#

Return a SpherePoint with the ra and deck given in the “row” of catalog data. Row must include the RA and dec as “ra” and “dec” columns respectively

_get_tract_patch(row)[source]#

Return (tractInfo, patchInfo) for a given row.

This function only returns the single principle tract and patch in the case of overlap.

_request_patch(tract_index, patch_index)[source]#

Request a patch from the butler. This will be a list of lsst.afw.image objects each corresponding to the configured bands

Uses functools.lru_cache for basic in-memory caching.

_fetch_single_cutout(row)[source]#

Make a single cutout, returning a torch tensor.

Does not handle edge-of-tract/patch type edge cases, will only work near center of a patch.

class DownloadedLSSTDataset(config, data_location)[source]#

Bases: hyrax.datasets.lsst_dataset.LSSTDataset

DownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads cutouts from the LSST butler, saving them as .pt files during first access. On subsequent accesses, it loads cutouts directly from these cached files.

This class also creates a manifest files with the shape of each cutout and the corresponding filename.

Public Methods:
download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False):

Download cutouts with parallel processing. Automatically resumes from previous progress. Use max_workers to control thread count, force_retry to re-attempt failed downloads.

manifest_stats():

Returns dict with download statistics: total, successful, failed, pending counts and manifest file path.

download_progress():

Returns detailed progress metrics including completion percentage and failure rates.

reset_failed_downloads():

Resets all failed download attempts to allow retry without force_retry flag. Returns count of reset entries.

save_manifest_now():

Forces immediate manifest save (normally saved periodically during downloads).

cache_info():

Returns LRU cache statistics for patch fetching performance monitoring.

clear_cache():

Clears the patch LRU cache to free memory.

Usage Example:

# Initialize Hyrax h = hyrax.Hyrax() a = h.prepare()

# Download all cutouts (resumes automatically) a.download_cutouts(max_workers=4) WARNING: The LRU Caching scheme is slightly complicated, so it is recommended to use the default max_workers=1 for the first download. Simply using more workers may not always speed up the download process.

# Check progress a.download_progress()

# Retry failed downloads a.download_cutouts(force_retry=True)

# Access cutouts (loads from cache) cutout = a[0] # Single cutout cutouts = a[0:10] # Multiple cutouts

File Organization: - Cutouts saved as: cutout_{object_id}.pt or cutout_{index:04d}.pt - Manifest saved as: manifest.fits (Astropy) or manifest.parquet (HATS) - All files stored in the data_location provided during initialization

__init__()[source]#

Initialize the dataset with either a HATS catalog or astropy table.

Config can specify either: - config[“data_set”][“hats_catalog”]: path to HATS catalog - config[“data_set”][“astropy_table”]: path to any file readable by Astropy Table

download_dir#
catalog_object_ids#
_manifest_lock#
_updates_since_save = 0#
_save_interval = 1000#
_band_failure_stats#
_band_failure_lock#
_manifest_filter_object_ids = None#
_catalog_to_manifest_index_map = None#
_manifest_to_catalog_index_map = None#
get_objectId(idx)[source]#

Get object ID for a given index based on naming strategy.

_setup_naming_strategy()[source]#

Setup file naming strategy based on catalog columns.

_initialize_manifest()[source]#

Create new manifest or load/merge with existing manifest, with band filtering validation.

The manifest is always an astropy Table with at least the following columns: cutout_shape: np.array of dimensions e.g. [3,150,150] filename: string containing the fits filename containing the tensor for the object downloaded_bands: string containing a comma separated list of the bands downloaded. Order is expected to be consistent between rows.

When this astropy table is loaded into memory, multiple sources are consulted. - The Manifest on the filesystem, which contains the source of truth for what files have been downloaded. If this is not found, it is created. - The bands given in the catalog passed in

_load_existing_manifest()[source]#

Load existing manifest file.

_update_manifest_from_catalog(existing_manifest)[source]#

Using object_id as a unique key, adds manifest entries to existing_manifest, using self.catalog as the source of any new objects.

self.catalog is not altered by this operation.

Entries in existing_manifest are not altered by this operation. New entries are added to the end of existing_manifest with a state indicating they have not been downloaded.

_build_catalog_to_manifest_index_map()[source]#

Build efficient mapping from catalog indices to manifest indices.

_add_manifest_columns_to_table(table)[source]#

Add cutout_shape, filename, and downloaded_bands columns to manifest.

_longest_object_id_idx()[source]#
_get_available_bands_from_manifest(manifest)[source]#

Get available bands by finding entries with complete band coverage.

Uses cutout_shape[0] to determine the expected number of bands, then finds entries where downloaded_bands has that many entries (i.e., complete downloads).

_setup_band_filtering(requested_bands, original_band_order)[source]#

Setup band filtering to extract only requested bands from cached cutouts.

_get_cutout_path_from_idx(idx)[source]#

Generate cutout file path for a given index.

This simply applies a pattern to the filename using the object_id column. No guarantees are made about the file itself.

_get_cutout_path_from_manifest(idx)[source]#

Get the cutout path by consulting the manifest

The download thread ensures that the filename is not written to the manifest until all the bands that we intend to download are downloaded.

This function is intended to be a thread safe way to get valid cutout paths. In the case where the file exists and is believed to be correctly downloaded you get a filename, but this will return None if there is some other issue.

Parameters:

idx (int) – The catalog index of the relevant cutout

Returns:

path to the cutout.

Return type:

Path

_update_manifest_entry(idx, cutout_shape=None, filename='Attempted', downloaded_bands=None)[source]#

Thread-safe manifest update with periodic saves.

Parameters:
  • idx – Index in the manifest

  • cutout_shape – Shape tuple of the cutout tensor, or None for failed downloads

  • filename – Basename of the saved file, or “Attempted” only when ALL bands fail

  • downloaded_bands – List of band names successfully downloaded in tensor order

_save_manifest()[source]#

Save manifest

_sync_manifest_with_filesystem()[source]#

Sync manifest with actual downloaded files on disk.

This updates the manifest to reflect what is on the filesystem. For existing cutouts this loads every file using torch.load

static _request_patch_cached(tract_index, patch_index, butler, skymap_name, bands_tuple)[source]#

Cached patch fetching using static method.

Static method means no ‘self’ in cache key, making it truly global. Thread-safe because each call creates its own Butler instance.

_fetch_single_cutout(row, idx=None, manifest_idx=None)[source]#

Fetch cutout, using saved cutout if available, with optional band filtering.

_fetch_cutout_with_cache(row)[source]#

Generate cutout using cached patch fetching with NaN filling for failed bands.

__len__()[source]#

Return length of current catalog, not the full manifest.

_get_manifest_index_for_catalog_index(catalog_idx)[source]#

Map catalog index to manifest index. None return indicates no such item in manifest.

get_image(idxs)[source]#

Fetch image cutout(s) for given index or indices, using caching and band filtering.

Parameters:#

idxs: int or slice or list

Index or indices to fetch.

Returns:#

torch.Tensor or list of torch.Tensor:

Single cutout tensor or list of cutout tensors.

__getitem__(idxs) dict[source]#

Modified to pass index for saving cutouts.

Parameters:#

idxs: int or slice or list

Index or indices to fetch.

Returns:#

dict:

Dictionary with key ‘data’ containing another dict of default data fields to return. Currently only ‘image’ is supported.

download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False)[source]#

Download cutouts using multiple threads with caching.

Parameters:
  • indices – List of indices to download, or None for all

  • sync_filesystem – Whether to sync manifest with existing files on disk

  • max_workers – Maximum number of worker threads, or None to use default

  • force_retry – Whether to retry previously failed downloads

_download_single_cutout(catalog_idx, manifest_idx)[source]#

Helper method to download a single cutout.

cache_info()[source]#

Get cache statistics.

clear_cache()[source]#

Clear the LRU cache.

manifest_stats()[source]#

Get manifest statistics including downloaded bands information.

band_filtering_info()[source]#

Get information about current band filtering configuration.

save_manifest_now()[source]#

Force immediate manifest save.

static _determine_numprocs_download()[source]#

Determine number of threads for downloading.

reset_failed_downloads()[source]#

Reset failed download attempts to allow retry.

download_progress()[source]#

Get detailed download progress information.

download_summary()[source]#

Get detailed download and band analysis, accounting for band filtering.

class HSCDataset(config: dict, data_location=None)[source]#

Bases: hyrax.datasets.fits_image_dataset.FitsImageDataset

Dataset for sets of HSC cutouts created by the fibad download command.

__init__()[source]#
_called_from_test = False#
filters_config#
_read_filter_catalog(filter_catalog_path: pathlib.Path | None)[source]#
_parse_filter_catalog(table) None[source]#

Sets self.files by parsing the catalog.

Subclasses may override this function to control parsing of the table more directly, but the overriding class must create the files dict which has type dict[object_id -> dict[filter -> filename]] with object_id, filter, and filename all strings. In the case of no filter distinction, a single flag value may be used for the filter dict keys in the inner dicts.

Parameters:

table (Table) – The catalog we read in

_set_crop_transform()[source]#

Returns the crop transform on the image

If overriden, subclass must: 1) set self.cutout_shape to a tuple of ints representing the size of the cutouts that will be returned at some point in the init flow.

  1. Update the crop tranform using self.set_crop_transform() from the HyraxImageDataset mixin

_before_preload()[source]#
_scan_file_names(filters: list[str] | None, filter_obj_ids: list[str] | None = None) hyrax.datasets.fits_image_dataset.files_dict[source]#

Class initialization helper

Parameters:
  • filters (list[str], Optional:) – List of filters that we should look for in the data corpus

  • filter_obj_ids (list[str], Optional:) – Filter the file scan to only file names which have the provided object IDs, skipping other files When not provided, all file names in the configured data directory that match the pattern from hyrax download are parsed.

Returns:

Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. Corresponds to self.files

Return type:

dict[str,dict[str,str]]

static _determine_numprocs() int[source]#
static _fixup_limit(nproc: int, res, est_limit, est_procs) int[source]#
_scan_file_dimensions() dim_dict[source]#
static _scan_file_dimension(processing_unit: tuple[str, list[str]]) tuple[str, list[tuple[int, int]]][source]#
static _fits_file_dims(filepath) tuple[int, int][source]#
_prune_objects(filters_ref: list[str], cutout_shape: tuple[int, int] | None)[source]#

Class initialization helper. Prunes objects from the list of objects.

  1. Removes any objects which do not have all the filters specified in filters_ref

  2. If a cutout_shape was provided in the constructor, prunes files that are too small for the chosen cutout size

This function deletes from self.files and self.dims via _prune_object

Parameters:
  • files (dict[str,dict[str,str]]) – Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. This is created by _scan_files()

  • filters_ref (list[str]) – List of the filter names

  • cutout_shape (tuple[int, int]) – Cutout shape tuple provided from constructor

_mark_for_prune(object_id, reason)[source]#
_prune_object(object_id, reason: str)[source]#
_check_file_dimensions() tuple[int, int][source]#

Class initialization helper. Find the maximal pixel size that all images can support

It is assumed that all the cutouts will be of very similar size; however, HSC’s cutout server does not return exactly the same number of pixels for every query, even when it is given the same angular spread for every cutout.

Machine learning models expect all images to be the same size.

This function warns on significant differences (>2px) on any dimension between the largest and smallest images.

Returns:

The minimum width and height in pixels of the entire dataset. In other words: the maximal image size in pixels that can be generated from ALL cutout images via cropping.

Return type:

tuple(int,int)

_rebuild_manifest(config)[source]#
__contains__(object_id: str) bool[source]#

Allows you to do object_id in dataset queries. Used by testing code.

Parameters:

object_id (str) – The object ID you’d like to know if is in the dataset

Returns:

True of the object_id given is in the data set

Return type:

bool

_all_files_full()[source]#

Private read-only iterator over all files that enforces a strict total order across objects and filters. Will not work prior to self.files, and self.path initialization in __init__

Yields:

Tuple[object_id, filter, filename, dim] – Members of this tuple are - The object_id as a string - The filter name as a string - The filename relative to self.path - A tuple containing the dimensions of the fits file in pixels.

_object_files(object_id)[source]#

Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files, and self.path initialization in __init__

Guaranteed to only return files that have filters in self.filters_ref.

Yields:

Path – The path to the file.

display(index)[source]#
class HyraxCifarDataset(config: dict, data_location: pathlib.Path = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

Map style CIFAR 10 dataset for Hyrax

This utilizes the CIFAR dataset from torchvision for retrieving the dataset.

__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = None#
training_data#
cifar#
id_width = 0#
get_image(idx)[source]#

Get the image at the given index as a NumPy array.

get_label(idx)[source]#

Get the label at the given index.

get_object_id(idx)[source]#

Get the object ID for the item as a string.

__len__()[source]#
class HyraxRandomDataset(config, data_location)[source]#

Bases: HyraxRandomDatasetBase, hyrax.datasets.dataset_registry.HyraxDataset, torch.utils.data.Dataset

This dataset is stand-in for a map-style dataset. It will produce random numpy arrays along with sequential numeric ids and, optionally, labels randomly selected from the provided list of possible labels.

__init__(config, data_location)[source]#

Initialize the dataset using the parameters defined in the configuration.

Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the ["data_set"]["HyraxRandomDataset"] table in the configuration:

  • size: The number of random data samples to produce.

  • shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).

  • seed: The random seed to use for reproducibility.

  • provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.

  • metadata_fields: A list of metadata field names. Used to create a metadata table with columns corresponding to each field name. All data is numeric.

  • number_invalid_values: The number of invalid values to insert into the data.

  • invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.

__getitem__(idx: int) dict[source]#

Get a data sample by index.

The returned dictionary will contain the following keys:

  • index: The index of the data sample.

  • object_id: The ID of the data sample.

  • image: The data sample as a numpy array.

  • label: The label of the data sample (if provided).

Parameters:

idx (int) – The index of the data sample to retrieve.

Returns:

A dictionary containing the data sample and its metadata.

Return type:

dict

__len__()[source]#

Get the total number of samples in this dataset. This should be return the same value as the size parameter in the configuration.

class HyraxRandomDatasetBase(config, data_location)[source]#

This is the base class for the random datasets provided by Hyrax.

Warning

Direct use of HyraxRandomDatasetBase is not advised. When working with Hyrax, prefer to use HyraxRandomDataset.

__init__(config, data_location)[source]#

Initialize the dataset using the parameters defined in the configuration.

Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the ["data_set"]["HyraxRandomDataset"] table in the configuration:

  • size: The number of random data samples to produce.

  • shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).

  • seed: The random seed to use for reproducibility.

  • provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.

  • metadata_fields: A list of metadata field names. Used to create a metadata table with columns corresponding to each field name. All data is numeric.

  • number_invalid_values: The number of invalid values to insert into the data.

  • invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.

data: numpy.ndarray#

The random data samples produced by the dataset.

id_list: list#

A list of sequential numeric IDs for each data sample.

provided_labels: list#

A list of labels randomly selected from the provided list of possible labels.

data_location#
get_image(idx: int) numpy.ndarray[source]#

Get the image at the given index as a NumPy array.

get_label(idx: int) str[source]#

Get the label at the given index.

get_object_id(idx: int) str[source]#

Get the index of the item.

class InferenceDataset(config, results_dir: pathlib.Path | str | None = None, verb: str | None = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset, torch.utils.data.Dataset

This is a dataset class to represent the situations where we wish to treat the output of inference as a dataset. e.g. when performing umap/visualization operations

Initialize an InferenceDataset object.

As a user of this code, you should almost never create this class, Instances of this class are returned by the umap and infer verbs. Prefer those over creating your own.

If you do end up creating your own class, you will need a hyrax config, and to know some things about where the result you are interested in is stored.

Parameters:
  • config (dict) – The hyrax config dictionary

  • results_dir (Optional[Union[Path, str]], optional) –

    The results subdirectory of the inference or umap results you want to access, by default None. If no results subdirectory is provided, this function will attempt the following in order:

    1. Use the directory specified in config['results']['inference_dir'] if set and the directory exists

    2. Look in the results configured in config['general']['results_dir'] (./results/ by default), then use the most recent results directory corresponding to the verb specified.

  • verb (Optional[str], optional) – The name of the verb that generated the results, only important when the most recent results are being fetched. If no verb is provided, “infer” will be assumed.

Raises:

RuntimeError – When the provided results directory is corrupt, or cannot be found.

results_dir#
batch_index#
length#
cached_batch_num: int | None = None#
shape_element#
_original_dataset_config#
original_dataset#
_shape()[source]#

The shape of the dataset (Discovered from files)

Returns:

Tuple with the shape of an individual element of the dataset

Return type:

Tuple

get_object_id(idx) str[source]#

Returns the ID at a particular index.

IDs are provided by the primary dataset’s primary ID column.

ids() list[str][source]#

Returns the IDs of the dataset.

IDs flow from the primary dataset and the primary ID column.

For an InferenceDataset instance, self.ids() is canonically the same as [self.get_object_id(i) for i in range(len(self))].

_ids() collections.abc.Generator[str][source]#

IDs of this dataset. Will return a string generator with IDs.

These IDs are the IDs of the dataset used originally to generate this dataset.

Returns:

Generator that yields the string ids of this dataset

Return type:

Generator[str]

Yields:

Generator[str] – Yields the string ids of this dataset

__getitem__(idx: int | numpy.ndarray)[source]#

Implements the [] operator

Parameters:

idx (Union[int, np.ndarray]) – Either an index or a numpy array of indexes. These are NOT the ID values of the dataset, but rather a zero-based index starting at the beginning of the inference dataset.

Returns:

Either the tensor corresponding to a single result, or a tensor with a multiplicity of results if multiple indexes were passed.

Return type:

torch.tensor

__len__() int[source]#

Returns the length of the dataset.

Returns:

Length of the dataset.

Return type:

int

property original_config: dict#

Get the original configuration for the dataset used to generate this inference dataset

Since this sort of dataset is definitionally an intermediate product, this returns the runtime config used to construct that dataset rather than this one.

Returns:

Configuration that can be used to create the original dataset that was used as input for whatever inference process created this dataset.

Return type:

dict

metadata_fields() list[str][source]#

Get the metadata fields associted with the original dataset used to generate this one

Returns:

List of valid field names for metadata queries

Return type:

list[str]

metadata(idxs: numpy.typing.ArrayLike, fields: list[str]) numpy.typing.ArrayLike[source]#

Get metadata associated with the data in the InferenceDataset. This metadata comes from the original dataset, but is indexed according to the InferenceDataset.

Parameters:
  • idxs (npt.ArrayLike) – Indexes in the InferenceDataset for which metadata is desired

  • fields (list[str]) – Metadata fields requested

Returns:

An array where the rows correspond to the passed list of indexes and the columns correspond to the fields passed. Order is preserved- metadata[i] corresponds to idxs[i].

Return type:

npt.ArrayLike

_load_from_batch_file(batch_num: int, ids=Union[int, np.ndarray]) numpy.ndarray[source]#

Hands back an array of tensors given a set of IDs in a particular batch and the given batch number

class ResultDataset(config: dict, data_location: pathlib.Path | str)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

Reader for Lance-based inference results.

Provides HyraxQL-compatible getters to results stored in Lance format.

Initialize the dataset.

Parameters:
  • config (dict) – Hyrax configuration dictionary

  • data_location (Union[Path, str]) – Path to results directory containing lance_db/

data_location#
lance_dir#
db#
table#
lance_dataset#
tensor_shape#
tensor_dtype#
__len__() int[source]#

Return the number of records in the dataset.

__getitem__(idx: int | numpy.ndarray)[source]#

Get data by index.

Parameters:

idx (Union[int, np.ndarray]) – Single index or array of indices

Returns:

Data tensor(s)

Return type:

np.ndarray

Raises:

IndexError – If index is out of range

__get_all__()[source]#

Get all data tensors in the dataset.

This is a specialized method that is meant for internal use (e.g. visualize_v2). It retrieves all tensors efficiently by assuming column names and accessing the array buffer directly, without creating Python objects for each row.

Returns:

All data tensors

Return type:

np.ndarray

get_data(idx: int)[source]#

Get data tensor at index (HyraxQL getter).

Parameters:

idx (int) – Index of the data item

Returns:

Data tensor

Return type:

np.ndarray

get_object_id(idx: int) str[source]#

Get object ID at index (HyraxQL getter).

Parameters:

idx (int) – Index of the data item

Returns:

Object ID

Return type:

str

ids() list[str][source]#

Generate all object IDs.

Returns:

Object IDs in order

Return type:

list[str]

class ResultDatasetWriter(result_dir: str | pathlib.Path)[source]#

Writer for Lance-based inference results.

Writes inference results incrementally to Lance format using table.add() for each batch, avoiding memory accumulation.

Initialize the writer.

Parameters:

result_dir (Union[str, Path]) – Directory where Lance database will be created

result_dir#
lance_dir#
db = None#
table = None#
schema = None#
tensor_dtype = None#
tensor_shape = None#
batch_count = 0#
write_batch(object_ids: numpy.ndarray, data: list[numpy.ndarray])[source]#

Write a batch of results incrementally.

Parameters:
  • object_ids (np.ndarray) – Array of object IDs (will be converted to strings)

  • data (list[np.ndarray]) – List of numpy arrays (tensors) to write

commit()[source]#

Finalize the write by optimizing the table.

_create_schema(sample_tensor: numpy.ndarray)[source]#

Create PyArrow schema with tensor metadata.

Parameters:

sample_tensor (np.ndarray) – Sample tensor to determine dtype and shape

create_results_writer(result_dir: str | pathlib.Path)[source]#

Create a writer for results (Lance format).

This factory creates a ResultDatasetWriter for writing inference results to Lance format. New writes always use Lance format going forward.

Parameters:

result_dir (Union[str, Path]) – Directory where results should be saved

Returns:

Writer instance for Lance storage

Return type:

ResultDatasetWriter

load_results_dataset(config: dict, results_dir: pathlib.Path | str | None = None, verb: str | None = None)[source]#

Load a results dataset, auto-detecting format.

This factory auto-detects whether the results are in Lance or .npy format and returns the appropriate dataset class.

Parameters:
  • config (dict) – The hyrax config dictionary

  • results_dir (Union[Path, str, None], optional) – The results subdirectory to load from

  • verb (Union[str, None], optional) – The name of the verb that generated the results (for auto-discovery)

Returns:

The appropriate dataset instance based on detected format

Return type:

Union[ResultDataset, InferenceDataset]

class HyraxDataset(config: dict, metadata_table=None, object_id_column_name=None)[source]#

How to make a hyrax dataset:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(self, config: dict):
        super().__init__(config)

    def __len__(self):
        # Your len function goes here
        pass

Optional interfaces:

metadata -> Subclasses may pass an astropy table of metadata to __init__ in the superclass. This table of metadata will be available through the metadata_fields and metadata functions. If desired, a subclass may override these functions directly rather than using the astropy Table interface.

Further documentation is in the Build a dataset class in a notebook example notebook.

__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

_config#
_metadata_table = None#
property config#
classmethod __init_subclass__()[source]#
metadata_fields() list[str][source]#

Returns a list of metadata fields supported by this object

Returns:

The column names of the metadata table passed. Empty string if no metadata was provided at during construction of the HyraxDataset (or derived class).

Return type:

list[str]

metadata(idxs: numpy.typing.ArrayLike, fields: list[str]) numpy.typing.ArrayLike[source]#

Returns a table representing the metadata given an array of indexes and a list of fields.

Parameters:
  • idxs (npt.ArrayLike) – The indexes of the relevant tensor objects

  • fields (list[str]) – The names of the fields you would like returned. All values must be among those returned by metadata_fields()

Returns:

A numpy record array of your metadata, with only the columns specified. Roughly equivalent to: metadata_table[idxs][fields].as_array() where metadata_table is the astropy table that the HyraxDataset (or derived class) was constructed with.

Return type:

npt.ArrayLike

Raises:

RuntimeError – When none of the provided fields are

class HyraxCSVDataset(config: dict, data_location: pathlib.Path = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

A Hyrax Dataset for CSV files.

This class reads a CSV file using pandas with memory mapping enabled. It dynamically creates getter methods for each column in the CSV file, allowing users to request data from specific columns.

Note

Column names found in the CSV file are used to create the getter methods. If a column name contains characters that are invalid for method names, those characters are replaced with underscores.

Examples

Example data_request configuration:

{
    "train": {
        "data": {
            "dataset_class": "HyraxCSVDataset",
            "data_location": "</path/to/data.csv>",
            "fields": ["<column1>", "<column2>", ...],
            "primary_id_field": "<column name that contains a unique ID>",
        },
    },
    "validate": { "<similar to above>" },
    "infer": { "<similar to above>" },
}
__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = None#
column_names#
mem_mapped_csv = None#
__getitem__(idx)[source]#

Currently required by Hyrax machinery, but likely to be phased out.

__len__() int[source]#

Return the number of records in the CSV.

sample_data()[source]#

Return the first record, in dictionary form, as the sample.

class HyraxHATSDataset(config: dict, data_location: pathlib.Path = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

Generic Hyrax dataset for HATS catalogs loaded through LSDB.

Notes

This phase-1 implementation materializes the LSDB catalog to a pandas DataFrame and dynamically creates get_<column> methods for requested columns.

__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = None#
dataframe#
column_names#
_requested_columns_from_config(config: dict) list[str][source]#
_open_catalog_kwargs_from_config(config: dict) dict[source]#
__len__() int[source]#
class MultimodalUniverseDataset(config: dict, data_location: pathlib.Path | str | None = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

Load a MultimodalUniverse dataset through Hugging Face datasets.

This dataset class is intentionally generic so one configuration pattern can be used for image, spectra, and time-series MMU datasets.

Examples

Example data_request configuration:

{
    "infer": {
        "mmu": {
            "dataset_class": "MultimodalUniverseDataset",
            "data_location": "hf://MultimodalUniverse/plasticc",
            "primary_id_field": "object_id",
            "dataset_config": {
                "MultimodalUniverseDataset": {
                    "split": "train",
                    "max_samples": 32,
                }
            },
        }
    }
}
__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = ''#
split#
max_samples#
streaming#
dataset#
_column_name_map#
_normalize_data_location(data_location: str) str[source]#
_load_dataset(dataset_source: str)[source]#
_limit_non_streaming_dataset(dataset: Any, max_samples: int)[source]#
_build_column_name_map() dict[str, str][source]#

Returns a map from sanitized column names to the original column names.

It’s possible for a column name to have punctuation or start with a number. In these cases we also allow column access via a sanitized name where all punctuation is replaced with the underscore character, and any field starting with a number is replaced by field_

Every field is entered in the dictionary regardless of whether it needed sanitization or not. In this case the sanitized name is exactly the field name.

_sanitize_name(column_name: str) str[source]#

Take a column name that may contain punctuation and return a version with underscore replacing the punctuation

_register_getters() None[source]#
__len__() int[source]#
class NestedPandasDataset(config: dict, data_location: pathlib.Path | str | None = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

A minimal Hyrax wrapper around nested_pandas.read_parquet.

__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = ''#
read_kwargs#
nested_frame#
_load_nested_frame(read_kwargs: dict)[source]#
_all_available_fields() list[str][source]#
_register_getters() None[source]#
__len__() int[source]#
class LanceDBDataset(config: dict, data_location: pathlib.Path | str | None = None)[source]#

Bases: hyrax.datasets.dataset_registry.HyraxDataset

A minimal Hyrax wrapper around a LanceDB table.

__init__()[source]#

Overall initialization for all Datasets which saves the config

Subclasses of HyraxDataset ought call this at the end of their __init__ like:

from hyrax.datasets import HyraxDataset

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.datasets import HyraxDataset
from astropy.table import Table

class MyDataset(HyraxDataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = ''#
table_name#
connect_kwargs#
open_table_kwargs#
db#
table#
lance_dataset#
_row_cache: collections.OrderedDict#
_all_available_fields() list[str][source]#
_get_row(idx: int)[source]#

Return the PyArrow record-batch for idx, using a small FIFO row cache.

Caching avoids redundant lance_dataset.take calls when multiple get_<field> accessors are invoked for the same sample index, which is the common pattern when DataProvider resolves all fields for a single item. The cache holds at most _ROW_CACHE_SIZE rows; the oldest entry is evicted once that limit is reached.

_resolve_table_name(configured_table_name) str[source]#
_register_getters() None[source]#
__len__() int[source]#
class DataCache(config, data_provider: hyrax.datasets.data_provider.DataProvider)[source]#

DataCache tracks and manages a caching layer which can be used most effectively if the entirety of a training (or inference) epoch fits in system RAM.

Two configs control this functionality:

h.config[“data_set”][“use_cache”] which determines if we are serving data dictionaries out of a cache. When set, the first epoch of training fills the cache with tensors, and subsequent epochs are served out of the cache.

h.config[“data_set”][“preload_cache”] starts a thread which iterates over the dataset/dataloader class to completion. The thread pre-loads the cache with tensors independently of the training process. The hope is that this thread proceeds faster than the first epoch of training and speeds up the first epoch as well.

In this class we cache the output of DataProvider, before being batched. Users can control the size of data cached by only selecting particular fields in their data_request specification.

The class logs to the tensorboard logger in the DataProvider (when configured).

Initialize the DataCache with a Hyrax config.

Parameters:
  • config (dict) – The Hyrax configuration that defines the data_request.

  • data_provider (DataProvider) – The DataProvider object which we are caching for.

_max_length#
_resolve_data_func#
_data_provider#
_use_cache#
_preload_cache#
_data_size_bytes = 0#
_insert_count = 0#
logging_interval = 1000#
_cache_map#
_preload_thread = None#
start_preload_thread()[source]#

Start the cache preload thread if configured

This exists to separate initialization from thread start in DataProvider’s constructor, so the thread started can always count on a fully initialized DataProvider.

_idx_check(idx)[source]#
try_fetch(idx: int) dict | None[source]#

Try to fetch a data_dict from the cache.

Parameters:

idx (int) – The DataProvider index of the data dict

Returns:

The data dict from the cache, None on a cache miss.

Return type:

Optional[dict]

insert_into_cache(idx: int, data: dict[str, dict[str, Any]])[source]#

Insert a data dict into the cache

Parameters:
  • idx (int) – Index of the data dict

  • data (dict[str, dict[str, Any]]) – The data dict

static _data_size(data, seen: set[int] | None = None) int[source]#
_preload_tensor_cache()[source]#

Preload all tensors in the dataset using multiple threads.

_lazy_map_executor(executor: concurrent.futures.Executor, idxs: collections.abc.Iterable[int])[source]#

Lazy evaluation version of concurrent.futures.Executor.map().

This limits memory usage during preloading by keeping only a small number of data dictionaries in memory at once.

Parameters:
  • executor (concurrent.futures.Executor) – An executor for running futures

  • idxs (Iterable[int]) – An iterable list of DataProvider indexes

Yields:

Iterator[torch.Tensor] – An iterator over torch tensors, lazily loaded