Getting started with Hyrax Custom Dataset Classes
In this notebook we are going to build up a custom dataset class for hyrax, and show how you can use the prepare verb in hyrax to test various aspects of your new dataclass.
First we will create some synthetic data. We’ll create 1000 10x10 numpy arrays with associated random file names. Nothing in this cell is specific to developing a Hyrax dataset, it’s just setting things up in a semi-realistic way.
[1]:
import numpy as np
rng = np.random.default_rng()
num_tensors = 1000
# Generate filenames
alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
filename_length = 15
filenames = ["".join(list(rng.choice(alphabet, 15))) for _ in range(num_tensors)]
# Generate numpy arrays
shape = (3, 10, 10)
random_data = {file: rng.random(size=shape, dtype=np.float32) for file in filenames}
Building a custom Dataset class
We will treat these tensors as if they are on the filesystem, and write a dataclass that gives hyrax access to these “files” treating _read_tensor as a library function which returns a torch.Tensor from our “files”, and _list_filenames as a library function which lists the filenames in a particular path.
The first thing we need to do is make a new class derived from HyraxDataset and torch.Dataset as shown below.
[2]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset
from pathlib import Path
from typing import Union
class MyDataset(HyraxDataset, Dataset):
def __init__(self, config: dict, data_location: Union[Path, str] = None):
self.filenames = MyDataset._list_filenames(data_location)
super().__init__(config)
def get_image(self, index):
"""Pretend to read specific data from the disk."""
filename = self.filenames[index]
global random_data
return random_data[filename]
def __getitem__(self, idx):
return {
"data": {"image": MyDataset.get_image(self.filenames[idx])},
}
def __len__(self):
return len(self.filenames)
@staticmethod
def _list_filenames(data_location):
"""This is a pretend implementation so we ignore data_location"""
global filenames
return filenames
Key aspects of this class that you will need to replicate are:
__init__must callsuper().__init__(config)This is important for hyrax to function appropriately, and gives you access to hyrax’s config in other functions should you want it later. You will probably want to accessconfig["general"]["data_dir"]to figure out what directory to start in.__getitem__You must implement this function, it takes an index and return the appropriate torch.Tensor for your data.__len__must return the length of your tensorial data.
Note that all of these are instance methods that use self as the first argument. This self is the current MyDataset object, and allows you to set and get values as is done with self.filenames in the code above.
The functions _list_filenames() and _read_tensor() are both reading our fake data, and are there so we have an effective demonstration. The functional organization of your analogous file reading code is entirely up to you!
We’re now going to start up Hyrax and use the prepare verb to create an instance of this class and see that it works correctly. Note that we have set config["general]["data_dir"] to specify the location of our data for the __init__ function we wrote earlier, as well as the config["data_set"]["name"] to the name of our class, so that Hyrax knows to use our dataset class rather than one of the built-in ones.
Our h.prepare() line in the script will have the effect of calling our __init__ function with the current hyrax config.
[3]:
import hyrax
h = hyrax.Hyrax()
h.config["model_inputs"] = {
"data": {
"dataset_class": "MyDataset",
"data_location": "/fake/path/to/some/data",
}
}
dataset = h.prepare()
[2025-09-17 14:25:05,001 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:06,348 hyrax.prepare:INFO] Finished Prepare
Testing
The object we recieved from h.prepare() is an instance of our dataset, which we can test for functionality
We’re going to index into the dataset object with [] this has the effect of calling our __getitem__ function and returning the result.
We’re also going to call len() on the dataset which will have the effect of calling our __len__ function
[4]:
print("Checking __getitem__ ...", end="\n\n")
item = dataset[0]
print('Shape of our first element, should be "torch.Size([3,10,10])": ')
print(item["data"]["image"].shape, end="\n\n")
print("Type of our first element, should be \"<class 'numpy.ndarray'>\": ")
print(type(item["data"]["image"]), end="\n\n")
print("Checking __len__, should print 1000: ")
print(len(dataset))
Checking __getitem__ ...
Shape of our first element, should be "torch.Size([3,10,10])":
(3, 10, 10)
Type of our first element, should be "<class 'numpy.ndarray'>":
<class 'numpy.ndarray'>
Checking __len__, should print 1000:
1000
This dataset class is suitable for training or inference with Hyrax; however, you may want to read on to learn about more advanced features such as custom IDs for your data elements, metadata, and configuration access.
Below is a short example that uses the HyraxAutoencoder built-in model, demonstrating that training is possible:
[5]:
import hyrax
h = hyrax.Hyrax()
h.config["model"]["name"] = "HyraxAutoencoder"
h.config["model_inputs"] = {
"data": {
"dataset_class": "MyDataset",
"data_location": "/fake/path/to/some/data",
}
}
h.train()
[2025-09-17 14:25:06,382 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:06,464 hyrax.models.hyrax_autoencoder:INFO] Found shape: (3, 10, 10) in data sample, using this to initialize model.
[2025-09-17 14:25:06,467 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-09-17 14:25:06,575 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
Dataset':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2da90a0ec0>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-09-17 14:25:06,576 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
Dataset':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2da8728910>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
/home/drew/miniconda3/envs/hyrax/lib/python3.13/site-packages/ignite/handlers/tqdm_logger.py:127: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
from tqdm.autonotebook import tqdm
2025/09/17 14:25:06 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-09-17 14:25:06,938 hyrax.pytorch_ignite:INFO] Training model on device: cuda
[2025-09-17 14:25:07,884 hyrax.pytorch_ignite:INFO] Total training time: 0.95[s]
[2025-09-17 14:25:07,885 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-142506-train-pSy4/checkpoint_epoch_10.pt
[2025-09-17 14:25:07,885 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-142506-train-pSy4/checkpoint_10_loss=-31.1603.pt
2025/09/17 14:25:07 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/09/17 14:25:07 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-09-17 14:25:07,894 hyrax.verbs.train:INFO] Finished Training
[2025-09-17 14:25:07,958 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250917-142506-train-pSy4/example_model_opset_20.onnx
[5]:
HyraxAutoencoder(
(encoder): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): GELU(approximate='none')
(6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): GELU(approximate='none')
(8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(9): GELU(approximate='none')
(10): Flatten(start_dim=1, end_dim=-1)
(11): Linear(in_features=256, out_features=64, bias=True)
)
(dec_linear): Sequential(
(0): Linear(in_features=64, out_features=256, bias=True)
(1): GELU(approximate='none')
)
(decoder): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(5): GELU(approximate='none')
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): GELU(approximate='none')
(8): ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(9): Tanh()
)
(criterion): CrossEntropyLoss()
)
Extending to support visualization
This section is primarily concerned with binding different sorts of metadata to your dataset. This metadata is used by the Hyrax visualization components to identify the source data of your latent space representation and link it back to a particular object/event in your astronomical dataset.
When we built MyDataclass above, we invisibly picked up two major aspects from HyraxDataset:
Unique IDs: Every tensor in our dataset got an ID of a sequential zero-based index, which was exactly the argument to
__getitem__/[]. This list of ids is available as an iterator by callingids()on the dataset object. These IDs are used in inference results and visualizations of the data, but they can be overriden.Metadata Interface: Every
HyraxDatasetcan provide an astropyTableof values in the same order as their__getitem__/[]This allows each tensor in the dataset to have associated scalar data such as ra/dec, ephemeris parameters, redshift, magnitude, etc. For our class there currently is no metadata.
Below is how we would access the metadata and IDs demonstrating the default behavior if your custom class does no overrides:
[6]:
import hyrax
h = hyrax.Hyrax()
h.config["model_inputs"] = {
"data": {
"dataset_class": "MyDataset",
"data_location": "/fake/path/to/some/data",
}
}
dataset = h.prepare()
print("\nIDs:")
print(f"list(dataset.ids())[0:10] = {list(dataset.ids())[0:10]}")
print("\nMetadata field list:")
print(f"dataset.metadata_fields() = {dataset.metadata_fields()} (there is no metadata)")
[2025-09-17 14:25:07,987 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:08,001 hyrax.prepare:INFO] Finished Prepare
IDs:
list(dataset.ids())[0:10] = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
Metadata field list:
dataset.metadata_fields() = ['object_id'] (there is no metadata)
Adding IDs
We’re going to use the filename in our fake data as IDs by adding a single ids() method to our MyDataset object. The most expedient way to do this will be to redefine the entire class below. Note that functions marked with a comment are just the same as earlier.
Note that the ids() function is required to return a generator, so we will use a for loop and yield each sequential value. This interface allows Hyrax to partially enumerate the IDs in a dataset when that is desirable. It is easy enough to get all the ids in order with list(dataset.ids()).
[7]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset
class MyDataset(HyraxDataset, Dataset):
def ids(self):
for filename in self.filenames:
yield filename
# Unchanged from before below this comment ...
def __init__(self, config: dict, data_location: Union[Path, str] = None):
self.filenames = MyDataset._list_filenames(data_location)
super().__init__(config)
def get_image(self, index):
"""Pretend to read specific data from the disk."""
filename = self.filenames[index]
global random_data
return random_data[filename]
def __getitem__(self, idx):
return {
"data": {"image": MyDataset.get_image(self.filenames[idx])},
}
def __len__(self):
return len(self.filenames)
@staticmethod
def _list_filenames(data_location):
"""This is a pretend implementation so we ignore data_location"""
global filenames
return filenames
Running prepare again on our newly defined dataset class, we can see that the ids are now the fake “filenames” we generated at the top of the notebook, rather than sequential integers:
[8]:
import hyrax
h = hyrax.Hyrax()
h.config["model_inputs"] = {
"data": {
"dataset_class": "MyDataset",
"data_location": "/fake/path/to/some/data",
}
}
dataset = h.prepare()
print("\nIDs:")
print(f"list(dataset.ids())[0:5] = {list(dataset.ids())[0:5]}")
[2025-09-17 14:25:08,051 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:08,063 hyrax.prepare:INFO] Finished Prepare
IDs:
list(dataset.ids())[0:5] = ['UYvkmfDTyMHydyA', 'SOErTFtdwGspvJS', 'QVsvWTasGFiEtKc', 'QWqksbRihihOqde', 'BFXVfmWywdOSWmi']
Adding Metadata
Now we are going to generate some fake metadata for our fake data. This will take the form of random ra/dec pairs for each fake object.
[9]:
import astropy.units as u
ras = rng.uniform(low=0.0, high=360.0, size=num_tensors) * u.deg
decs = rng.uniform(low=-90.0, high=90.0, size=num_tensors) * u.deg
In order to override metadata we will provide HyraxDataset with an astropy table containing all of the metadata in the constructor for our class as shown below. We do this in __init__ by passing an astropy table of our metadata to super().__init__ as a second, optional argument.
Note the new function _read_metadata() which constructs this table. On a real dataset this function would most likely call astropy’s Table.read high level interface to construct a table directly from your catalog.
As before we re-implement the entire class below with small modifications marked with comments:
[34]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset
class MyDataset(HyraxDataset, Dataset):
def __init__(self, config: dict, data_location: Union[Path, str] = None):
self.filenames = MyDataset._list_filenames(config["general"]["data_dir"])
metadata_table = MyDataset._read_metadata(config["general"]["data_dir"])
super().__init__(config, metadata_table=metadata_table)
def _read_metadata(path_to_data):
"""This is a pretend implementation so we don't use the path passed, which you might use
to find your .csv/.fits/.tsv catalog file and call astropy's Table.read().
We simply construct a table from our mock data"""
from astropy.table import Table
global ras, decs, filenames
return Table({"object_id": filenames, "ra": ras, "dec": decs})
# Unchanged from before below this comment ...
def ids(self):
for filename in self.filenames:
yield filename
def get_image(self, index):
"""Pretend to read specific data from the disk."""
filename = self.filenames[index]
global random_data
return random_data[filename]
def __getitem__(self, idx):
return {
"data": {"image": MyDataset.get_image(self.filenames[idx])},
"object_id": self.filenames[idx],
}
def __len__(self):
return len(self.filenames)
@staticmethod
def _list_filenames(data_location):
"""This is a pretend implementation so we ignore data_location"""
global filenames
return filenames
Now that our dataset class supports metadata, we can access the metadata interface directly on the dataset object using the metadata_fields and metadata functions on the dataset object.
metadata_fieldslists the available fields, in our case only “ra” and “dec” are available, but this is only because that is what was defined in the cell abovemetadatatakes a list (or array) of indexes, and a list (or array) of valid fields. It returns a numpy rec-array of the selected metadata fields for the selected data indexes. It is essentially equivalent tometadata_table[indexes][fields].as_array()wheremetadata_tableis the original astropy table.
[41]:
import hyrax
from astropy.table import Table
h = hyrax.Hyrax()
h.config["model_inputs"] = {
"data": {
"dataset_class": "MyDataset",
"data_location": "/fake/path/to/some/data",
}
}
dataset = h.prepare()
print("\nMetadata field list:")
print(f"dataset.metadata_fields() = {dataset.metadata_fields()}")
print(f'Table(dataset.metadata([1, 3, 4], ["ra_data", "dec_data"])) =>')
Table(dataset.metadata([1, 3, 4], ["ra_data", "dec_data"]))
[2025-09-17 16:47:01,811 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 16:47:01,832 hyrax.prepare:INFO] Finished Prepare
Metadata field list:
dataset.metadata_fields() = ['object_id_data', 'ra_data', 'dec_data', 'object_id']
Table(dataset.metadata([1, 3, 4], ["ra_data", "dec_data"])) =>
[41]:
| ra_data | dec_data |
|---|---|
| float64 | float64 |
| 292.4007388647287 | -87.13940122247789 |
| 84.67360544447406 | -59.027978794560354 |
| 218.76113549876985 | -1.0705419754372798 |
Now that we have a Dataset capable of ‘ra’ and ‘dec’ metadata, we can do a full analysis with hyrax, training the model, infering the latent space, umapping the latent space to a 2d representation, and visualize-ing the result.
[32]:
import hyrax
h = hyrax.Hyrax()
h.config["model_inputs"] = {
"data": {
"dataset_class": "MyDataset",
"data_location": "/fake/path/to/some/data",
"primary_id_field": "object_id",
}
}
h.config["model"]["name"] = "HyraxAutoencoder"
h.train()
h.infer()
h.umap()
v = h.visualize()
[2025-09-17 16:44:01,044 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 16:44:01,062 hyrax.models.hyrax_autoencoder:INFO] Found shape: (3, 10, 10) in data sample, using this to initialize model.
[2025-09-17 16:44:01,065 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-09-17 16:44:01,066 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
Dataset':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2a6f2d5630>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-09-17 16:44:01,066 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
Dataset':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2a6f4c7da0>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025/09/17 16:44:01 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-09-17 16:44:01,085 hyrax.pytorch_ignite:INFO] Training model on device: cuda
[2025-09-17 16:44:01,760 hyrax.pytorch_ignite:INFO] Total training time: 0.67[s]
[2025-09-17 16:44:01,760 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-train-tfhe/checkpoint_epoch_10.pt
[2025-09-17 16:44:01,761 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-train-tfhe/checkpoint_10_loss=-28.0658.pt
2025/09/17 16:44:01 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/09/17 16:44:01 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-09-17 16:44:01,769 hyrax.verbs.train:INFO] Finished Training
[2025-09-17 16:44:01,817 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-train-tfhe/example_model_opset_20.onnx
[2025-09-17 16:44:01,832 hyrax.models.hyrax_autoencoder:INFO] Found shape: (3, 10, 10) in data sample, using this to initialize model.
[2025-09-17 16:44:01,834 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-09-17 16:44:01,835 hyrax.verbs.infer:INFO] data set has length 1000
2025-09-17 16:44:01,835 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
Dataset':
{'sampler': None, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
[2025-09-17 16:44:01,840 hyrax.verbs.infer:INFO] Saving inference results at: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-infer-ZP0M
[2025-09-17 16:44:02,237 hyrax.pytorch_ignite:INFO] Evaluating model on device: cuda
[2025-09-17 16:44:02,239 hyrax.pytorch_ignite:INFO] Total epochs: 1
[2025-09-17 16:44:02,293 hyrax.pytorch_ignite:INFO] Total evaluation time: 0.06[s]
[2025-09-17 16:44:02,369 hyrax.verbs.infer:INFO] Inference Complete.
[2025-09-17 16:44:02,406 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-infer-ZP0M for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-09-17 16:44:02,441 hyrax.verbs.umap:INFO] Saving UMAP results to /home/drew/code/hyrax/docs/pre_executed/results/20250917-164402-umap-Volp
[2025-09-17 16:44:02,826 hyrax.verbs.umap:INFO] Fitting the UMAP
[2025-09-17 16:44:03,793 hyrax.verbs.umap:INFO] Saving fitted UMAP Reducer
[2025-09-17 16:44:05,269 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP
[2025-09-17 16:44:05,308 hyrax.verbs.visualize:INFO] UMAP directory not specified at runtime. Reading from config values.
[2025-09-17 16:44:05,310 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /home/drew/code/hyrax/docs/pre_executed/results/20250917-164402-umap-Volp for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-09-17 16:44:05,350 hyrax.verbs.visualize:INFO] Rendering UMAP from the following directory: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164402-umap-Volp
[ ]: