hyrax.data_sets

Hyrax has several built-in datasets that you can use for astronomical data. For many uses, these datasets can be configured out-of-the box for a given project.

FitsImageDataSet is a generic container for fits image cutout data indexed by a user-provided catalog file. It attempts to cover common usage paradigms such as multiple images of the same object differentiated by telescope filter; however, extending the class as a custom dataset may be more well fit to advanced usage.

LSSTDataset Is a alpha-quality container for LSST cutout images, currently limited to deep_coadd type images, and restricted to run only on a Rubin observatory RSP environment where LSST Pipeline tools and a data butler with the appropriate images are available.

DownloadedLSSTDataset is a subclass of LSSTDataset that generates cutouts from the butler and saves them as .pt files on first access. On subsequent access, it loads the cutouts directly from these files, which can significantly speed up data loading times. It inherits from LSSTDataset to access the data butler and catalog functionality.

HSCDataSet Works similarly to FitsImageDataSet, but is specialized to Hyper Suprime-Cam (HSC) cutout images downloaded with the hyrax download verb. It contains additional integrity checks and is tightly integrated with the download and rebuild_manifest verbs. In future this class and the downloader may become a separate package.

HyraxCifarDataSet and HyraxCifarIterableDataSet give access to the standard CIFAR10 labeled image dataset, automatically downloading the dataset if it is not present. These datasets are useful for testing hyrax and occasionally individual models, but they are not astronomical datasets.

HyraxRandomDataset and HyraxRandomIterableDataset are utility datasets that generate random data with a specific shape. These datasets make it easy to test new models with simple random data. They are highly configurable such that it’s possible to simulate input data for models that are under development.

Each of these datasets can be used a starting point for a Custom Dataset by inheriting your custom dataset from e.g. FitsImageDataSet, or you can make an entirely custom dataset following the custom dataset instructions and/or custom dataset example notebook.

The remaining classes in this module exist primarily for Hyrax interface purposes:

InferenceDataset is a dataset class that represents an infer or umap result, and may be returned from those verbs to provide data access

HyraxDataset is a base class for all datasets in Hyrax and must be within the inheretence hierarchy of all custom datasets. It is not usable on it’s own, but provides various fall-back functionality to make custom datasets easier to write. See the custom dataset instructions and example notebook for more information.

Submodules

Classes

FitsImageDataSet

Dataset for Fits Images, typically cutouts.

LSSTDataset

LSSTDataset: A dataset to access deep_coadd images from lsst pipelines

DownloadedLSSTDataset

DownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads

HSCDataSet

Dataset for sets of HSC cutouts created by the fibad download command.

HyraxCifarDataSet

Map style CIFAR 10 dataset for Hyrax

HyraxCifarIterableDataSet

Iterable style CIFAR 10 dataset for Hyrax

HyraxRandomDataset

This dataset is stand-in for a map-style dataset.

HyraxRandomIterableDataset

This dataset is stand-in for a iterable-style, or streaming, dataset.

HyraxRandomDatasetBase

This is the base class for the random datasets provided by Hyrax.

InferenceDataSet

This is a dataset class to represent the situations where we wish to treat the output of inference

HyraxDataset

How to make a hyrax dataset:

HyraxCifarBase

Base class for Hyrax Cifar datasets

HyraxCSVDataset

A Hyrax Dataset for CSV files.

Package Contents

class FitsImageDataSet(config: dict, data_location=None)[source]

Bases: hyrax.data_sets.data_set_registry.HyraxDataset, hyrax.data_sets.data_set_registry.HyraxImageDataset, hyrax.data_sets.tensor_cache_mixin.TensorCacheMixin, torch.utils.data.Dataset

Dataset for Fits Images, typically cutouts.

__init__()[source]

Initialize a FitsImageDataSet

Most work is done in _init_from_path and functions it calls in order to allow subclasses to override behavior.

Parameters:
  • config (dict) – Nested configuration dictionary for hyrax

  • data_location (Optional[Union[Path, str]]) – The directory location of the data that this dataset class will access

_called_from_test = False
_config
object_id_column_name
filter_column_name
filename_column_name
_init_from_path(path: pathlib.Path | str)[source]

__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan operations and will ultimately open and read the header info of every fits file in the given directory

Parameters:

path (Union[Path, str]) – Path or string specifying the directory path that is the root of all filenames in the catalog table

_set_crop_transform()[source]

Returns the crop transform on the image

If overriden, subclass must: 1) set self.cutout_shape to a tuple of ints representing the size of the cutouts that will be returned at some point in the init flow.

  1. Update the crop tranform using self.set_crop_transform() from the HyraxImageDataset mixin

_read_filter_catalog(filter_catalog_path: pathlib.Path | None)[source]
_parse_filter_catalog(table) None[source]

Sets self.files by parsing the catalog.

Subclasses may override this function to control parsing of the table more directly, but the overriding class must create the files dict which has type dict[object_id -> dict[filter -> filename]] with object_id, filter, and filename all strings. In the case of no filter distinction, a single flag value may be used for the filter dict keys in the inner dicts.

Parameters:

table (Table) – The catalog we read in

_before_preload() None[source]
_prepare_metadata()[source]
shape() tuple[int, int, int][source]

Shape of the individual cutouts this will give to a model

Returns:

Tuple describing the dimensions of the 3 dimensional tensor handed back to models The first index is the number of filters The second index is the width of each image The third index is the height of each image

Return type:

tuple[int,int,int]

__len__() int[source]

Returns number of objects in this loader

Returns:

number of objects in this data loader

Return type:

int

get_object_id(idx: int) str[source]

Get the object ID at the given index

Parameters:

idx (int) – Index of the object ID to return

Returns:

The object ID at the given index

Return type:

str

get_image(idx: int)[source]

Get the image at the given index as a PyTorch Tensor.

Parameters:

idx (int) – Index of the image to return

Returns:

The image at the given index as a PyTorch Tensor.

Return type:

torch.Tensor

__getitem__(idx: int)[source]
__contains__(object_id: str) bool[source]

Allows you to do object_id in dataset queries. Used by testing code.

Parameters:

object_id (str) – The object ID you’d like to know if is in the dataset

Returns:

True of the object_id given is in the data set

Return type:

bool

_get_file(index: int) pathlib.Path[source]

Private indexing method across all files.

Returns the file path corresponding to the given index.

The index is zero-based and defined in the same manner as the total order of _all_files() and _object_files() iterator. Useful if you have an np.array() or list built from _all_files() and you need to select an individual item.

Only valid after self.object_ids, self.files, self.path, and self.num_filters have been initialized in __init__

Parameters:

index (int) – Index, see above for order semantics

Returns:

The path to the file

Return type:

Path

ids(log_every=None) collections.abc.Generator[str][source]

Public read-only iterator over all object_ids that enforces a strict total order across objects. Will not work prior to self.files initialization in __init__

Yields:

Iterator[str] – Object IDs currently in the dataset

_all_files()[source]

Private read-only iterator over all files that enforces a strict total order across objects and filters. Will not work prior to self.files, and self.path initialization in __init__

Yields:

Path – The path to the file.

_filter_filename(object_id)[source]

Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files initialization in __init__

Yields:

filter_name, file name – The name of a filter and the file name for the fits file. The file name is relative to self.path

_object_files(object_id)[source]

Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files, and self.path initialization in __init__

Yields:

Path – The path to the file.

_file_to_path(filename: str) pathlib.Path[source]

Turns a filename into a full path suitable for open. Equivalent to:

Path(self.path) / Path(filename)

Parameters:

filename (str) – The filename string

Returns:

A full path that is openable.

Return type:

Path

_read_object_id(object_id: str)[source]
_convert_to_torch(data: list[numpy.typing.ArrayLike])[source]
_load_tensor_for_cache(object_id: str)[source]

Implementation of TensorCacheMixin abstract method.

_object_id_to_tensor(object_id: str)[source]

Converts an object_id to a pytorch tensor with dimensions (self.num_filters, self.cutout_shape[0], self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the far corners of the image from (0,0).

The current implementation reads the files once the first time they are accessed, and then keeps them in a dict for future accesses.

Parameters:

object_id (str) – The object_id requested

Returns:

A tensor with dimension (self.num_filters, self.cutout_shape[0], self.cutout_shape[1])

Return type:

torch.Tensor

class LSSTDataset(config, data_location)[source]

Bases: hyrax.data_sets.data_set_registry.HyraxDataset, hyrax.data_sets.data_set_registry.HyraxImageDataset, torch.utils.data.Dataset

LSSTDataset: A dataset to access deep_coadd images from lsst pipelines via the butler. Must be run in an RSP.

__init__()[source]

Initialize the dataset with either a HATS catalog or astropy table.

Config can specify either: - config[“data_set”][“hats_catalog”]: path to HATS catalog - config[“data_set”][“astropy_table”]: path to any file readable by Astropy Table

BANDS = ['u', 'g', 'r', 'i', 'z', 'y']
catalog
sh_deg
sw_deg
_load_catalog(data_set_config)[source]

Load the catalog from either a HATS catalog or an astropy table.

_load_hats_catalog(hats_path)[source]

Load catalog from HATS format using LSDB.

_load_astropy_catalog(table_path)[source]

Load catalog from astropy table format or pickled astropy table.

__len__()[source]
get_image(idxs)[source]

Get image cutouts for the given indices.

Parameters:

idxs (int or list of int) – The index or indices of the cutouts to retrieve.

Returns:

Single cutout tensor or list of cutout tensors.

Return type:

list or torch.Tensor

__getitem__(idxs)[source]

Get default data fields for the this dataset.

Parameters:

idxs (int or list of int) – The index or indices of the cutouts to retrieve.

Returns:

A dictionary containing the default data fields.

Return type:

dict

_parse_box(patch, row)[source]

Return a Box2I representing the desired cutout in pixel space, given a “row” of catalog data which includes the semi-height (sh) and semi-width (sw) in degrees desired for the cutout.

_parse_sphere_point(row)[source]

Return a SpherePoint with the ra and deck given in the “row” of catalog data. Row must include the RA and dec as “ra” and “dec” columns respectively

_get_tract_patch(row)[source]

Return (tractInfo, patchInfo) for a given row.

This function only returns the single principle tract and patch in the case of overlap.

_request_patch(tract_index, patch_index)[source]

Request a patch from the butler. This will be a list of lsst.afw.image objects each corresponding to the configured bands

Uses functools.lru_cache for basic in-memory caching.

_fetch_single_cutout(row)[source]

Make a single cutout, returning a torch tensor.

Does not handle edge-of-tract/patch type edge cases, will only work near center of a patch.

class DownloadedLSSTDataset(config, data_location)[source]

Bases: hyrax.data_sets.lsst_dataset.LSSTDataset, hyrax.data_sets.tensor_cache_mixin.TensorCacheMixin

DownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads cutouts from the LSST butler, saving them as .pt files during first access. On subsequent accesses, it loads cutouts directly from these cached files.

This class also creates a manifest files with the shape of each cutout and the corresponding filename.

Public Methods:
download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False):

Download cutouts with parallel processing. Automatically resumes from previous progress. Use max_workers to control thread count, force_retry to re-attempt failed downloads.

manifest_stats():

Returns dict with download statistics: total, successful, failed, pending counts and manifest file path.

download_progress():

Returns detailed progress metrics including completion percentage and failure rates.

reset_failed_downloads():

Resets all failed download attempts to allow retry without force_retry flag. Returns count of reset entries.

save_manifest_now():

Forces immediate manifest save (normally saved periodically during downloads).

cache_info():

Returns LRU cache statistics for patch fetching performance monitoring.

clear_cache():

Clears the patch LRU cache to free memory.

Usage Example:

# Initialize Hyrax h = hyrax.Hyrax() a = h.prepare()

# Download all cutouts (resumes automatically) a.download_cutouts(max_workers=4) WARNING: The LRU Caching scheme is slightly complicated, so it is recommended to use the default max_workers=1 for the first download. Simply using more workers may not always speed up the download process.

# Check progress a.download_progress()

# Retry failed downloads a.download_cutouts(force_retry=True)

# Access cutouts (loads from cache) cutout = a[0] # Single cutout cutouts = a[0:10] # Multiple cutouts

File Organization: - Cutouts saved as: cutout_{object_id}.pt or cutout_{index:04d}.pt - Manifest saved as: manifest.fits (Astropy) or manifest.parquet (HATS) - All files stored in config[“general”][“data_dir”]

__init__()[source]

Initialize the dataset with either a HATS catalog or astropy table.

Config can specify either: - config[“data_set”][“hats_catalog”]: path to HATS catalog - config[“data_set”][“astropy_table”]: path to any file readable by Astropy Table

download_dir
_config
_manifest_lock
_updates_since_save = 0
_save_interval = 1000
_band_failure_stats
_band_failure_lock
_manifest_filter_object_ids = None
_catalog_to_manifest_index_map = None
_manifest_to_catalog_index_map = None
get_objectId(idx)[source]

Get object ID for a given index based on naming strategy.

ids(log_every=None)[source]

Generator yielding object IDs for the entire dataset. Required by TensorCacheMixin

_setup_naming_strategy()[source]

Setup file naming strategy based on catalog columns.

_initialize_manifest()[source]

Create new manifest or load/merge with existing manifest, with band filtering validation.

_load_existing_manifest()[source]

Load existing manifest file.

_merge_manifests(existing_manifest)[source]

Merge existing manifest with current catalog based on object_id.

_build_catalog_to_manifest_index_map(manifest)[source]

Build efficient mapping from catalog indices to manifest indices.

_add_manifest_columns()[source]

Add cutout_shape, filename, and downloaded_bands columns to manifest.

_get_available_bands_from_manifest(manifest)[source]

Get available bands by checking first 10 successful downloads for consistency.

_setup_band_filtering(requested_bands, original_band_order)[source]

Setup band filtering to extract only requested bands from cached cutouts.

_get_cutout_path(idx)[source]

Generate cutout file path for a given index.

_update_manifest_entry(idx, cutout_shape=None, filename='Attempted', downloaded_bands=None)[source]

Thread-safe manifest update with periodic saves.

Parameters:
  • idx – Index in the manifest

  • cutout_shape – Shape tuple of the cutout tensor, or None for failed downloads

  • filename – Basename of the saved file, or “Attempted” only when ALL bands fail

  • downloaded_bands – List of band names successfully downloaded in tensor order

_save_manifest()[source]

Save manifest in appropriate format (FITS for Astropy, Parquet for HATS).

_sync_manifest_with_filesystem()[source]

Sync manifest with actual downloaded files on disk.

static _request_patch_cached(tract_index, patch_index, butler_repo, butler_collections, skymap_name, bands_tuple)[source]

Cached patch fetching using static method.

Static method means no ‘self’ in cache key, making it truly global. Thread-safe because each call creates its own Butler instance.

_fetch_single_cutout(row, idx=None, manifest_idx=None)[source]

Fetch cutout, using saved cutout if available, with optional band filtering.

_fetch_cutout_with_cache(row)[source]

Generate cutout using cached patch fetching with NaN filling for failed bands.

_load_tensor_for_cache(object_id: str)[source]

Implementation of TensorCacheMixin abstract method.

__len__()[source]

Return length of current catalog, not the full manifest.

_get_manifest_index_for_catalog_index(catalog_idx)[source]

Map catalog index to manifest index when filtering is active.

get_image(idxs)[source]

Fetch image cutout(s) for given index or indices, using caching and band filtering.

Parameters:

idxs: int or slice or list

Index or indices to fetch.

Returns:

torch.Tensor or list of torch.Tensor:

Single cutout tensor or list of cutout tensors.

__getitem__(idxs) dict[source]

Modified to pass index for saving cutouts.

Parameters:

idxs: int or slice or list

Index or indices to fetch.

Returns:

dict:

Dictionary with key ‘data’ containing another dict of default data fields to return. Currently only ‘image’ is supported.

download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False)[source]

Download cutouts using multiple threads with caching.

Parameters:
  • indices – List of indices to download, or None for all

  • sync_filesystem – Whether to sync manifest with existing files on disk

  • max_workers – Maximum number of worker threads, or None to use default

  • force_retry – Whether to retry previously failed downloads

_download_single_cutout(catalog_idx, manifest_idx)[source]

Helper method to download a single cutout.

cache_info()[source]

Get cache statistics.

clear_cache()[source]

Clear the LRU cache.

manifest_stats()[source]

Get manifest statistics including downloaded bands information.

band_filtering_info()[source]

Get information about current band filtering configuration.

save_manifest_now()[source]

Force immediate manifest save.

static _determine_numprocs_download()[source]

Determine number of threads for downloading.

reset_failed_downloads()[source]

Reset failed download attempts to allow retry.

download_progress()[source]

Get detailed download progress information.

download_summary()[source]

Get detailed download and band analysis, accounting for band filtering.

class HSCDataSet(config: dict, data_location=None)[source]

Bases: hyrax.data_sets.fits_image_dataset.FitsImageDataSet

Dataset for sets of HSC cutouts created by the fibad download command.

__init__()[source]
_called_from_test = False
filters_config
_read_filter_catalog(filter_catalog_path: pathlib.Path | None)[source]
_parse_filter_catalog(table) None[source]

Sets self.files by parsing the catalog.

Subclasses may override this function to control parsing of the table more directly, but the overriding class must create the files dict which has type dict[object_id -> dict[filter -> filename]] with object_id, filter, and filename all strings. In the case of no filter distinction, a single flag value may be used for the filter dict keys in the inner dicts.

Parameters:

table (Table) – The catalog we read in

_set_crop_transform()[source]

Returns the crop transform on the image

If overriden, subclass must: 1) set self.cutout_shape to a tuple of ints representing the size of the cutouts that will be returned at some point in the init flow.

  1. Update the crop tranform using self.set_crop_transform() from the HyraxImageDataset mixin

_before_preload()[source]
_scan_file_names(filters: list[str] | None, filter_obj_ids: list[str] | None = None) hyrax.data_sets.fits_image_dataset.files_dict[source]

Class initialization helper

Parameters:
  • filters (list[str], Optional:) – List of filters that we should look for in the data corpus

  • filter_obj_ids (list[str], Optional:) – Filter the file scan to only file names which have the provided object IDs, skipping other files When not provided, all file names in the configured data directory that match the pattern from hyrax download are parsed.

Returns:

Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. Corresponds to self.files

Return type:

dict[str,dict[str,str]]

static _determine_numprocs() int[source]
static _fixup_limit(nproc: int, res, est_limit, est_procs) int[source]
_scan_file_dimensions() dim_dict[source]
static _scan_file_dimension(processing_unit: tuple[str, list[str]]) tuple[str, list[tuple[int, int]]][source]
static _fits_file_dims(filepath) tuple[int, int][source]
_prune_objects(filters_ref: list[str], cutout_shape: tuple[int, int] | None)[source]

Class initialization helper. Prunes objects from the list of objects.

  1. Removes any objects which do not have all the filters specified in filters_ref

  2. If a cutout_shape was provided in the constructor, prunes files that are too small for the chosen cutout size

This function deletes from self.files and self.dims via _prune_object

Parameters:
  • files (dict[str,dict[str,str]]) – Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. This is created by _scan_files()

  • filters_ref (list[str]) – List of the filter names

  • cutout_shape (tuple[int, int]) – Cutout shape tuple provided from constructor

_mark_for_prune(object_id, reason)[source]
_prune_object(object_id, reason: str)[source]
_check_file_dimensions() tuple[int, int][source]

Class initialization helper. Find the maximal pixel size that all images can support

It is assumed that all the cutouts will be of very similar size; however, HSC’s cutout server does not return exactly the same number of pixels for every query, even when it is given the same angular spread for every cutout.

Machine learning models expect all images to be the same size.

This function warns on significant differences (>2px) on any dimension between the largest and smallest images.

Returns:

The minimum width and height in pixels of the entire dataset. In other words: the maximal image size in pixels that can be generated from ALL cutout images via cropping.

Return type:

tuple(int,int)

_rebuild_manifest(config)[source]
__contains__(object_id: str) bool[source]

Allows you to do object_id in dataset queries. Used by testing code.

Parameters:

object_id (str) – The object ID you’d like to know if is in the dataset

Returns:

True of the object_id given is in the data set

Return type:

bool

_all_files_full()[source]

Private read-only iterator over all files that enforces a strict total order across objects and filters. Will not work prior to self.files, and self.path initialization in __init__

Yields:

Tuple[object_id, filter, filename, dim] – Members of this tuple are - The object_id as a string - The filter name as a string - The filename relative to self.path - A tuple containing the dimensions of the fits file in pixels.

_object_files(object_id)[source]

Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files, and self.path initialization in __init__

Guaranteed to only return files that have filters in self.filters_ref.

Yields:

Path – The path to the file.

class HyraxCifarDataSet(config: dict, data_location: pathlib.Path = None)[source]

Bases: HyraxCifarBase, hyrax.data_sets.data_set_registry.HyraxDataset, torch.utils.data.Dataset

Map style CIFAR 10 dataset for Hyrax

This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation that works well for example code.

We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that into Train/test/Validate according to configuration.

__init__()[source]

Overall initialization for all DataSets which saves the config

Subclasses of HyraxDataSet ought call this at the end of their __init__ like:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset
from astropy.table import Table

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

__len__()[source]
__getitem__(idx)[source]
class HyraxCifarIterableDataSet(config: dict, data_location: pathlib.Path = None)[source]

Bases: HyraxCifarBase, hyrax.data_sets.data_set_registry.HyraxDataset, torch.utils.data.IterableDataset

Iterable style CIFAR 10 dataset for Hyrax

This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation that works well for example code. This version only supports iteration, and not map-style access

We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that into Train/test/Validate according to configuration.

__init__()[source]

Overall initialization for all DataSets which saves the config

Subclasses of HyraxDataSet ought call this at the end of their __init__ like:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset
from astropy.table import Table

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

__iter__()[source]
class HyraxRandomDataset(config, data_location)[source]

Bases: HyraxRandomDatasetBase, hyrax.data_sets.data_set_registry.HyraxDataset, torch.utils.data.Dataset

This dataset is stand-in for a map-style dataset. It will produce random numpy arrays along with sequential numeric ids and, optionally, labels randomly selected from the provided list of possible labels.

__init__(config, data_location)[source]

Initialize the dataset using the parameters defined in the configuration.

Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the ["data_set"]["HyraxRandomDataset"] table in the configuration:

  • size: The number of random data samples to produce.

  • shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).

  • seed: The random seed to use for reproducibility.

  • provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.

  • metadata_fields: A list of metadata field names. Used to create a metadata

    table with columns corresponding to each field name. All data is numeric.

  • number_invalid_values: The number of invalid values to insert into the data.

  • invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.

__getitem__(idx: int) dict[source]

Get a data sample by index. The returned dictionary will contain the following keys:

  • index: The index of the data sample.

  • object_id: The ID of the data sample.

  • image: The data sample as a numpy array.

  • label: The label of the data sample (if provided).

Parameters:

idx (int) – The index of the data sample to retrieve.

Returns:

A dictionary containing the data sample and its metadata.

Return type:

dict

__len__()[source]

Get the total number of samples in this dataset. This should be return the same value as the size parameter in the configuration.

ids()[source]

This function yields IDs for the dataset. It can be used as an iterable in a loop, or converted to a list by wrapping the function call in list(...).

class HyraxRandomIterableDataset(config, data_location)[source]

Bases: HyraxRandomDatasetBase, hyrax.data_sets.data_set_registry.HyraxDataset, torch.utils.data.IterableDataset

This dataset is stand-in for a iterable-style, or streaming, dataset. It will produce random numpy arrays and, optionally, labels randomly selected from the provided list of possible labels.

Note

While ids will be generated automatically for this dataset, calling the ids method of this dataset will return the index instead of the id.

__init__(config, data_location)[source]

Initialize the dataset using the parameters defined in the configuration.

Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the ["data_set"]["HyraxRandomDataset"] table in the configuration:

  • size: The number of random data samples to produce.

  • shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).

  • seed: The random seed to use for reproducibility.

  • provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.

  • metadata_fields: A list of metadata field names. Used to create a metadata

    table with columns corresponding to each field name. All data is numeric.

  • number_invalid_values: The number of invalid values to insert into the data.

  • invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.

__iter__()[source]

Yield the next data sample. The returned dictionary will have the following form:

  • data: A dictionary containing the following keys:

index: The index of the data sample. – object_id: The value will be the same as index for this dataset. – image: The data sample as a numpy array. – label: The label of the data sample (if provided).

Returns:

A dictionary containing a data sample and its metadata.

Return type:

dict

class HyraxRandomDatasetBase(config, data_location)[source]

This is the base class for the random datasets provided by Hyrax.

Warning

Direct use of HyraxRandomDatasetBase is not advised. When working with Hyrax, prefer to use HyraxRandomDataset or HyraxRandomIterableDataset.

__init__(config, data_location)[source]

Initialize the dataset using the parameters defined in the configuration.

Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the ["data_set"]["HyraxRandomDataset"] table in the configuration:

  • size: The number of random data samples to produce.

  • shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).

  • seed: The random seed to use for reproducibility.

  • provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.

  • metadata_fields: A list of metadata field names. Used to create a metadata

    table with columns corresponding to each field name. All data is numeric.

  • number_invalid_values: The number of invalid values to insert into the data.

  • invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.

data: numpy.ndarray

The random data samples produced by the dataset.

id_list: list

A list of sequential numeric IDs for each data sample.

provided_labels: list

A list of labels randomly selected from the provided list of possible labels.

data_location
get_image(idx: int) numpy.ndarray[source]

Get the image at the given index as a NumPy array.

get_label(idx: int) str[source]

Get the label at the given index.

get_object_id(idx: int) str[source]

Get the index of the item.

class InferenceDataSet(config, results_dir: pathlib.Path | str | None = None, verb: str | None = None)[source]

Bases: hyrax.data_sets.data_set_registry.HyraxDataset, torch.utils.data.Dataset

This is a dataset class to represent the situations where we wish to treat the output of inference as a dataset. e.g. when performing umap/visualization operations

Initialize an InferenceDataSet object.

As a user of this code, you should almost never create this class, Instances of this class are returned by the umap and infer verbs. Prefer those over creating your own.

If you do end up creating your own class, you will need a hyrax config, and to know some things about where the result you are interested in is stored.

Parameters:
  • config (dict) – The hyrax config dictionary

  • results_dir (Optional[Union[Path, str]], optional) –

    The results subdirectory of the inference or umap results you want to access, by default None. If no results subdirectory is provided, this function will attempt the following in order:

    1. Use the directory specified in config['results']['inference_dir'] if set and the directory exists

    2. Look in the results configured in config['general']['results_dir'] (./results/ by default), then use the most recent results directory corresponding to the verb specified.

  • verb (Optional[str], optional) – The name of the verb that generated the results, only important when the most recent results are being fetched. If no verb is provided, “infer” will be assumed.

Raises:

RuntimeError – When the provided results directory is corrupt, or cannot be found.

results_dir
batch_index
length
cached_batch_num: int | None = None
shape_element
_original_dataset_config
original_dataset
_shape()[source]

The shape of the dataset (Discovered from files)

Returns:

Tuple with the shape of an individual element of the dataset

Return type:

Tuple

ids() collections.abc.Generator[str][source]

IDs of this dataset. Will return a string generator with IDs.

These IDs are the IDs of the dataset used originally to generate this dataset.

Returns:

Generator that yields the string ids of this dataset

Return type:

Generator[str]

Yields:

Generator[str] – Yields the string ids of this dataset

__getitem__(idx: int | numpy.ndarray)[source]

Implements the [] operator

Parameters:

idx (Union[int, np.ndarray]) – Either an index or a numpy array of indexes. These are NOT the ID values of the dataset, but rather a zero-based index starting at the beginning of the inference dataset.

Returns:

Either the tensor corresponding to a single result, or a tensor with a multiplicity of results if multiple indexes were passed.

Return type:

torch.tensor

__len__() int[source]

Returns the length of the dataset.

Returns:

Length of the dataset.

Return type:

int

property original_config: dict

Get the original configuration for the dataset used to generate this inference dataset

Since this sort of dataset is definitionally an intermediate product, this returns the runtime config used to construct that dataset rather than this one.

Returns:

Configuration that can be used to create the original dataset that was used as input for whatever inference process created this dataset.

Return type:

dict

metadata_fields() list[str][source]

Get the metadata fields associted with the original dataset used to generate this one

Returns:

List of valid field names for metadata queries

Return type:

list[str]

metadata(idxs: numpy.typing.ArrayLike, fields: list[str]) numpy.typing.ArrayLike[source]

Get metadata associated with the data in the InferenceDataSet. This metadata comes from the original dataset, but is indexed according to the InferenceDataSet.

Parameters:
  • idxs (npt.ArrayLike) – Indexes in the InferenceDataSet for which metadata is desired

  • fields (list[str]) – Metadata fields requested

Returns:

An array where the rows correspond to the passed list of indexes and the columns correspond to the fields passed. Order is preserved- metadata[i] corresponds to idxs[i].

Return type:

npt.ArrayLike

_load_from_batch_file(batch_num: int, ids=Union[int, np.ndarray]) numpy.ndarray[source]

Hands back an array of tensors given a set of IDs in a particular batch and the given batch number

_resolve_results_dir(config, results_dir: pathlib.Path | str | None, verb: str | None) pathlib.Path[source]

Initialize an inference results directory as a data source. Accepts an override of what directory to use

class HyraxDataset(config: dict, metadata_table=None, object_id_column_name=None)[source]

How to make a hyrax dataset:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset

class MyDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict):
        super().__init__(config)

    def __getitem__():
        # Your getitem goes here
        pass

    def __len__ ():
        # Your len function goes here
        pass

Optional interfaces:

ids() -> Subclasses may override this directly with their own ids function returning a generator of strings

metadata -> Subclasses may pass an astropy table of metadata to __init__ in the superclass. This table of metadata will be available through the metadata_fields and metadata functions. If desired, a subclass may override these functions directly rather than using the astropy Table interface.

Further documentation is in the Getting started with Hyrax Custom Dataset Classes example notebook.

__init__()[source]

Overall initialization for all DataSets which saves the config

Subclasses of HyraxDataSet ought call this at the end of their __init__ like:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset
from astropy.table import Table

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

_config
_metadata_table = None
tensorboardx_logger = None
classmethod is_iterable()[source]

Returns true if underlying dataset is iterable style, supporting __iter__ vs map style where __getitem__/__len__ are the preferred access methods.

Returns:

True if underlying dataset is iterable

Return type:

bool

classmethod is_map()[source]

Returns true if underlying dataset is map style, supporting __getitem__/__len__ vs iterable where __iter__ is the preferred access method.

Returns:

True if underlying dataset is map-style

Return type:

bool

property config
classmethod __init_subclass__()[source]
ids() collections.abc.Generator[str][source]

This is the default IDs function you get when you derive from hyrax Dataset

Returns:

A generator yielding all the string IDs of the dataset.

Return type:

Generator[str]

sample_data() dict[source]

Get a sample from the dataset. This is a convenience function that returns the first sample from the dataset, regardless of whether it is iterable or map-style. Often this will be used to instantiate a model that adjusts its form based on the shape of the data.

metadata_fields() list[str][source]

Returns a list of metadata fields supported by this object

Returns:

The column names of the metadata table passed. Empty string if no metadata was provided at during construction of the HyraxDataset (or derived class).

Return type:

list[str]

metadata(idxs: numpy.typing.ArrayLike, fields: list[str]) numpy.typing.ArrayLike[source]

Returns a table representing the metadata given an array of indexes and a list of fields.

Parameters:
  • idxs (npt.ArrayLike) – The indexes of the relevant tensor objects

  • fields (list[str]) – The names of the fields you would like returned. All values must be among those returned by metadata_fields()

Returns:

A numpy record array of your metadata, with only the columns specified. Roughly equivalent to: metadata_table[idxs][fields].as_array() where metadata_table is the astropy table that the HyraxDataset (or derived class) was constructed with.

Return type:

npt.ArrayLike

Raises:

RuntimeError – When none of the provided fields are

class HyraxCifarBase(config: dict, data_location: pathlib.Path = None)[source]

Base class for Hyrax Cifar datasets

data_location
cifar
get_image(idx)[source]

Get the image at the given index as a NumPy array.

get_label(idx)[source]

Get the label at the given index.

get_index(idx)[source]

Get the index of the item.

get_object_id(idx)[source]

Get the object ID for the item.

class HyraxCSVDataset(config: dict, data_location: pathlib.Path = None)[source]

Bases: hyrax.data_sets.data_set_registry.HyraxDataset

A Hyrax Dataset for CSV files. This class reads a CSV file using pandas with memory mapping enabled. It dynamically creates getter methods for each column in the CSV file, allowing users to request data from specific columns.

Note: Column names found in the CSV file are used to create the getter methods. If a column name contains characters that are invalid for method names, those characters are replaced with underscores.

Example model_inputs configuration: {

“train”: {
“data”: {

“dataset_class”: “HyraxCSVDataset”, “data_location”: </path/to/data.csv>, “fields”: [“<column1>”, “<column2>”, …], “primary_id_field”: <column name that contains a unique ID>,

},

}, “validate”: { <similar to above> }, “infer”: { <similar to above> },

}

__init__()[source]

Overall initialization for all DataSets which saves the config

Subclasses of HyraxDataSet ought call this at the end of their __init__ like:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        super().__init__(config)

If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:

from hyrax.data_sets import HyraxDataset
from torch.utils.data import Dataset
from astropy.table import Table

class MyDataset(HyraxDataset, Dataset):
    def __init__(config):
        <your code>
        metadata_table = Table(<Your catalog data goes here>)
        super().__init__(config, metadata_table)
Parameters:
  • config (dict, Optional) – The runtime configuration for hyrax

  • metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.

  • object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.

data_location = None
column_names
mem_mapped_csv = None
__getitem__(idx)[source]

Currently required by Hyrax machinery, but likely to be phased out.

__len__() int[source]

Return the number of records in the CSV.

sample_data()[source]

Return the first record, in dictionary form, as the sample.

classmethod is_map() bool[source]

Boilerplate method to indicate this is a map-style dataset.