hyrax.datasets#
Hyrax has several built-in datasets that you can use for astronomical data. For many uses, these datasets can be configured out-of-the box for a given project.
FitsImageDataset is a generic container for fits image cutout data indexed by a user-provided catalog file. It attempts to cover common usage paradigms such as multiple images of the same object differentiated by telescope filter; however, extending the class as a custom dataset may be more well fit to advanced usage.
LSSTDataset Is a alpha-quality container for LSST cutout images, currently
limited to deep_coadd type images, and restricted to run only on a Rubin observatory RSP environment
where LSST Pipeline tools and a
data butler with the appropriate images
are available.
DownloadedLSSTDataset is a subclass of LSSTDataset that generates
cutouts from the butler and saves them as .pt files on first access. On subsequent access,
it loads the cutouts directly from these files, which can significantly speed up data loading times.
It inherits from LSSTDataset to access the data butler and catalog functionality.
HSCDataset Works similarly to FitsImageDataset, but is specialized to
Hyper Suprime-Cam (HSC) cutout images downloaded
with the hyrax download verb. It contains additional integrity checks and is tightly integrated with
the download and rebuild_manifest verbs. In future this class and the downloader may become a
separate package.
HyraxCifarDataset gives access to the standard CIFAR10 labeled image dataset, automatically downloading the dataset if it is not present. This dataset is useful for testing hyrax and occasionally individual models, but it is not an astronomical dataset.
HyraxRandomDataset is a utility dataset that generates random data with a specific shape. This dataset makes it easy to test new models with simple random data. It is highly configurable such that it’s possible to simulate input data for models that are under development.
Each of these datasets can be used a starting point for a Custom Dataset by inheriting your custom dataset from e.g. FitsImageDataset, or you can make an entirely custom dataset following the dataset class reference and/or dataset class notebook example.
The remaining classes in this module exist primarily for Hyrax interface purposes:
InferenceDataset is a dataset class that represents an infer or umap
result, and may be returned from those verbs to provide data access
HyraxDataset is a base class for all datasets in Hyrax and must be within the inheritance hierarchy of all custom datasets. It is not usable on its own, but provides various fall-back functionality to make custom datasets easier to write. See the dataset class reference and example notebook for more information.
Submodules#
- hyrax.datasets.data_cache
- hyrax.datasets.data_provider
- hyrax.datasets.dataset_registry
- hyrax.datasets.downloaded_lsst_dataset
- hyrax.datasets.fits_image_dataset
- hyrax.datasets.hsc_dataset
- hyrax.datasets.hyrax_cifar_dataset
- hyrax.datasets.hyrax_csv_dataset
- hyrax.datasets.inference_dataset
- hyrax.datasets.lsst_dataset
- hyrax.datasets.mmu_dataset
- hyrax.datasets.nested_pandas_dataset
- hyrax.datasets.random
- hyrax.datasets.result_dataset
- hyrax.datasets.result_factories
Classes#
Dataset for Fits Images, typically cutouts. |
|
LSSTDataset: A dataset to access deep_coadd images from lsst pipelines |
|
DownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads |
|
Dataset for sets of HSC cutouts created by the |
|
Map style CIFAR 10 dataset for Hyrax |
|
This dataset is stand-in for a map-style dataset. |
|
This is the base class for the random datasets provided by Hyrax. |
|
This is a dataset class to represent the situations where we wish to treat the output of inference |
|
Reader for Lance-based inference results. |
|
Writer for Lance-based inference results. |
|
How to make a hyrax dataset: |
|
A Hyrax Dataset for CSV files. |
|
Load a MultimodalUniverse dataset through Hugging Face |
|
A minimal Hyrax wrapper around |
|
DataCache tracks and manages a caching layer which can be used most effectively if the entirety of a |
Functions#
|
Create a writer for results (Lance format). |
|
Load a results dataset, auto-detecting format. |
Package Contents#
- class FitsImageDataset(config: dict, data_location=None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDataset,hyrax.datasets.dataset_registry.HyraxImageDataset,torch.utils.data.DatasetDataset for Fits Images, typically cutouts.
Initialize a FitsImageDataset
Most work is done in
_init_from_pathand functions it calls in order to allow subclasses to override behavior.- Parameters:
config (dict) – Nested configuration dictionary for hyrax
data_location (Optional[Union[Path, str]]) – The directory location of the data that this dataset class will access
- _called_from_test = False#
- _config#
- data_location = None#
- object_id_column_name#
- filter_column_name#
- filename_column_name#
- _init_from_path(path: pathlib.Path | str)[source]#
__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan operations and will ultimately open and read the header info of every fits file in the given directory
- Parameters:
path (Union[Path, str]) – Path or string specifying the directory path that is the root of all filenames in the catalog table
- _set_crop_transform()[source]#
Returns the crop transform on the image
If overriden, subclass must: 1) set self.cutout_shape to a tuple of ints representing the size of the cutouts that will be returned at some point in the init flow.
Update the crop tranform using self.set_crop_transform() from the HyraxImageDataset mixin
- _parse_filter_catalog(table) None[source]#
Sets self.files by parsing the catalog.
Subclasses may override this function to control parsing of the table more directly, but the overriding class must create the files dict which has type dict[object_id -> dict[filter -> filename]] with object_id, filter, and filename all strings. In the case of no filter distinction, a single flag value may be used for the filter dict keys in the inner dicts.
- Parameters:
table (Table) – The catalog we read in
- shape() tuple[int, int, int][source]#
Shape of the individual cutouts this will give to a model
- Returns:
Tuple describing the dimensions of the 3 dimensional tensor handed back to models The first index is the number of filters The second index is the width of each image The third index is the height of each image
- Return type:
tuple[int,int,int]
- __len__() int[source]#
Returns number of objects in this loader
- Returns:
number of objects in this data loader
- Return type:
int
- get_object_id(idx: int) str[source]#
Get the object ID at the given index
- Parameters:
idx (int) – Index of the object ID to return
- Returns:
The object ID at the given index
- Return type:
str
- get_image(idx: int)[source]#
Get the image at the given index as a PyTorch Tensor.
- Parameters:
idx (int) – Index of the image to return
- Returns:
The image at the given index as a PyTorch Tensor.
- Return type:
torch.Tensor
- __contains__(object_id: str) bool[source]#
Allows you to do object_id in dataset queries. Used by testing code.
- Parameters:
object_id (str) – The object ID you’d like to know if is in the dataset
- Returns:
True of the object_id given is in the data set
- Return type:
bool
- _get_file(index: int) pathlib.Path[source]#
Private indexing method across all files.
Returns the file path corresponding to the given index.
The index is zero-based and defined in the same manner as the total order of _all_files() and _object_files() iterator. Useful if you have an np.array() or list built from _all_files() and you need to select an individual item.
Only valid after self.object_ids, self.files, self.path, and self.num_filters have been initialized in __init__
- Parameters:
index (int) – Index, see above for order semantics
- Returns:
The path to the file
- Return type:
Path
- _all_ids(log_every=None) collections.abc.Generator[str][source]#
Private read-only iterator over all object_ids that enforces a strict total order across objects. Will not work prior to self.files initialization in __init__
- Yields:
Iterator[str] – Object IDs currently in the dataset
- _all_files()[source]#
Private read-only iterator over all files that enforces a strict total order across objects and filters. Will not work prior to self.files, and self.path initialization in __init__
- Yields:
Path – The path to the file.
- _filter_filename(object_id)[source]#
Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files initialization in __init__
- Yields:
filter_name, file name – The name of a filter and the file name for the fits file. The file name is relative to self.path
- _object_files(object_id)[source]#
Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files, and self.path initialization in __init__
- Yields:
Path – The path to the file.
- _file_to_path(filename: str) pathlib.Path[source]#
Turns a filename into a full path suitable for open. Equivalent to:
Path(self.path) / Path(filename)
- Parameters:
filename (str) – The filename string
- Returns:
A full path that is openable.
- Return type:
Path
- _object_id_to_tensor(object_id: str)[source]#
Converts an object_id to a pytorch tensor with dimensions (self.num_filters, self.cutout_shape[0], self.cutout_shape[1]). This is done by reading the file and slicing away any excess pixels at the far corners of the image from (0,0).
The current implementation reads the files once the first time they are accessed, and then keeps them in a dict for future accesses.
- Parameters:
object_id (str) – The object_id requested
- Returns:
A tensor with dimension (self.num_filters, self.cutout_shape[0], self.cutout_shape[1])
- Return type:
torch.Tensor
- class LSSTDataset(config, data_location=None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDataset,hyrax.datasets.dataset_registry.HyraxImageDataset,torch.utils.data.DatasetLSSTDataset: A dataset to access deep_coadd images from lsst pipelines via the butler. Must be run in an RSP.
Initialize the dataset with either a HATS catalog or astropy table.
Config can specify either: - config[“data_set”][“hats_catalog”]: path to HATS catalog - config[“data_set”][“astropy_table”]: path to any file readable by Astropy Table
- BANDS = ['u', 'g', 'r', 'i', 'z', 'y']#
- object_id_autodetect_names = ['object_id', 'objectId']#
- catalog#
- sh_deg#
- sw_deg#
- oid_column_name#
- _get_butler_thread_safe()[source]#
Thread safe butler creation
This function ensures that there is one and only one butler created per thread and that threads always use their assigned butler.
This is necessary because child classes of this one use butlers, and butler objects are not safe for multithreaded access.
- Returns:
The butler assigned to the current thread.
- Return type:
butler
- _load_catalog(data_set_config)[source]#
Load the catalog from either a HATS catalog or an astropy table.
- _load_astropy_catalog(table_path)[source]#
Load catalog from astropy table format or pickled astropy table.
- get_image(idxs)[source]#
Get image cutouts for the given indices.
- Parameters:
idxs (int or list of int) – The index or indices of the cutouts to retrieve.
- Returns:
Single cutout tensor or list of cutout tensors.
- Return type:
list or torch.Tensor
- __getitem__(idxs)[source]#
Get default data fields for the this dataset.
- Parameters:
idxs (int or list of int) – The index or indices of the cutouts to retrieve.
- Returns:
A dictionary containing the default data fields.
- Return type:
dict
- _parse_box(patch, row)[source]#
Return a Box2I representing the desired cutout in pixel space, given a “row” of catalog data which includes the semi-height (sh) and semi-width (sw) in degrees desired for the cutout.
- _parse_sphere_point(row)[source]#
Return a SpherePoint with the ra and deck given in the “row” of catalog data. Row must include the RA and dec as “ra” and “dec” columns respectively
- _get_tract_patch(row)[source]#
Return (tractInfo, patchInfo) for a given row.
This function only returns the single principle tract and patch in the case of overlap.
- class DownloadedLSSTDataset(config, data_location)[source]#
Bases:
hyrax.datasets.lsst_dataset.LSSTDatasetDownloadedLSSTDataset: A dataset that inherits from LSSTDataset and downloads cutouts from the LSST butler, saving them as .pt files during first access. On subsequent accesses, it loads cutouts directly from these cached files.
This class also creates a manifest files with the shape of each cutout and the corresponding filename.
- Public Methods:
- download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False):
Download cutouts with parallel processing. Automatically resumes from previous progress. Use max_workers to control thread count, force_retry to re-attempt failed downloads.
- manifest_stats():
Returns dict with download statistics: total, successful, failed, pending counts and manifest file path.
- download_progress():
Returns detailed progress metrics including completion percentage and failure rates.
- reset_failed_downloads():
Resets all failed download attempts to allow retry without force_retry flag. Returns count of reset entries.
- save_manifest_now():
Forces immediate manifest save (normally saved periodically during downloads).
- cache_info():
Returns LRU cache statistics for patch fetching performance monitoring.
- clear_cache():
Clears the patch LRU cache to free memory.
- Usage Example:
# Initialize Hyrax h = hyrax.Hyrax() a = h.prepare()
# Download all cutouts (resumes automatically) a.download_cutouts(max_workers=4) WARNING: The LRU Caching scheme is slightly complicated, so it is recommended to use the default max_workers=1 for the first download. Simply using more workers may not always speed up the download process.
# Check progress a.download_progress()
# Retry failed downloads a.download_cutouts(force_retry=True)
# Access cutouts (loads from cache) cutout = a[0] # Single cutout cutouts = a[0:10] # Multiple cutouts
File Organization: - Cutouts saved as: cutout_{object_id}.pt or cutout_{index:04d}.pt - Manifest saved as: manifest.fits (Astropy) or manifest.parquet (HATS) - All files stored in the data_location provided during initialization
Initialize the dataset with either a HATS catalog or astropy table.
Config can specify either: - config[“data_set”][“hats_catalog”]: path to HATS catalog - config[“data_set”][“astropy_table”]: path to any file readable by Astropy Table
- download_dir#
- catalog_object_ids#
- _manifest_lock#
- _updates_since_save = 0#
- _save_interval = 1000#
- _band_failure_stats#
- _band_failure_lock#
- _manifest_filter_object_ids = None#
- _catalog_to_manifest_index_map = None#
- _manifest_to_catalog_index_map = None#
- _initialize_manifest()[source]#
Create new manifest or load/merge with existing manifest, with band filtering validation.
The manifest is always an astropy Table with at least the following columns: cutout_shape: np.array of dimensions e.g. [3,150,150] filename: string containing the fits filename containing the tensor for the object downloaded_bands: string containing a comma separated list of the bands downloaded. Order is expected to be consistent between rows.
When this astropy table is loaded into memory, multiple sources are consulted. - The Manifest on the filesystem, which contains the source of truth for what files have been downloaded. If this is not found, it is created. - The bands given in the catalog passed in
- _update_manifest_from_catalog(existing_manifest)[source]#
Using object_id as a unique key, adds manifest entries to existing_manifest, using self.catalog as the source of any new objects.
self.catalog is not altered by this operation.
Entries in existing_manifest are not altered by this operation. New entries are added to the end of existing_manifest with a state indicating they have not been downloaded.
- _build_catalog_to_manifest_index_map()[source]#
Build efficient mapping from catalog indices to manifest indices.
- _add_manifest_columns_to_table(table)[source]#
Add cutout_shape, filename, and downloaded_bands columns to manifest.
- _get_available_bands_from_manifest(manifest)[source]#
Get available bands by finding entries with complete band coverage.
Uses cutout_shape[0] to determine the expected number of bands, then finds entries where downloaded_bands has that many entries (i.e., complete downloads).
- _setup_band_filtering(requested_bands, original_band_order)[source]#
Setup band filtering to extract only requested bands from cached cutouts.
- _get_cutout_path_from_idx(idx)[source]#
Generate cutout file path for a given index.
This simply applies a pattern to the filename using the object_id column. No guarantees are made about the file itself.
- _get_cutout_path_from_manifest(idx)[source]#
Get the cutout path by consulting the manifest
The download thread ensures that the filename is not written to the manifest until all the bands that we intend to download are downloaded.
This function is intended to be a thread safe way to get valid cutout paths. In the case where the file exists and is believed to be correctly downloaded you get a filename, but this will return None if there is some other issue.
- Parameters:
idx (int) – The catalog index of the relevant cutout
- Returns:
path to the cutout.
- Return type:
Path
- _update_manifest_entry(idx, cutout_shape=None, filename='Attempted', downloaded_bands=None)[source]#
Thread-safe manifest update with periodic saves.
- Parameters:
idx – Index in the manifest
cutout_shape – Shape tuple of the cutout tensor, or None for failed downloads
filename – Basename of the saved file, or “Attempted” only when ALL bands fail
downloaded_bands – List of band names successfully downloaded in tensor order
- _sync_manifest_with_filesystem()[source]#
Sync manifest with actual downloaded files on disk.
This updates the manifest to reflect what is on the filesystem. For existing cutouts this loads every file using torch.load
- static _request_patch_cached(tract_index, patch_index, butler, skymap_name, bands_tuple)[source]#
Cached patch fetching using static method.
Static method means no ‘self’ in cache key, making it truly global. Thread-safe because each call creates its own Butler instance.
- _fetch_single_cutout(row, idx=None, manifest_idx=None)[source]#
Fetch cutout, using saved cutout if available, with optional band filtering.
- _fetch_cutout_with_cache(row)[source]#
Generate cutout using cached patch fetching with NaN filling for failed bands.
- _get_manifest_index_for_catalog_index(catalog_idx)[source]#
Map catalog index to manifest index. None return indicates no such item in manifest.
- get_image(idxs)[source]#
Fetch image cutout(s) for given index or indices, using caching and band filtering.
Parameters:#
- idxs: int or slice or list
Index or indices to fetch.
Returns:#
- torch.Tensor or list of torch.Tensor:
Single cutout tensor or list of cutout tensors.
- __getitem__(idxs) dict[source]#
Modified to pass index for saving cutouts.
Parameters:#
- idxs: int or slice or list
Index or indices to fetch.
Returns:#
- dict:
Dictionary with key ‘data’ containing another dict of default data fields to return. Currently only ‘image’ is supported.
- download_cutouts(indices=None, sync_filesystem=True, max_workers=None, force_retry=False)[source]#
Download cutouts using multiple threads with caching.
- Parameters:
indices – List of indices to download, or None for all
sync_filesystem – Whether to sync manifest with existing files on disk
max_workers – Maximum number of worker threads, or None to use default
force_retry – Whether to retry previously failed downloads
- class HSCDataset(config: dict, data_location=None)[source]#
Bases:
hyrax.datasets.fits_image_dataset.FitsImageDatasetDataset for sets of HSC cutouts created by the
fibad downloadcommand.- _called_from_test = False#
- filters_config#
- _parse_filter_catalog(table) None[source]#
Sets self.files by parsing the catalog.
Subclasses may override this function to control parsing of the table more directly, but the overriding class must create the files dict which has type dict[object_id -> dict[filter -> filename]] with object_id, filter, and filename all strings. In the case of no filter distinction, a single flag value may be used for the filter dict keys in the inner dicts.
- Parameters:
table (Table) – The catalog we read in
- _set_crop_transform()[source]#
Returns the crop transform on the image
If overriden, subclass must: 1) set self.cutout_shape to a tuple of ints representing the size of the cutouts that will be returned at some point in the init flow.
Update the crop tranform using self.set_crop_transform() from the HyraxImageDataset mixin
- _scan_file_names(filters: list[str] | None, filter_obj_ids: list[str] | None = None) hyrax.datasets.fits_image_dataset.files_dict[source]#
Class initialization helper
- Parameters:
filters (list[str], Optional:) – List of filters that we should look for in the data corpus
filter_obj_ids (list[str], Optional:) – Filter the file scan to only file names which have the provided object IDs, skipping other files When not provided, all file names in the configured data directory that match the pattern from hyrax download are parsed.
- Returns:
Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. Corresponds to self.files
- Return type:
dict[str,dict[str,str]]
- static _scan_file_dimension(processing_unit: tuple[str, list[str]]) tuple[str, list[tuple[int, int]]][source]#
- _prune_objects(filters_ref: list[str], cutout_shape: tuple[int, int] | None)[source]#
Class initialization helper. Prunes objects from the list of objects.
Removes any objects which do not have all the filters specified in filters_ref
If a cutout_shape was provided in the constructor, prunes files that are too small for the chosen cutout size
This function deletes from self.files and self.dims via _prune_object
- Parameters:
files (dict[str,dict[str,str]]) – Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. This is created by _scan_files()
filters_ref (list[str]) – List of the filter names
cutout_shape (tuple[int, int]) – Cutout shape tuple provided from constructor
- _check_file_dimensions() tuple[int, int][source]#
Class initialization helper. Find the maximal pixel size that all images can support
It is assumed that all the cutouts will be of very similar size; however, HSC’s cutout server does not return exactly the same number of pixels for every query, even when it is given the same angular spread for every cutout.
Machine learning models expect all images to be the same size.
This function warns on significant differences (>2px) on any dimension between the largest and smallest images.
- Returns:
The minimum width and height in pixels of the entire dataset. In other words: the maximal image size in pixels that can be generated from ALL cutout images via cropping.
- Return type:
tuple(int,int)
- __contains__(object_id: str) bool[source]#
Allows you to do object_id in dataset queries. Used by testing code.
- Parameters:
object_id (str) – The object ID you’d like to know if is in the dataset
- Returns:
True of the object_id given is in the data set
- Return type:
bool
- _all_files_full()[source]#
Private read-only iterator over all files that enforces a strict total order across objects and filters. Will not work prior to self.files, and self.path initialization in __init__
- Yields:
Tuple[object_id, filter, filename, dim] – Members of this tuple are - The object_id as a string - The filter name as a string - The filename relative to self.path - A tuple containing the dimensions of the fits file in pixels.
- _object_files(object_id)[source]#
Private read-only iterator over all files for a given object. This enforces a strict total order across filters. Will not work prior to self.files, and self.path initialization in __init__
Guaranteed to only return files that have filters in self.filters_ref.
- Yields:
Path – The path to the file.
- class HyraxCifarDataset(config: dict, data_location: pathlib.Path = None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDatasetMap style CIFAR 10 dataset for Hyrax
This utilizes the CIFAR dataset from torchvision for retrieving the dataset.
Overall initialization for all Datasets which saves the config
Subclasses of HyraxDataset ought call this at the end of their __init__ like:
from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(config): <your code> super().__init__(config)
If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:
from hyrax.datasets import HyraxDataset from astropy.table import Table class MyDataset(HyraxDataset): def __init__(config): <your code> metadata_table = Table(<Your catalog data goes here>) super().__init__(config, metadata_table)
- Parameters:
config (dict, Optional) – The runtime configuration for hyrax
metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.
object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.
- data_location = None#
- training_data#
- cifar#
- id_width = 0#
- class HyraxRandomDataset(config, data_location)[source]#
Bases:
HyraxRandomDatasetBase,hyrax.datasets.dataset_registry.HyraxDataset,torch.utils.data.DatasetThis dataset is stand-in for a map-style dataset. It will produce random numpy arrays along with sequential numeric ids and, optionally, labels randomly selected from the provided list of possible labels.
Initialize the dataset using the parameters defined in the configuration.
Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the
["data_set"]["HyraxRandomDataset"]table in the configuration:size: The number of random data samples to produce.shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).seed: The random seed to use for reproducibility.provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.metadata_fields: A list of metadata field names. Used to create a metadata table with columns corresponding to each field name. All data is numeric.number_invalid_values: The number of invalid values to insert into the data.invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.
- __getitem__(idx: int) dict[source]#
Get a data sample by index.
The returned dictionary will contain the following keys:
index: The index of the data sample.object_id: The ID of the data sample.image: The data sample as a numpy array.label: The label of the data sample (if provided).
- Parameters:
idx (int) – The index of the data sample to retrieve.
- Returns:
A dictionary containing the data sample and its metadata.
- Return type:
dict
- class HyraxRandomDatasetBase(config, data_location)[source]#
This is the base class for the random datasets provided by Hyrax.
Warning
Direct use of
HyraxRandomDatasetBaseis not advised. When working with Hyrax, prefer to useHyraxRandomDataset.Initialize the dataset using the parameters defined in the configuration.
Parameter included for API consistency with other dataset classes, though not used by this implementation. All parameters are controlled by the following keys under the
["data_set"]["HyraxRandomDataset"]table in the configuration:size: The number of random data samples to produce.shape: The shape of each random data sample as a tuple (e.g. (3, 29, 29) = 3 layers of 2D data, each layer is 29x29 elements).seed: The random seed to use for reproducibility.provided_labels: A list of possible labels to randomly select from. If this is provided, the dataset will randomly select a label for each data sample.metadata_fields: A list of metadata field names. Used to create a metadata table with columns corresponding to each field name. All data is numeric.number_invalid_values: The number of invalid values to insert into the data.invalid_value_type: The type of invalid value to insert into the data. Valid values are “nan”, “inf”, “-inf”, “none”, or a float value.
- data: numpy.ndarray#
The random data samples produced by the dataset.
- id_list: list#
A list of sequential numeric IDs for each data sample.
- provided_labels: list#
A list of labels randomly selected from the provided list of possible labels.
- data_location#
- class InferenceDataset(config, results_dir: pathlib.Path | str | None = None, verb: str | None = None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDataset,torch.utils.data.DatasetThis is a dataset class to represent the situations where we wish to treat the output of inference as a dataset. e.g. when performing umap/visualization operations
Initialize an InferenceDataset object.
As a user of this code, you should almost never create this class, Instances of this class are returned by the umap and infer verbs. Prefer those over creating your own.
If you do end up creating your own class, you will need a hyrax config, and to know some things about where the result you are interested in is stored.
- Parameters:
config (dict) – The hyrax config dictionary
results_dir (Optional[Union[Path, str]], optional) –
The results subdirectory of the inference or umap results you want to access, by default None. If no results subdirectory is provided, this function will attempt the following in order:
Use the directory specified in
config['results']['inference_dir']if set and the directory existsLook in the results configured in
config['general']['results_dir'](./results/by default), then use the most recent results directory corresponding to the verb specified.
verb (Optional[str], optional) – The name of the verb that generated the results, only important when the most recent results are being fetched. If no verb is provided, “infer” will be assumed.
- Raises:
RuntimeError – When the provided results directory is corrupt, or cannot be found.
- results_dir#
- batch_index#
- length#
- cached_batch_num: int | None = None#
- shape_element#
- _original_dataset_config#
- original_dataset#
- _shape()[source]#
The shape of the dataset (Discovered from files)
- Returns:
Tuple with the shape of an individual element of the dataset
- Return type:
Tuple
- get_object_id(idx) str[source]#
Returns the ID at a particular index.
IDs are provided by the primary dataset’s primary ID column.
- ids() list[str][source]#
Returns the IDs of the dataset.
IDs flow from the primary dataset and the primary ID column.
For an InferenceDataset instance,
self.ids()is canonically the same as[self.get_object_id(i) for i in range(len(self))].
- _ids() collections.abc.Generator[str][source]#
IDs of this dataset. Will return a string generator with IDs.
These IDs are the IDs of the dataset used originally to generate this dataset.
- Returns:
Generator that yields the string ids of this dataset
- Return type:
Generator[str]
- Yields:
Generator[str] – Yields the string ids of this dataset
- __getitem__(idx: int | numpy.ndarray)[source]#
Implements the
[]operator- Parameters:
idx (Union[int, np.ndarray]) – Either an index or a numpy array of indexes. These are NOT the ID values of the dataset, but rather a zero-based index starting at the beginning of the inference dataset.
- Returns:
Either the tensor corresponding to a single result, or a tensor with a multiplicity of results if multiple indexes were passed.
- Return type:
torch.tensor
- __len__() int[source]#
Returns the length of the dataset.
- Returns:
Length of the dataset.
- Return type:
int
- property original_config: dict#
Get the original configuration for the dataset used to generate this inference dataset
Since this sort of dataset is definitionally an intermediate product, this returns the runtime config used to construct that dataset rather than this one.
- Returns:
Configuration that can be used to create the original dataset that was used as input for whatever inference process created this dataset.
- Return type:
dict
- metadata_fields() list[str][source]#
Get the metadata fields associted with the original dataset used to generate this one
- Returns:
List of valid field names for metadata queries
- Return type:
list[str]
- metadata(idxs: numpy.typing.ArrayLike, fields: list[str]) numpy.typing.ArrayLike[source]#
Get metadata associated with the data in the InferenceDataset. This metadata comes from the original dataset, but is indexed according to the InferenceDataset.
- Parameters:
idxs (npt.ArrayLike) – Indexes in the InferenceDataset for which metadata is desired
fields (list[str]) – Metadata fields requested
- Returns:
An array where the rows correspond to the passed list of indexes and the columns correspond to the fields passed. Order is preserved- metadata[i] corresponds to idxs[i].
- Return type:
npt.ArrayLike
- class ResultDataset(config: dict, data_location: pathlib.Path | str)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDatasetReader for Lance-based inference results.
Provides HyraxQL-compatible getters to results stored in Lance format.
Initialize the dataset.
- Parameters:
config (dict) – Hyrax configuration dictionary
data_location (Union[Path, str]) – Path to results directory containing lance_db/
- data_location#
- lance_dir#
- db#
- table#
- lance_dataset#
- tensor_shape#
- tensor_dtype#
- __getitem__(idx: int | numpy.ndarray)[source]#
Get data by index.
- Parameters:
idx (Union[int, np.ndarray]) – Single index or array of indices
- Returns:
Data tensor(s)
- Return type:
np.ndarray
- Raises:
IndexError – If index is out of range
- get_data(idx: int)[source]#
Get data tensor at index (HyraxQL getter).
- Parameters:
idx (int) – Index of the data item
- Returns:
Data tensor
- Return type:
np.ndarray
- class ResultDatasetWriter(result_dir: str | pathlib.Path)[source]#
Writer for Lance-based inference results.
Writes inference results incrementally to Lance format using table.add() for each batch, avoiding memory accumulation.
Initialize the writer.
- Parameters:
result_dir (Union[str, Path]) – Directory where Lance database will be created
- result_dir#
- lance_dir#
- db = None#
- table = None#
- schema = None#
- tensor_dtype = None#
- tensor_shape = None#
- batch_count = 0#
- create_results_writer(result_dir: str | pathlib.Path)[source]#
Create a writer for results (Lance format).
This factory creates a ResultDatasetWriter for writing inference results to Lance format. New writes always use Lance format going forward.
- Parameters:
result_dir (Union[str, Path]) – Directory where results should be saved
- Returns:
Writer instance for Lance storage
- Return type:
- load_results_dataset(config: dict, results_dir: pathlib.Path | str | None = None, verb: str | None = None)[source]#
Load a results dataset, auto-detecting format.
This factory auto-detects whether the results are in Lance or .npy format and returns the appropriate dataset class.
- Parameters:
config (dict) – The hyrax config dictionary
results_dir (Union[Path, str, None], optional) – The results subdirectory to load from
verb (Union[str, None], optional) – The name of the verb that generated the results (for auto-discovery)
- Returns:
The appropriate dataset instance based on detected format
- Return type:
Union[ResultDataset, InferenceDataset]
- class HyraxDataset(config: dict, metadata_table=None, object_id_column_name=None)[source]#
How to make a hyrax dataset:
from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(self, config: dict): super().__init__(config) def __len__(self): # Your len function goes here pass
Optional interfaces:
metadata-> Subclasses may pass an astropy table of metadata to__init__in the superclass. This table of metadata will be available through themetadata_fieldsandmetadatafunctions. If desired, a subclass may override these functions directly rather than using the astropy Table interface.Further documentation is in the Build a dataset class in a notebook example notebook.
Overall initialization for all Datasets which saves the config
Subclasses of HyraxDataset ought call this at the end of their __init__ like:
from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(config): <your code> super().__init__(config)
If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:
from hyrax.datasets import HyraxDataset from astropy.table import Table class MyDataset(HyraxDataset): def __init__(config): <your code> metadata_table = Table(<Your catalog data goes here>) super().__init__(config, metadata_table)
- Parameters:
config (dict, Optional) – The runtime configuration for hyrax
metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.
object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.
- _config#
- _metadata_table = None#
- property config#
- metadata_fields() list[str][source]#
Returns a list of metadata fields supported by this object
- Returns:
The column names of the metadata table passed. Empty string if no metadata was provided at during construction of the HyraxDataset (or derived class).
- Return type:
list[str]
- metadata(idxs: numpy.typing.ArrayLike, fields: list[str]) numpy.typing.ArrayLike[source]#
Returns a table representing the metadata given an array of indexes and a list of fields.
- Parameters:
idxs (npt.ArrayLike) – The indexes of the relevant tensor objects
fields (list[str]) – The names of the fields you would like returned. All values must be among those returned by metadata_fields()
- Returns:
A numpy record array of your metadata, with only the columns specified. Roughly equivalent to: metadata_table[idxs][fields].as_array() where metadata_table is the astropy table that the HyraxDataset (or derived class) was constructed with.
- Return type:
npt.ArrayLike
- Raises:
RuntimeError – When none of the provided fields are
- class HyraxCSVDataset(config: dict, data_location: pathlib.Path = None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDatasetA Hyrax Dataset for CSV files.
This class reads a CSV file using pandas with memory mapping enabled. It dynamically creates getter methods for each column in the CSV file, allowing users to request data from specific columns.
Note
Column names found in the CSV file are used to create the getter methods. If a column name contains characters that are invalid for method names, those characters are replaced with underscores.
Examples
Example data_request configuration:
{ "train": { "data": { "dataset_class": "HyraxCSVDataset", "data_location": "</path/to/data.csv>", "fields": ["<column1>", "<column2>", ...], "primary_id_field": "<column name that contains a unique ID>", }, }, "validate": { "<similar to above>" }, "infer": { "<similar to above>" }, }
Overall initialization for all Datasets which saves the config
Subclasses of HyraxDataset ought call this at the end of their __init__ like:
from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(config): <your code> super().__init__(config)
If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:
from hyrax.datasets import HyraxDataset from astropy.table import Table class MyDataset(HyraxDataset): def __init__(config): <your code> metadata_table = Table(<Your catalog data goes here>) super().__init__(config, metadata_table)
- Parameters:
config (dict, Optional) – The runtime configuration for hyrax
metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.
object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.
- data_location = None#
- column_names#
- mem_mapped_csv = None#
- class MultimodalUniverseDataset(config: dict, data_location: pathlib.Path | str | None = None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDatasetLoad a MultimodalUniverse dataset through Hugging Face
datasets.This dataset class is intentionally generic so one configuration pattern can be used for image, spectra, and time-series MMU datasets.
Examples
Example
data_requestconfiguration:{ "infer": { "mmu": { "dataset_class": "MultimodalUniverseDataset", "data_location": "hf://MultimodalUniverse/plasticc", "primary_id_field": "object_id", "dataset_config": { "MultimodalUniverseDataset": { "split": "train", "max_samples": 32, } }, } } }
Overall initialization for all Datasets which saves the config
Subclasses of HyraxDataset ought call this at the end of their __init__ like:
from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(config): <your code> super().__init__(config)
If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:
from hyrax.datasets import HyraxDataset from astropy.table import Table class MyDataset(HyraxDataset): def __init__(config): <your code> metadata_table = Table(<Your catalog data goes here>) super().__init__(config, metadata_table)
- Parameters:
config (dict, Optional) – The runtime configuration for hyrax
metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.
object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.
- data_location = ''#
- split#
- max_samples#
- streaming#
- dataset#
- _column_name_map#
- _build_column_name_map() dict[str, str][source]#
Returns a map from sanitized column names to the original column names.
It’s possible for a column name to have punctuation or start with a number. In these cases we also allow column access via a sanitized name where all punctuation is replaced with the underscore character, and any field starting with a number is replaced by
field_Every field is entered in the dictionary regardless of whether it needed sanitization or not. In this case the sanitized name is exactly the field name.
- class NestedPandasDataset(config: dict, data_location: pathlib.Path | str | None = None)[source]#
Bases:
hyrax.datasets.dataset_registry.HyraxDatasetA minimal Hyrax wrapper around
nested_pandas.read_parquet.Overall initialization for all Datasets which saves the config
Subclasses of HyraxDataset ought call this at the end of their __init__ like:
from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(config): <your code> super().__init__(config)
If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that metadata_table as shown below:
from hyrax.datasets import HyraxDataset from astropy.table import Table class MyDataset(HyraxDataset): def __init__(config): <your code> metadata_table = Table(<Your catalog data goes here>) super().__init__(config, metadata_table)
- Parameters:
config (dict, Optional) – The runtime configuration for hyrax
metadata_table (Optional[Table], optional) – An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated.
object_id_column_name (Optional[str], optional) – The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method.
- data_location = ''#
- read_kwargs#
- nested_frame#
- class DataCache(config, data_provider: hyrax.datasets.data_provider.DataProvider)[source]#
DataCache tracks and manages a caching layer which can be used most effectively if the entirety of a training (or inference) epoch fits in system RAM.
Two configs control this functionality:
h.config[“data_set”][“use_cache”] which determines if we are serving data dictionaries out of a cache. When set, the first epoch of training fills the cache with tensors, and subsequent epochs are served out of the cache.
h.config[“data_set”][“preload_cache”] starts a thread which iterates over the dataset/dataloader class to completion. The thread pre-loads the cache with tensors independently of the training process. The hope is that this thread proceeds faster than the first epoch of training and speeds up the first epoch as well.
In this class we cache the output of DataProvider, before being batched. Users can control the size of data cached by only selecting particular fields in their data_request specification.
The class logs to the tensorboard logger in the DataProvider (when configured).
Initialize the DataCache with a Hyrax config.
- Parameters:
config (dict) – The Hyrax configuration that defines the data_request.
data_provider (DataProvider) – The DataProvider object which we are caching for.
- _max_length#
- _resolve_data_func#
- _data_provider#
- _use_cache#
- _preload_cache#
- _data_size_bytes = 0#
- _insert_count = 0#
- logging_interval = 1000#
- _cache_map#
- _preload_thread = None#
- start_preload_thread()[source]#
Start the cache preload thread if configured
This exists to separate initialization from thread start in DataProvider’s constructor, so the thread started can always count on a fully initialized DataProvider.
- try_fetch(idx: int) dict | None[source]#
Try to fetch a data_dict from the cache.
- Parameters:
idx (int) – The DataProvider index of the data dict
- Returns:
The data dict from the cache, None on a cache miss.
- Return type:
Optional[dict]
- insert_into_cache(idx: int, data: dict[str, dict[str, Any]])[source]#
Insert a data dict into the cache
- Parameters:
idx (int) – Index of the data dict
data (dict[str, dict[str, Any]]) – The data dict
- _lazy_map_executor(executor: concurrent.futures.Executor, idxs: collections.abc.Iterable[int])[source]#
Lazy evaluation version of concurrent.futures.Executor.map().
This limits memory usage during preloading by keeping only a small number of data dictionaries in memory at once.
- Parameters:
executor (concurrent.futures.Executor) – An executor for running futures
idxs (Iterable[int]) – An iterable list of DataProvider indexes
- Yields:
Iterator[torch.Tensor] – An iterator over torch tensors, lazily loaded