Using a custom model class#
Hyrax comes with several built-in model architectures, but you can also define your own PyTorch model and register it with Hyrax. This is useful when you need a custom architecture, multimodal inputs, or other specialized behavior that the built-in models do not provide.
This notebook demonstrates:
Registering a custom model with the
@hyrax_modeldecorator so Hyrax can find and instantiate it.Training an image-only model using the familiar
h.train()call.Training a multimodal model that combines image data with a variable-length light curve via a custom
prepare_inputsmethod.
A small synthetic dataset is included in this notebook so it runs end-to-end without any external data. If you are new to writing Hyrax dataset classes, see external_dataset_class.ipynb first.
What you will need
Hyrax installed (
pip install hyrax)PyTorch (
pip install torch)
1) Synthetic dataset#
To keep this notebook self-contained we define a small in-memory dataset. It generates random images, classification labels, and variable-length light curves. The collate_light_curve method pads the light curves to the same length within each batch and creates a corresponding binary mask (1 where real data exists, 0 where padding was added); other fields can use Hyrax’s default collation.
See external_dataset_class.ipynb for a step-by-step explanation of the dataset API.
[ ]:
import numpy as np
from hyrax.datasets import HyraxDataset
class NotebookSurveyDatasetWithLightCurves(HyraxDataset):
def __init__(self, config, data_location=None):
n_objects = config.get("n_objects", 64)
min_len = config.get("min_len", 20)
self.max_len = config.get("max_len", 100)
rng = np.random.default_rng(11)
self.images = rng.normal(size=(n_objects, 3, 32, 32)).astype(np.float32)
self.labels = rng.integers(0, 5, size=n_objects, dtype=np.int64)
# Each light curve has a different number of time steps.
lengths = rng.integers(min_len, self.max_len + 1, size=n_objects)
self.light_curves = [rng.normal(size=int(n)).astype(np.float32) for n in lengths]
super().__init__(config)
def __len__(self):
return len(self.images)
def get_image(self, idx: int) -> np.ndarray:
return self.images[idx]
def get_label(self, idx: int) -> np.int64:
return self.labels[idx]
def get_light_curve(self, idx: int) -> np.ndarray:
return self.light_curves[idx] # variable-length 1-D array
def get_object_id(self, idx: int) -> str:
return f"obj-{idx:05d}"
def collate_light_curve(self, samples: list[dict]) -> dict:
# samples is a list of dicts, each shaped {"field": value, ...}.
retval = {}
curves = [s["light_curve"] for s in samples]
# Pad all light curves to the same length and record which entries are real.
light_curve = np.zeros((len(curves), self.max_len), dtype=np.float32)
light_curve_mask = np.zeros((len(curves), self.max_len), dtype=np.float32)
for i, c in enumerate(curves):
n = len(c)
light_curve[i, :n] = c
light_curve_mask[i, :n] = 1.0 # 1 = real data, 0 = padding
retval["light_curve"] = light_curve
retval["light_curve_mask"] = light_curve_mask
return retval
2) Define and register a custom model#
A Hyrax-compatible model is a standard PyTorch nn.Module with three additional methods and the @hyrax_model decorator:
Method |
Purpose |
|---|---|
|
Extracts the tensors your model needs from the batch dictionary returned by the dataset. Returns a tuple that is passed directly to |
|
Runs one forward + backward pass, steps the optimizer, and returns a dictionary of scalar metrics (e.g. |
|
Runs one forward pass and returns a dictionary of scalar metrics (e.g. |
|
Runs a forward pass (no gradient) and returns the model output. |
The @hyrax_model decorator registers the class by name so Hyrax can find it when you set model.name.
The constructor receives config (the Hyrax config dict) and an optional data_sample — a single batch produced by prepare_inputs. The sample lets you infer input dimensions automatically instead of hardcoding them.
[ ]:
import torch
import torch.nn as nn
from hyrax.models.model_registry import hyrax_model
@hyrax_model # registers the class so h.set_config("model.name", "ImageOnlyCNN") works
class ImageOnlyCNN(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
# Unpack the sample to derive input dimensions automatically
image_sample, _ = data_sample
c, h, w = image_sample.shape[1:] # channels, height, width
self.net = nn.Sequential(
nn.Conv2d(c, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(8 * h * w, 5), # 5 output classes
)
def forward(self, batch: tuple) -> torch.Tensor:
image, _ = batch
return self.net(image)
@staticmethod
def prepare_inputs(data_dict: dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
"""Pull the tensors this model needs out of the batch dictionary."""
image = data_dict["data"]["image"]
label = data_dict["data"]["label"]
return image, label
def train_batch(self, batch: tuple[torch.Tensor]) -> dict:
_, labels = batch
self.optimizer.zero_grad()
logits = self.forward(batch)
loss = self.criterion(logits, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}
def validate_batch(self, batch: tuple[torch.Tensor]) -> dict:
_, labels = batch
logits = self.forward(batch)
loss = self.criterion(logits, labels)
return {"loss": loss.item()}
def infer_batch(self, batch: tuple[torch.Tensor]) -> torch.Tensor:
return self.forward(batch)
3) Train the image-only model#
The training configuration below uses a standard cross-entropy loss and Adam optimizer. The data_request block tells Hyrax which dataset to use and how to split it into training and validation sets.
split_fraction: 0.8— use 80 % of the dataset objects for training.split_fraction: 0.2— use the remaining 20 % for validation.
Both train and validate groups must reference the same dataset and use split fractions that add up to ≤ 1.
[ ]:
from hyrax import Hyrax
h = Hyrax()
h.set_config("model.name", "ImageOnlyCNN")
h.set_config("train.epochs", 1)
# These are set to default values, uncomment to change.
# h.set_config("data_loader.batch_size", 512)
# h.set_config("criterion.name", "torch.nn.CrossEntropyLoss")
# h.set_config("optimizer.name", "torch.optim.Adam")
# h.set_config("'torch.optim.Adam'", {"lr": 1e-3})
h.set_config(
"data_request",
{
"train": {
"data": {
"dataset_class": "NotebookSurveyDatasetWithLightCurves",
"data_location": "/fake/location/data/is/randomly/generated",
"fields": ["image", "label", "object_id"],
"primary_id_field": "object_id",
"split_fraction": 0.8,
}
},
"validate": {
"data": {
"dataset_class": "NotebookSurveyDatasetWithLightCurves",
"data_location": "/fake/location/data/is/randomly/generated",
"fields": ["image", "label", "object_id"],
"primary_id_field": "object_id",
"split_fraction": 0.2,
}
},
},
)
model = h.train()
4) Extend to a multimodal model (image + light curve)#
Astronomical objects are often described by multiple data modalities — an image, a spectrum, a light curve, and so on. You can combine them all inside a single model by including every field in data_request and returning all required tensors from prepare_inputs.
Here we add a second encoder branch that processes the padded light curve. The binary mask from collate_light_curve is used to zero out the padded time steps so they do not influence the learned representation.
[ ]:
@hyrax_model
class ImageAndLightCurveModel(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
image, light_curve, light_curve_mask, _ = data_sample
c, h, w = image.shape[1:]
lc_dim = light_curve.shape[1] # padded sequence length
print(f"LC_DIM={lc_dim}")
# Image branch: simple CNN → 32-d embedding
self.image_encoder = nn.Sequential(
nn.Conv2d(c, 8, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(8 * h * w, 32),
nn.ReLU(),
)
# Light-curve branch: linear projection → 32-d embedding
self.lc_encoder = nn.Sequential(
nn.Linear(lc_dim, 32),
nn.ReLU(),
)
# Final classifier over the concatenated 64-d representation
self.classifier = nn.Linear(64, 5)
def forward(self, batch: tuple[torch.Tensor]) -> torch.Tensor:
# Zero out padded time steps before encoding
image, light_curve, light_curve_mask, _ = batch
masked_lc = light_curve * light_curve_mask
image_feat = self.image_encoder(image)
lc_feat = self.lc_encoder(masked_lc)
return self.classifier(torch.cat([image_feat, lc_feat], dim=1))
@staticmethod
def prepare_inputs(data_dict: dict[str, torch.Tensor]) -> tuple[torch.Tensor]:
"""Return all four tensors needed by this multimodal model."""
d = data_dict["data"]
return (d["image"], d["light_curve"], d["light_curve_mask"], d["label"])
def train_batch(self, batch: tuple[torch.Tensor]) -> dict:
_, _, _, labels = batch
self.optimizer.zero_grad()
logits = self.forward(batch)
loss = self.criterion(logits, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}
def validate_batch(self, batch: tuple[torch.Tensor]) -> dict:
_, _, _, labels = batch
logits = self.forward(batch)
loss = self.criterion(logits, labels)
return {"loss": loss.item()}
def infer_batch(self, batch: tuple[torch.Tensor]) -> torch.Tensor:
return self.forward(batch)
The only change to the data_request compared to the image-only run is the addition of "light_curve" to the fields list. Hyrax passes those fields through to the dataset’s collate_light_curve method, which produces light_curve and light_curve_mask in the batch.
[ ]:
h2 = Hyrax()
h2.set_config("model.name", "ImageAndLightCurveModel")
h2.set_config("train.epochs", 1)
# These are set to default values, uncomment to change.
# h.set_config("data_loader.batch_size", 512)
# h.set_config("criterion.name", "torch.nn.CrossEntropyLoss")
# h.set_config("optimizer.name", "torch.optim.Adam")
# h.set_config("'torch.optim.Adam'", {"lr": 1e-3})
h2.set_config(
"data_request",
{
"train": {
"data": {
"dataset_class": "NotebookSurveyDatasetWithLightCurves",
"data_location": "/fake/location/data/is/randomly/generated",
"fields": ["image", "light_curve", "label", "object_id"], # light_curve added
"primary_id_field": "object_id",
"split_fraction": 0.8,
}
},
"validate": {
"data": {
"dataset_class": "NotebookSurveyDatasetWithLightCurves",
"data_location": "/fake/location/data/is/randomly/generated",
"fields": ["image", "light_curve", "label", "object_id"],
"primary_id_field": "object_id",
"split_fraction": 0.2,
}
},
},
)
model = h2.train()
5) Moving your model into an external package#
Once the notebook version works, you can move the model (and dataset) class into an installable Python package so it can be reused across projects and shared with collaborators.
Copy both classes into your package, e.g.
mypackage/models.pyandmypackage/datasets.py.Make sure the package is installed (
pip install -e .) in the same environment as Hyrax.Use the fully-qualified class name wherever you previously used the short name:
h.set_config("model.name", "mypackage.models.ImageAndLightCurveModel")
h.set_config("data_request", {
"train": {
"data": {
"dataset_class": "mypackage.datasets.NotebookSurveyDatasetWithLightCurves",
...
}
},
...
})
Hyrax will import the class automatically at runtime — no extra registration step is required.