Model class reference#

This page is the ground truth for writing a Hyrax model class.

If you are not used to writing classes, follow this page method-by-method and copy/edit the examples.

How Hyrax uses your model class#

Hyrax does this sequence:

  1. Builds one sample batch from your dataset.

  2. Calls YourModel.prepare_inputs(sample_batch).

  3. Creates your model with YourModel(config, data_sample=prepared_sample).

  4. During runs, calls:

    • train_batch for training

    • infer_batch for inference

    • validate_batch for validation during training (if you use validate data)

    • test_batch for test (if you run the test verb)

Required methods (checklist)#

Your class must have these:

  1. Inherit from torch.nn.Module.

  2. Add @hyrax_model above the class (this registers your model so Hyrax can find it by name in the configuration).

  3. Implement:

    • __init__(self, config, data_sample=None)

    • forward(self, ...) (required by PyTorch modules)

    • infer_batch(self, batch)

    • train_batch(self, batch)

  4. Set model.name to full class path in config.

Method-by-method guide#

__init__(self, config, data_sample=None)#

What to do in this method:

  1. Save config to self.config.

  2. Read model-specific settings from config.

  3. Build your layers.

  4. Optionally use data_sample to infer shapes.

The model-specific config values are typically defined in your external library config layout. See External package setup for packaging/config setup.

Example (read settings from config):

def __init__(self, config, data_sample=None):
    super().__init__()
    self.config = config

    hidden = config["my_package"]["MyModel"]["hidden_size"]
    out_dim = config["my_package"]["MyModel"]["num_classes"]

    in_dim = 128
    if data_sample is not None:
        x, _ = data_sample
        in_dim = x.shape[-1]

    self.fc1 = nn.Linear(in_dim, hidden)
    self.fc2 = nn.Linear(hidden, out_dim)

To set those config values in a notebook, one simple pattern is:

h.config.setdefault("my_package", {}).setdefault("MyModel", {})["hidden_size"] = 64
h.config["my_package"]["MyModel"]["num_classes"] = 10

PyTorch reference for building nn.Module layers: torch.nn.Module

forward(self, ...)#

Even though Hyrax calls train_batch / infer_batch, your model should still implement forward because that is the core PyTorch pattern.

PyTorch reference for forward: Module.forward

Example:

def forward(self, x):
    x = self.fc1(x)
    x = nn.functional.relu(x)
    x = self.fc2(x)
    return x

@staticmethod prepare_inputs(data_dict)#

This converts collated dataset fields into model inputs. For background on where this sits in the pipeline, see Data Flow Through Hyrax.

Important:

  • define it as @staticmethod

  • return either: * one NumPy array, or * a tuple like (features, labels)

Example:

@staticmethod
def prepare_inputs(data_dict):
    features = data_dict["data"]["flux"]
    labels = data_dict["data"].get("label", None)
    return features, labels

If you do not define this method, Hyrax uses a default expecting data_dict["data"]["image"] and optional data_dict["data"]["label"].

infer_batch(self, batch)#

Used by the infer verb.

Usually this unpacks the batch and calls forward.

Example:

def infer_batch(self, batch):
    features, _ = batch
    return self.forward(features)

train_batch(self, batch)#

Used by the train verb.

Typical steps:

  1. unpack inputs

  2. zero gradients

  3. call forward

  4. compute loss

  5. backprop + optimizer step

  6. return a dictionary with "loss" and optional metrics

Example:

def train_batch(self, batch):
    features, labels = batch
    self.optimizer.zero_grad()
    logits = self.forward(features)
    loss = self.criterion(logits, labels)
    loss.backward()
    self.optimizer.step()
    return {"loss": loss.item(), "acc": 0.84}

Important details:

  • Hyrax treats lower ``loss`` as better for best-checkpoint selection.

  • Any extra metrics you return (like acc) are logged to TensorBoard and MLflow. See Using TensorBoard and MLFlow. For documentation on comparing runs across experiments, see Model comparison.

Optional methods#

validate_batch(self, batch)#

If the training run includes a validate dataset group, the train verb will run this method each epoch.

Recommended behavior: similar to train_batch but do not update weights.

Example:

def validate_batch(self, batch):
    features, labels = batch
    logits = self.forward(features)
    loss = self.criterion(logits, labels)
    return {"loss": loss.item(), "val_acc": 0.81}

test_batch(self, batch)#

Used by the test verb.

Return a dictionary with "loss" (plus optional metrics).

Example:

def test_batch(self, batch):
    features, labels = batch
    logits = self.forward(features)
    loss = self.criterion(logits, labels)
    return {"loss": loss.item(), "test_acc": 0.82}

log_epoch_metrics(self)#

Optional helper for extra per-epoch metrics during training.

If present, Hyrax calls it at epoch end and logs returned values to TensorBoard and MLflow. See Using TensorBoard and MLFlow.

Example:

def log_epoch_metrics(self):
    return {"my_metric": 0.42}

Optimizer / criterion / scheduler setup#

You have two ways to provide each training component:

  • Manual means you create it in your model class (for example, self.optimizer = torch.optim.Adam(...)).

  • Automatic means you set its name in config and Hyrax creates it for you.

These choices are independent. You can mix them.

Example: manual optimizer + automatic criterion/scheduler

# in model __init__ (manual optimizer)
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

# in config (automatic criterion + scheduler)
h.set_config("criterion.name", "torch.nn.CrossEntropyLoss")
h.set_config("scheduler.name", "torch.optim.lr_scheduler.StepLR")

This mixed mode is valid. Optimizer/criterion/scheduler do not all need to use exactly the same style.

Complete minimal class#

import torch.nn as nn
from hyrax.models.model_registry import hyrax_model


@hyrax_model
class MyModel(nn.Module):
    def __init__(self, config, data_sample=None):
        super().__init__()
        self.config = config
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        return self.fc(x)

    @staticmethod
    def prepare_inputs(data_dict):
        features = data_dict["data"]["flux"]
        labels = data_dict["data"].get("label", None)
        return features, labels

    def infer_batch(self, batch):
        features, _ = batch
        return self.forward(features)

    def train_batch(self, batch):
        features, labels = batch
        self.optimizer.zero_grad()
        out = self.forward(features)
        loss = self.criterion(out, labels)
        loss.backward()
        self.optimizer.step()
        return {"loss": loss.item()}

Notebook-first path#

Start with Using a custom model class, then move the class into an external package when it works.