Build a dataset class in a notebook#

Hyrax uses dataset classes to connect your data to the training and inference pipeline. You write the logic to load one object at a time; Hyrax handles batching, shuffling, train/validation splitting, and delivery to your model.

The key design contract is simple: name each loading method get_<field>(self, idx) and Hyrax will call it automatically whenever "<field>" appears in your data_request configuration.

This notebook shows a notebook-first workflow:

  1. Write the minimum required methods for datasets with fixed-shape fields.

  2. Wire the class into Hyrax with h.prepare() and inspect a sample.

  3. Add a variable-length field (light curves) and a custom collate function to pad sequences to a uniform batch shape.

  4. Move the finished class to a standalone package for production use.

For a complete method-by-method reference see :doc:/dataset_class_reference.

Step 1 — Define the minimum required methods#

Every Hyrax dataset class must:

  • Inherit from HyraxDataset.

  • Implement __len__ so Hyrax knows how many objects are in your dataset.

  • Implement get_<field>(self, idx) for each field listed in data_request. For example, if fields contains "image", Hyrax will call get_image(idx).

  • Call super().__init__(config) at the end of __init__, after your own attributes are set.

The class below uses randomly-generated NumPy arrays in place of real files so the notebook runs without any data on disk.

[ ]:
import numpy as np
from hyrax.datasets import HyraxDataset


class NotebookSurveyDataset(HyraxDataset):
    def __init__(self, config, data_location=None):
        # We control the number of objects through the config. You'll see how to set this in step 2
        n_objects = config.get("n_objects", 64)

        # In a real dataset, use data_location to find your files and
        # load catalogs or file-path lists here (not the heavy per-object data).
        rng = np.random.default_rng(7)
        self.images = rng.normal(size=(n_objects, 3, 32, 32)).astype(np.float32)
        self.labels = rng.integers(0, 5, size=n_objects, dtype=np.int64)
        # Always call super().__init__ last so the base class can inspect your
        # attributes safely.
        super().__init__(config)

    def __len__(self):
        return len(self.images)

    # Each get_* method loads one field for one object.
    # The method name must match the field name in data_request["fields"].
    def get_image(self, idx: int) -> np.ndarray:
        return self.images[idx]

    def get_label(self, idx: int) -> np.int64:
        return self.labels[idx]

    def get_object_id(self, idx: int) -> str:
        return f"obj-{idx:05d}"

Step 2 — Wire the dataset into Hyrax and inspect a sample#

Use h.set_config("data_request", ...) to tell Hyrax which dataset class to use and which fields to request.

The data_request dict has the structure:

{
    "<step name>": {
      "<friendly_name>": {
          "dataset_class": "<ClassName or fully-qualified path>",
          "data_location": "/path/to/your/data",
          "fields": ["<field1>", "<field2>", ...],
          "primary_id_field": "<field_that_uniquely_identifies_each_object>",
      }
    }
}
  • ``<step_name>`` - either ‘train’, ‘test’, ‘validate’, or ‘infer’ intended to allow you to have explicit data specifications for different ML steps

  • ``<friendly_name>`` — a label you choose (e.g. "data") used as the key when accessing batch data in your model. If you have multi-modal data, where each has a different dataset class, then you will have more than one of these stanzas.

  • ``data_location`` Where your data is on the filesystem or network. Used by Hyrax to disambiguate when multiple data sources use the same datast class, but different underlying files.

  • ``fields`` — the list of get_* methods Hyrax will call for each object.

  • ``primary_id_field`` — the field whose value uniquely identifies an object; used for dataset splitting and result lookup.

h.prepare() instantiates the dataset and returns a DataProvider you can index like a list to retrieve individual samples.

[ ]:
from hyrax import Hyrax

h = Hyrax()
h.set_config(
    "data_request",
    {
        "train": {
            "data": {  # friendly name — can be any string
                "dataset_class": "NotebookSurveyDataset",
                "data_location": "fake/location/because/data/is/randomly/generated",
                "dataset_config": {"n_objects": 32},
                "fields": ["image", "label", "object_id"],
                "primary_id_field": "object_id",
            }
        }
    },
)

prepared = h.prepare()

# Prepared contains data providers for each defined step. prepared["train"] is the training data provider.
# Within the data provider, Each indexed item is a dict keyed by friendly name.
# prepared["train"][0]["data"] contains the field values for the first object.
first = prepared["train"][0]
print("top-level keys:", list(first.keys()))
print("data keys:  ", list(first["data"].keys()))
# print("image_data: ", first["data"]["image"])

Step 3 — Handle variable-length fields with a custom collate function#

Fields that return a different-length array for each object (such as light curves sampled at irregular cadences) cannot be stacked into a single tensor directly. You need to pad the sequences to a common length and provide a boolean mask that tells your model which entries are real data versus padding.

Hyrax supports this through an optional collate(self, samples) method on your dataset class. When present, Hyrax delegates batch assembly to your method instead of using its default stacking logic.

The samples argument is a list of per-object dicts. Each dict has the shape {"data": {"field1": value, "field2": value, ...}}. Your method must return a dict with the same {"data": {...}} wrapper, with field values replaced by batch-dimension arrays.

The class below extends the previous example with get_light_curve and a collate that pads curves to the length of the longest sequence in the bdatasetatch.

[ ]:
import numpy as np
from hyrax.datasets import HyraxDataset


class NotebookSurveyDatasetWithLightCurves(HyraxDataset):
    def __init__(self, config, data_location=None):
        n_objects = config.get("n_objects", 64)
        min_len = config.get("min_len", 20)
        self.max_len = config.get("max_len", 100)

        rng = np.random.default_rng(11)
        self.images = rng.normal(size=(n_objects, 3, 32, 32)).astype(np.float32)
        self.labels = rng.integers(0, 5, size=n_objects, dtype=np.int64)
        # Each light curve has a different number of time steps.
        lengths = rng.integers(min_len, self.max_len + 1, size=n_objects)
        self.light_curves = [rng.normal(size=int(n)).astype(np.float32) for n in lengths]
        super().__init__(config)

    def __len__(self):
        return len(self.images)

    def get_image(self, idx: int) -> np.ndarray:
        return self.images[idx]

    def get_label(self, idx: int) -> np.int64:
        return self.labels[idx]

    def get_light_curve(self, idx: int) -> np.ndarray:
        return self.light_curves[idx]  # variable-length 1-D array

    def get_object_id(self, idx: int) -> str:
        return f"obj-{idx:05d}"

    def collate(self, samples: list[dict]) -> dict:
        # samples is a list of dicts, each shaped {"field": value, ...}.
        data_keys = list(samples[0].keys())

        # Note that we make the collate function only collate columns that are given in samples dicts
        # Each if stanza adds a key value pair to the final return value
        retval = {}

        if "image" in data_keys:
            images = np.stack([s["image"] for s in samples], axis=0).astype(np.float32)
            retval["image"] = images

        if "label" in data_keys:
            labels = np.array([s["label"] for s in samples], dtype=np.int64)
            retval["label"] = labels

        if "light_curve" in data_keys:
            curves = [s["light_curve"] for s in samples]

            # Pad all light curves to the same length and record which entries are real.
            light_curve = np.zeros((len(curves), self.max_len), dtype=np.float32)
            light_curve_mask = np.zeros((len(curves), self.max_len), dtype=np.float32)
            for i, c in enumerate(curves):
                n = len(c)
                light_curve[i, :n] = c
                light_curve_mask[i, :n] = 1.0  # 1 = real data, 0 = padding

            retval["light_curve"] = light_curve
            retval["light_curve_mask"] = light_curve_mask

        if "object_id" in data_keys:
            object_ids = np.array([s["object_id"] for s in samples], dtype=str)
            retval["object_id"] = object_ids

        return retval
[ ]:
h2 = Hyrax()
h2.set_config(
    "data_request",
    {
        "train": {
            "data": {
                "dataset_class": "NotebookSurveyDatasetWithLightCurves",
                "data_location": "fake/location/because/data/is/randomly/generated",
                "dataset_config": {
                    "n_objects": 32,
                    "max_len": 50,
                },
                "fields": ["image", "label", "light_curve", "object_id"],
                "primary_id_field": "object_id",
            }
        }
    },
)

prepared = h2.prepare()
prepared_train = prepared["train"]

# Manually collate four samples to verify the padded shapes.
raw_samples = [prepared_train[i] for i in range(4)]
padded = prepared_train.collate(raw_samples)

# light_curve shape: (batch_size, max_sequence_length)
print("light_curve shape:     ", padded["data"]["light_curve"].shape)
print("light_curve_mask shape:", padded["data"]["light_curve_mask"].shape)

Step 4 — Move the class into a standalone package#

Once the class works in the notebook, move it into a Python package so it can be reused across projects and referenced by its fully-qualified import path in Hyrax config:

"dataset_class": "my_package.datasets.my_dataset.MyDataset"

See :doc:/external_library_package for step-by-step instructions on creating the package structure and installing it locally, and :doc:/pre_executed/external_model_class for the companion notebook that shows how to define a model class following the same workflow.