Model class reference#
This page is the ground truth for writing a Hyrax model class.
If you are not used to writing classes, follow this page method-by-method and copy/edit the examples.
How Hyrax uses your model class#
Hyrax does this sequence:
Builds one sample batch from your dataset.
Calls
YourModel.prepare_inputs(sample_batch).Creates your model with
YourModel(config, data_sample=prepared_sample).During runs, calls:
train_batchfor traininginfer_batchfor inferencevalidate_batchfor validation during training (if you usevalidatedata)test_batchfor test (if you run thetestverb)
Required methods (checklist)#
Your class must have these:
Inherit from
torch.nn.Module.Add
@hyrax_modelabove the class (this registers your model so Hyrax can find it by name in the configuration).Implement:
__init__(self, config, data_sample=None)forward(self, ...)(required by PyTorch modules)infer_batch(self, batch)train_batch(self, batch)
Set
model.nameto full class path in config.
Method-by-method guide#
__init__(self, config, data_sample=None)#
What to do in this method:
Save
configtoself.config.Read model-specific settings from config.
Build your layers.
Optionally use
data_sampleto infer shapes.
The model-specific config values are typically defined in your external library config layout. See External package setup for packaging/config setup.
Example (read settings from config):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
hidden = config["my_package"]["MyModel"]["hidden_size"]
out_dim = config["my_package"]["MyModel"]["num_classes"]
in_dim = 128
if data_sample is not None:
x, _ = data_sample
in_dim = x.shape[-1]
self.fc1 = nn.Linear(in_dim, hidden)
self.fc2 = nn.Linear(hidden, out_dim)
To set those config values in a notebook, one simple pattern is:
h.config.setdefault("my_package", {}).setdefault("MyModel", {})["hidden_size"] = 64
h.config["my_package"]["MyModel"]["num_classes"] = 10
PyTorch reference for building nn.Module layers:
torch.nn.Module
forward(self, ...)#
Even though Hyrax calls train_batch / infer_batch, your model should still
implement forward because that is the core PyTorch pattern.
PyTorch reference for forward:
Module.forward
Example:
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
@staticmethod prepare_inputs(data_dict)#
This converts collated dataset fields into model inputs. For background on where this sits in the pipeline, see Data Flow Through Hyrax.
Important:
define it as
@staticmethodreturn either: * one NumPy array, or * a tuple like
(features, labels)
Example:
@staticmethod
def prepare_inputs(data_dict):
features = data_dict["data"]["flux"]
labels = data_dict["data"].get("label", None)
return features, labels
If you do not define this method, Hyrax uses a default expecting
data_dict["data"]["image"] and optional data_dict["data"]["label"].
infer_batch(self, batch)#
Used by the infer verb.
Usually this unpacks the batch and calls forward.
Example:
def infer_batch(self, batch):
features, _ = batch
return self.forward(features)
train_batch(self, batch)#
Used by the train verb.
Typical steps:
unpack inputs
zero gradients
call
forwardcompute loss
backprop + optimizer step
return a dictionary with
"loss"and optional metrics
Example:
def train_batch(self, batch):
features, labels = batch
self.optimizer.zero_grad()
logits = self.forward(features)
loss = self.criterion(logits, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item(), "acc": 0.84}
Important details:
Hyrax treats lower ``loss`` as better for best-checkpoint selection.
Any extra metrics you return (like
acc) are logged to TensorBoard and MLflow. See Using TensorBoard and MLFlow. For documentation on comparing runs across experiments, see Model comparison.
Optional methods#
validate_batch(self, batch)#
If the training run includes a validate dataset group, the train verb will
run this method each epoch.
Recommended behavior: similar to train_batch but do not update weights.
Example:
def validate_batch(self, batch):
features, labels = batch
logits = self.forward(features)
loss = self.criterion(logits, labels)
return {"loss": loss.item(), "val_acc": 0.81}
test_batch(self, batch)#
Used by the test verb.
Return a dictionary with "loss" (plus optional metrics).
Example:
def test_batch(self, batch):
features, labels = batch
logits = self.forward(features)
loss = self.criterion(logits, labels)
return {"loss": loss.item(), "test_acc": 0.82}
log_epoch_metrics(self)#
Optional helper for extra per-epoch metrics during training.
If present, Hyrax calls it at epoch end and logs returned values to TensorBoard and MLflow. See Using TensorBoard and MLFlow.
Example:
def log_epoch_metrics(self):
return {"my_metric": 0.42}
Optimizer / criterion / scheduler setup#
You have two ways to provide each training component:
Manual means you create it in your model class (for example,
self.optimizer = torch.optim.Adam(...)).Automatic means you set its name in config and Hyrax creates it for you.
These choices are independent. You can mix them.
Example: manual optimizer + automatic criterion/scheduler
# in model __init__ (manual optimizer)
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# in config (automatic criterion + scheduler)
h.set_config("criterion.name", "torch.nn.CrossEntropyLoss")
h.set_config("scheduler.name", "torch.optim.lr_scheduler.StepLR")
This mixed mode is valid. Optimizer/criterion/scheduler do not all need to use exactly the same style.
Complete minimal class#
import torch.nn as nn
from hyrax.models.model_registry import hyrax_model
@hyrax_model
class MyModel(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
self.fc = nn.Linear(128, 10)
def forward(self, x):
return self.fc(x)
@staticmethod
def prepare_inputs(data_dict):
features = data_dict["data"]["flux"]
labels = data_dict["data"].get("label", None)
return features, labels
def infer_batch(self, batch):
features, _ = batch
return self.forward(features)
def train_batch(self, batch):
features, labels = batch
self.optimizer.zero_grad()
out = self.forward(features)
loss = self.criterion(out, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}
Notebook-first path#
Start with Using a custom model class, then move the class into an external package when it works.