Getting started with Hyrax Custom Dataset Classes

In this notebook we are going to build up a custom dataset class for hyrax, and show how you can use the prepare verb in hyrax to test various aspects of your new dataclass.

First we will create some synthetic data. We’ll create 1000 10x10 numpy arrays with associated random file names. Nothing in this cell is specific to developing a Hyrax dataset, it’s just setting things up in a semi-realistic way.

[1]:
import numpy as np

rng = np.random.default_rng()
num_tensors = 1000

# Generate filenames
alphabet = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
filename_length = 15
filenames = ["".join(list(rng.choice(alphabet, 15))) for _ in range(num_tensors)]

# Generate numpy arrays
shape = (3, 10, 10)
random_data = {file: rng.random(size=shape, dtype=np.float32) for file in filenames}

Building a custom Dataset class

We will treat these tensors as if they are on the filesystem, and write a dataclass that gives hyrax access to these “files” treating _read_tensor as a library function which returns a torch.Tensor from our “files”, and _list_filenames as a library function which lists the filenames in a particular path.

The first thing we need to do is make a new class derived from HyraxDataset and torch.Dataset as shown below.

[2]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset
from pathlib import Path
from typing import Union


class MyDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict, data_location: Union[Path, str] = None):
        self.filenames = MyDataset._list_filenames(data_location)

        super().__init__(config)

    def get_image(self, index):
        """Pretend to read specific data from the disk."""
        filename = self.filenames[index]
        global random_data
        return random_data[filename]

    def __getitem__(self, idx):
        return {
            "data": {"image": MyDataset.get_image(self.filenames[idx])},
        }

    def __len__(self):
        return len(self.filenames)

    @staticmethod
    def _list_filenames(data_location):
        """This is a pretend implementation so we ignore data_location"""
        global filenames
        return filenames

Key aspects of this class that you will need to replicate are:

  • __init__ must call super().__init__(config) This is important for hyrax to function appropriately, and gives you access to hyrax’s config in other functions should you want it later. You will probably want to access config["general"]["data_dir"] to figure out what directory to start in.

  • __getitem__ You must implement this function, it takes an index and return the appropriate torch.Tensor for your data.

  • __len__ must return the length of your tensorial data.

Note that all of these are instance methods that use self as the first argument. This self is the current MyDataset object, and allows you to set and get values as is done with self.filenames in the code above.

The functions _list_filenames() and _read_tensor() are both reading our fake data, and are there so we have an effective demonstration. The functional organization of your analogous file reading code is entirely up to you!

We’re now going to start up Hyrax and use the prepare verb to create an instance of this class and see that it works correctly. Note that we have set config["general]["data_dir"] to specify the location of our data for the __init__ function we wrote earlier, as well as the config["data_set"]["name"] to the name of our class, so that Hyrax knows to use our dataset class rather than one of the built-in ones.

Our h.prepare() line in the script will have the effect of calling our __init__ function with the current hyrax config.

[3]:
import hyrax

h = hyrax.Hyrax()
h.config["model_inputs"] = {
    "data": {
        "dataset_class": "MyDataset",
        "data_location": "/fake/path/to/some/data",
    }
}

dataset = h.prepare()
[2025-09-17 14:25:05,001 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:06,348 hyrax.prepare:INFO] Finished Prepare

Testing

The object we recieved from h.prepare() is an instance of our dataset, which we can test for functionality

We’re going to index into the dataset object with [] this has the effect of calling our __getitem__ function and returning the result.

We’re also going to call len() on the dataset which will have the effect of calling our __len__ function

[4]:
print("Checking __getitem__ ...", end="\n\n")
item = dataset[0]

print('Shape of our first element, should be "torch.Size([3,10,10])": ')
print(item["data"]["image"].shape, end="\n\n")

print("Type of our first element, should be \"<class 'numpy.ndarray'>\": ")
print(type(item["data"]["image"]), end="\n\n")

print("Checking __len__, should print 1000: ")
print(len(dataset))
Checking __getitem__ ...

Shape of our first element, should be "torch.Size([3,10,10])":
(3, 10, 10)

Type of our first element, should be "<class 'numpy.ndarray'>":
<class 'numpy.ndarray'>

Checking __len__, should print 1000:
1000

This dataset class is suitable for training or inference with Hyrax; however, you may want to read on to learn about more advanced features such as custom IDs for your data elements, metadata, and configuration access.

Below is a short example that uses the HyraxAutoencoder built-in model, demonstrating that training is possible:

[5]:
import hyrax

h = hyrax.Hyrax()
h.config["model"]["name"] = "HyraxAutoencoder"
h.config["model_inputs"] = {
    "data": {
        "dataset_class": "MyDataset",
        "data_location": "/fake/path/to/some/data",
    }
}

h.train()
[2025-09-17 14:25:06,382 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:06,464 hyrax.models.hyrax_autoencoder:INFO] Found shape: (3, 10, 10) in data sample, using this to initialize model.
[2025-09-17 14:25:06,467 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-09-17 14:25:06,575 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
  Dataset':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2da90a0ec0>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-09-17 14:25:06,576 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
  Dataset':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2da8728910>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
/home/drew/miniconda3/envs/hyrax/lib/python3.13/site-packages/ignite/handlers/tqdm_logger.py:127: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
2025/09/17 14:25:06 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-09-17 14:25:06,938 hyrax.pytorch_ignite:INFO] Training model on device: cuda
[2025-09-17 14:25:07,884 hyrax.pytorch_ignite:INFO] Total training time: 0.95[s]
[2025-09-17 14:25:07,885 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-142506-train-pSy4/checkpoint_epoch_10.pt
[2025-09-17 14:25:07,885 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-142506-train-pSy4/checkpoint_10_loss=-31.1603.pt
2025/09/17 14:25:07 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/09/17 14:25:07 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-09-17 14:25:07,894 hyrax.verbs.train:INFO] Finished Training
[2025-09-17 14:25:07,958 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250917-142506-train-pSy4/example_model_opset_20.onnx
[5]:
HyraxAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU(approximate='none')
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU(approximate='none')
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): GELU(approximate='none')
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU(approximate='none')
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): GELU(approximate='none')
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=256, out_features=64, bias=True)
  )
  (dec_linear): Sequential(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): GELU(approximate='none')
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU(approximate='none')
    (4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): GELU(approximate='none')
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU(approximate='none')
    (8): ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (9): Tanh()
  )
  (criterion): CrossEntropyLoss()
)

Extending to support visualization

This section is primarily concerned with binding different sorts of metadata to your dataset. This metadata is used by the Hyrax visualization components to identify the source data of your latent space representation and link it back to a particular object/event in your astronomical dataset.

When we built MyDataclass above, we invisibly picked up two major aspects from HyraxDataset:

  1. Unique IDs: Every tensor in our dataset got an ID of a sequential zero-based index, which was exactly the argument to __getitem__/[]. This list of ids is available as an iterator by calling ids() on the dataset object. These IDs are used in inference results and visualizations of the data, but they can be overriden.

  2. Metadata Interface: Every HyraxDataset can provide an astropy Table of values in the same order as their __getitem__/[] This allows each tensor in the dataset to have associated scalar data such as ra/dec, ephemeris parameters, redshift, magnitude, etc. For our class there currently is no metadata.

Below is how we would access the metadata and IDs demonstrating the default behavior if your custom class does no overrides:

[6]:
import hyrax

h = hyrax.Hyrax()
h.config["model_inputs"] = {
    "data": {
        "dataset_class": "MyDataset",
        "data_location": "/fake/path/to/some/data",
    }
}

dataset = h.prepare()

print("\nIDs:")
print(f"list(dataset.ids())[0:10] = {list(dataset.ids())[0:10]}")


print("\nMetadata field list:")
print(f"dataset.metadata_fields() = {dataset.metadata_fields()} (there is no metadata)")
[2025-09-17 14:25:07,987 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:08,001 hyrax.prepare:INFO] Finished Prepare

IDs:
list(dataset.ids())[0:10] = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

Metadata field list:
dataset.metadata_fields() = ['object_id'] (there is no metadata)

Adding IDs

We’re going to use the filename in our fake data as IDs by adding a single ids() method to our MyDataset object. The most expedient way to do this will be to redefine the entire class below. Note that functions marked with a comment are just the same as earlier.

Note that the ids() function is required to return a generator, so we will use a for loop and yield each sequential value. This interface allows Hyrax to partially enumerate the IDs in a dataset when that is desirable. It is easy enough to get all the ids in order with list(dataset.ids()).

[7]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset


class MyDataset(HyraxDataset, Dataset):
    def ids(self):
        for filename in self.filenames:
            yield filename

    # Unchanged from before below this comment ...
    def __init__(self, config: dict, data_location: Union[Path, str] = None):
        self.filenames = MyDataset._list_filenames(data_location)

        super().__init__(config)

    def get_image(self, index):
        """Pretend to read specific data from the disk."""
        filename = self.filenames[index]
        global random_data
        return random_data[filename]

    def __getitem__(self, idx):
        return {
            "data": {"image": MyDataset.get_image(self.filenames[idx])},
        }

    def __len__(self):
        return len(self.filenames)

    @staticmethod
    def _list_filenames(data_location):
        """This is a pretend implementation so we ignore data_location"""
        global filenames
        return filenames

Running prepare again on our newly defined dataset class, we can see that the ids are now the fake “filenames” we generated at the top of the notebook, rather than sequential integers:

[8]:
import hyrax

h = hyrax.Hyrax()
h.config["model_inputs"] = {
    "data": {
        "dataset_class": "MyDataset",
        "data_location": "/fake/path/to/some/data",
    }
}

dataset = h.prepare()

print("\nIDs:")
print(f"list(dataset.ids())[0:5] = {list(dataset.ids())[0:5]}")
[2025-09-17 14:25:08,051 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 14:25:08,063 hyrax.prepare:INFO] Finished Prepare

IDs:
list(dataset.ids())[0:5] = ['UYvkmfDTyMHydyA', 'SOErTFtdwGspvJS', 'QVsvWTasGFiEtKc', 'QWqksbRihihOqde', 'BFXVfmWywdOSWmi']

Adding Metadata

Now we are going to generate some fake metadata for our fake data. This will take the form of random ra/dec pairs for each fake object.

[9]:
import astropy.units as u

ras = rng.uniform(low=0.0, high=360.0, size=num_tensors) * u.deg
decs = rng.uniform(low=-90.0, high=90.0, size=num_tensors) * u.deg

In order to override metadata we will provide HyraxDataset with an astropy table containing all of the metadata in the constructor for our class as shown below. We do this in __init__ by passing an astropy table of our metadata to super().__init__ as a second, optional argument.

Note the new function _read_metadata() which constructs this table. On a real dataset this function would most likely call astropy’s Table.read high level interface to construct a table directly from your catalog.

As before we re-implement the entire class below with small modifications marked with comments:

[34]:
from torch.utils.data import Dataset
from hyrax.data_sets import HyraxDataset


class MyDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict, data_location: Union[Path, str] = None):
        self.filenames = MyDataset._list_filenames(config["general"]["data_dir"])
        metadata_table = MyDataset._read_metadata(config["general"]["data_dir"])
        super().__init__(config, metadata_table=metadata_table)

    def _read_metadata(path_to_data):
        """This is a pretend implementation so we don't use the path passed, which you might use
        to find your .csv/.fits/.tsv catalog file and call astropy's Table.read().

        We simply construct a table from our mock data"""
        from astropy.table import Table

        global ras, decs, filenames
        return Table({"object_id": filenames, "ra": ras, "dec": decs})

    # Unchanged from before below this comment ...
    def ids(self):
        for filename in self.filenames:
            yield filename

    def get_image(self, index):
        """Pretend to read specific data from the disk."""
        filename = self.filenames[index]
        global random_data
        return random_data[filename]

    def __getitem__(self, idx):
        return {
            "data": {"image": MyDataset.get_image(self.filenames[idx])},
            "object_id": self.filenames[idx],
        }

    def __len__(self):
        return len(self.filenames)

    @staticmethod
    def _list_filenames(data_location):
        """This is a pretend implementation so we ignore data_location"""
        global filenames
        return filenames

Now that our dataset class supports metadata, we can access the metadata interface directly on the dataset object using the metadata_fields and metadata functions on the dataset object.

  • metadata_fields lists the available fields, in our case only “ra” and “dec” are available, but this is only because that is what was defined in the cell above

  • metadata takes a list (or array) of indexes, and a list (or array) of valid fields. It returns a numpy rec-array of the selected metadata fields for the selected data indexes. It is essentially equivalent to metadata_table[indexes][fields].as_array() where metadata_table is the original astropy table.

[41]:
import hyrax
from astropy.table import Table

h = hyrax.Hyrax()
h.config["model_inputs"] = {
    "data": {
        "dataset_class": "MyDataset",
        "data_location": "/fake/path/to/some/data",
    }
}

dataset = h.prepare()

print("\nMetadata field list:")
print(f"dataset.metadata_fields() = {dataset.metadata_fields()}")
print(f'Table(dataset.metadata([1, 3, 4], ["ra_data", "dec_data"])) =>')
Table(dataset.metadata([1, 3, 4], ["ra_data", "dec_data"]))
[2025-09-17 16:47:01,811 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 16:47:01,832 hyrax.prepare:INFO] Finished Prepare

Metadata field list:
dataset.metadata_fields() = ['object_id_data', 'ra_data', 'dec_data', 'object_id']
Table(dataset.metadata([1, 3, 4], ["ra_data", "dec_data"])) =>
[41]:
Table length=3
ra_datadec_data
float64float64
292.4007388647287-87.13940122247789
84.67360544447406-59.027978794560354
218.76113549876985-1.0705419754372798

Now that we have a Dataset capable of ‘ra’ and ‘dec’ metadata, we can do a full analysis with hyrax, training the model, infering the latent space, umapping the latent space to a 2d representation, and visualize-ing the result.

[32]:
import hyrax

h = hyrax.Hyrax()
h.config["model_inputs"] = {
    "data": {
        "dataset_class": "MyDataset",
        "data_location": "/fake/path/to/some/data",
        "primary_id_field": "object_id",
    }
}
h.config["model"]["name"] = "HyraxAutoencoder"

h.train()
h.infer()
h.umap()
v = h.visualize()
[2025-09-17 16:44:01,044 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2025-09-17 16:44:01,062 hyrax.models.hyrax_autoencoder:INFO] Found shape: (3, 10, 10) in data sample, using this to initialize model.
[2025-09-17 16:44:01,065 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-09-17 16:44:01,066 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
  Dataset':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2a6f2d5630>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-09-17 16:44:01,066 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
  Dataset':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7c2a6f4c7da0>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025/09/17 16:44:01 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-09-17 16:44:01,085 hyrax.pytorch_ignite:INFO] Training model on device: cuda
[2025-09-17 16:44:01,760 hyrax.pytorch_ignite:INFO] Total training time: 0.67[s]
[2025-09-17 16:44:01,760 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-train-tfhe/checkpoint_epoch_10.pt
[2025-09-17 16:44:01,761 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-train-tfhe/checkpoint_10_loss=-28.0658.pt
2025/09/17 16:44:01 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/09/17 16:44:01 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-09-17 16:44:01,769 hyrax.verbs.train:INFO] Finished Training
[2025-09-17 16:44:01,817 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-train-tfhe/example_model_opset_20.onnx
[2025-09-17 16:44:01,832 hyrax.models.hyrax_autoencoder:INFO] Found shape: (3, 10, 10) in data sample, using this to initialize model.
[2025-09-17 16:44:01,834 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-09-17 16:44:01,835 hyrax.verbs.infer:INFO] data set has length 1000
2025-09-17 16:44:01,835 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data
  Dataset':
        {'sampler': None, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
[2025-09-17 16:44:01,840 hyrax.verbs.infer:INFO] Saving inference results at: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-infer-ZP0M
[2025-09-17 16:44:02,237 hyrax.pytorch_ignite:INFO] Evaluating model on device: cuda
[2025-09-17 16:44:02,239 hyrax.pytorch_ignite:INFO] Total epochs: 1
[2025-09-17 16:44:02,293 hyrax.pytorch_ignite:INFO] Total evaluation time: 0.06[s]
[2025-09-17 16:44:02,369 hyrax.verbs.infer:INFO] Inference Complete.
[2025-09-17 16:44:02,406 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /home/drew/code/hyrax/docs/pre_executed/results/20250917-164401-infer-ZP0M for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-09-17 16:44:02,441 hyrax.verbs.umap:INFO] Saving UMAP results to /home/drew/code/hyrax/docs/pre_executed/results/20250917-164402-umap-Volp
[2025-09-17 16:44:02,826 hyrax.verbs.umap:INFO] Fitting the UMAP
[2025-09-17 16:44:03,793 hyrax.verbs.umap:INFO] Saving fitted UMAP Reducer
[2025-09-17 16:44:05,269 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP
[2025-09-17 16:44:05,308 hyrax.verbs.visualize:INFO] UMAP directory not specified at runtime. Reading from config values.
[2025-09-17 16:44:05,310 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /home/drew/code/hyrax/docs/pre_executed/results/20250917-164402-umap-Volp for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-09-17 16:44:05,350 hyrax.verbs.visualize:INFO] Rendering UMAP from the following directory: /home/drew/code/hyrax/docs/pre_executed/results/20250917-164402-umap-Volp
[ ]: