Custom dataset collation#
In this notebook we create a custom collation function for the HyraxRandomDataset.
What is collation? When a DataLoader assembles individual samples into a batch, it calls a collate function to combine them. Hyrax provides a default implementation, but custom collation is required when samples have non-uniform shapes. A common example is light-curve data: different objects may have different numbers of photometric observations, so shorter sequences must be padded to a common length and a Boolean mask must indicate which values are padding. The model then uses the mask to ignore padded positions.
For simplicity we use a dataset with a uniform shape, so we can focus on the mechanics of creating a custom collate function rather than the padding logic. A mask is still created, but all its values will be False (no padding needed).
We begin by creating a Hyrax instance.
[1]:
from hyrax import Hyrax
h = Hyrax()
Next we define a data request. The data request is a nested dictionary that tells Hyrax which datasets to load and which fields to expose. The top-level key "train" names the data split; inside it, "data" is the friendly name assigned to this particular dataset. The friendly name appears as a key in every sample returned by the DataProvider, so choosing a descriptive name (e.g. "spectra" or "images") helps keep multi-dataset workflows readable.
[2]:
data_request = {
"train": {
"data": {
"dataset_class": "HyraxRandomDataset",
"data_location": "./data",
"fields": ["object_id", "image", "label"],
"split_fraction": 1.0,
"primary_id_field": "object_id",
},
}
}
h.set_config("data_request", data_request)
Add a custom collation function#
A custom collate static method can be attached directly to the dataset class. Hyrax detects it automatically when the dataset is prepared and uses it in place of the default collation logic.
The function receives a list of per-sample field dictionaries for the dataset’s friendly name — each element corresponds to one item from dataset[i]["data"]. It should return a single dictionary of collated arrays.
Important: guard against missing fields. A dataset class may expose more fields than a user requests in their data_request. The collate function is always called when it exists, even if only a subset of fields was requested. To avoid KeyError exceptions when a field is absent, wrap each field’s collation in an ``if`` check on the first sample (e.g. if "image" in samples[0]). Without these guards the collate function will fail whenever a user omits a field from their request.
In production the method would live in the dataset class file. Attaching it dynamically here, as shown below, is convenient for prototyping in a notebook.
[3]:
from hyrax.datasets.random.hyrax_random_dataset import HyraxRandomDataset
import numpy as np
@staticmethod
def collate(samples: list[dict]) -> dict:
"""Collate a list of dictionaries into a single batch.
This method takes a list of samples and collates them into a single batch.
The returned batch dictionary will contain the following keys (when present
in the requested fields):
- ``object_id``: Numpy array of object IDs for the samples in the batch.
- ``image``: Numpy array of stacked images for the samples in the batch.
- ``mask``: Numpy array of masks with the same shape as the images. Derived
from ``image`` and only present when ``image`` is requested.
- ``label``: Numpy array of labels for the samples in the batch.
Each field is guarded with an ``if`` check so that the function works
correctly even when the user has not requested every field the dataset
exposes.
Parameters
----------
samples : list of dict
A list of samples to collate.
Returns
-------
dict
A single dictionary containing the collated data as numpy arrays.
"""
collated_data = {}
if "object_id" in samples[0]:
collated_data["object_id"] = np.array([sample["object_id"] for sample in samples])
if "image" in samples[0]:
# vertically stack all the images into a single NumPy array
collated_data["image"] = np.stack([sample["image"] for sample in samples], axis=0)
# create a "mask" that has the same shape as the "image" stack.
collated_data["mask"] = np.zeros_like(collated_data["image"], dtype=bool)
if "label" in samples[0]:
collated_data["label"] = np.array([sample["label"] for sample in samples])
return collated_data
# Attach the collation function to the HyraxRandomDataset class
HyraxRandomDataset.collate = collate
Prepare the dataset#
h.prepare() instantiates the requested dataset classes and returns a dictionary keyed by split name. It also picks up the custom collate method we just attached to HyraxRandomDataset.
[4]:
dataset = h.prepare()
# Access the "train" data group for clarity in the next steps
train_dataset = dataset["train"]
[2026-04-23 22:24:09,049 hyrax.verbs.prepare:INFO] Finished Prepare
Inspect individual samples#
We request two samples by index and print their field types and shapes. Each sample is a dictionary whose keys are the friendly name ("data") plus a top-level "object_id" entry. The fields we requested live under sample["data"].
[5]:
sample_0 = train_dataset[0]
sample_1 = train_dataset[1]
[6]:
print("Type and shape / value of fields in sample 0:")
# Fields are nested under the friendly name "data"
print(f"object_id type: {type(sample_0['data']['object_id'])}, value: {sample_0['data']['object_id']}")
print(f"image type: {type(sample_0['data']['image'])}, shape: {sample_0['data']['image'].shape}")
print(f"label type: {type(sample_0['data']['label'])}, value: {sample_0['data']['label']}")
print("\nSample 0")
print(sample_0)
Type and shape / value of fields in sample 0:
object_id type: <class 'str'>, value: 19
image type: <class 'numpy.ndarray'>, shape: (2, 5, 5)
label type: <class 'numpy.int64'>, value: 0
Sample 0
{'data': {'object_id': '19', 'image': array([[[0.08925092, 0.773956 , 0.6545715 , 0.43887842, 0.43301523],
[0.8585979 , 0.08594561, 0.697368 , 0.20146948, 0.09417731],
[0.52647895, 0.9756223 , 0.73575234, 0.7611397 , 0.71747726],
[0.78606427, 0.51322657, 0.12811363, 0.8397482 , 0.45038593],
[0.5003519 , 0.370798 , 0.1825496 , 0.92676497, 0.78156745]],
[[0.6438651 , 0.40241432, 0.8227616 , 0.5454291 , 0.44341415],
[0.45045954, 0.22723871, 0.09213591, 0.55458474, 0.8878898 ],
[0.0638172 , 0.85829127, 0.8276311 , 0.27675968, 0.6316644 ],
[0.16522902, 0.7580877 , 0.70052296, 0.35452592, 0.06791997],
[0.970698 , 0.44568747, 0.89312106, 0.677919 , 0.7783835 ]]],
dtype=float32), 'label': np.int64(0)}, 'object_id': '19'}
Collate two samples manually#
We call train_dataset.collate() directly to see the result. During training, the PyTorch DataLoader calls this same method automatically to build each mini-batch.
[7]:
collated = train_dataset.collate([sample_0, sample_1])
[8]:
print("Collated types and shapes from samples 0 and 1:")
# Results are nested under the friendly name "data", mirroring the sample structure
print(f"image type: {type(collated['data']['image'])}, shape: {collated['data']['image'].shape}")
print(f"object_id type: {type(collated['data']['object_id'])}, shape: {collated['data']['object_id'].shape}")
print(f"label type: {type(collated['data']['label'])}, shape: {collated['data']['label'].shape}")
print(f"mask type: {type(collated['data']['mask'])}, shape: {collated['data']['mask'].shape}")
Collated types and shapes from samples 0 and 1:
image type: <class 'numpy.ndarray'>, shape: (2, 2, 5, 5)
object_id type: <class 'numpy.ndarray'>, shape: (2,)
label type: <class 'numpy.ndarray'>, shape: (2,)
mask type: <class 'numpy.ndarray'>, shape: (2, 2, 5, 5)
The batch now contains a mask field with the same type and shape as image. All values are False here because the data was uniformly shaped and required no padding. In a real use case, positions that were padded would be True, signalling to the model that those values should be ignored.
When HyraxRandomDataset is used in a training run, the DataLoader will call this same collate method automatically — no additional configuration is required.