Custom training metrics

Custom training metrics#

Hyrax automatically logs every key/value pair that your model returns from train_batch and validate_batch. By default the built-in models return only {"loss": ...}, but you can add any number of extra metrics simply by including them in the returned dictionary.

Hyrax also supports epoch-level metrics through an optional log_epoch_metrics method on your model. This is useful for quantities that only make sense when computed over a full epoch rather than a single batch.

All metrics are written to both TensorBoard and MLflow with no extra configuration required.

Batch-level metrics#

The train_batch method is called once per batch during training. Whatever dictionary it returns is treated as the batch metrics. Every 10 iterations those metrics are logged to TensorBoard and MLflow under the training/ prefix.

The same applies to validate_batch — its returned dictionary is logged under the validation/ prefix.

Below is a minimal autoencoder that returns the reconstruction loss and the mean absolute error in every batch.

[1]:
import torch.nn as nn
import torch.nn.functional as F  # noqa: N812
import torch.optim as optim

from hyrax.models.model_registry import hyrax_model


@hyrax_model
class MetricsDemoAutoencoder(nn.Module):
    """Tiny autoencoder that demonstrates custom training metrics."""

    def __init__(self, config, data_sample=None):
        super().__init__()
        self.encoder = nn.Linear(4, 2)
        self.decoder = nn.Linear(2, 4)

    def forward(self, x):
        return self.encoder(x)

    # ------------------------------------------------------------------
    # KEY POINT: return a dict with *all* the metrics you want logged.
    # Hyrax will automatically log every key/value pair.
    # ------------------------------------------------------------------
    def train_batch(self, batch):
        self.optimizer.zero_grad()
        x_hat = self.decoder(self(batch))

        mse_loss = F.mse_loss(x_hat, batch)
        mae = F.l1_loss(x_hat, batch)

        mse_loss.backward()
        self.optimizer.step()

        # Return as many metrics as you like — they are all logged.
        return {
            "loss": mse_loss.item(),
            "mae": mae.item(),
        }

    def validate_batch(self, batch):
        x_hat = self.decoder(self(batch))
        return {
            "loss": F.mse_loss(x_hat, batch).item(),
            "mae": F.l1_loss(x_hat, batch).item(),
        }
Hyrax model MetricsDemoAutoencoder missing required method infer_batch.

With the model above, every 10 training iterations Hyrax will write both training/loss and training/mae to TensorBoard and MLflow. Validation metrics will appear as validation/loss and validation/mae.

You can add as many keys as you need — for example, per-band losses, gradient norms, or classification accuracy.

Epoch-level metrics#

Some metrics only make sense when computed over an entire epoch — for example, a running average or a statistic accumulated across all batches.

If your model defines a method called log_epoch_metrics, Hyrax will call it at the end of every training epoch and log the returned dictionary under training/epoch/. The method takes no arguments (beyond self) and must return a dict[str, float].

Here is an example that tracks the average loss over each epoch.

[2]:
@hyrax_model
class EpochMetricsDemoAutoencoder(nn.Module):
    """Autoencoder that reports both batch and epoch-level metrics."""

    def __init__(self, config, data_sample=None):
        super().__init__()
        self.encoder = nn.Linear(4, 2)
        self.decoder = nn.Linear(2, 4)

        # Accumulators for epoch-level stats
        self._epoch_loss_sum = 0.0
        self._epoch_batch_count = 0

    def forward(self, x):
        return self.encoder(x)

    def train_batch(self, batch):
        self.optimizer.zero_grad()
        x_hat = self.decoder(self(batch))
        loss = F.mse_loss(x_hat, batch)
        loss.backward()
        self.optimizer.step()

        # Accumulate for the epoch-level metric
        self._epoch_loss_sum += loss.item()
        self._epoch_batch_count += 1

        return {"loss": loss.item()}

    # ------------------------------------------------------------------
    # KEY POINT: implement log_epoch_metrics to report epoch-level stats.
    # Hyrax calls this at the end of every training epoch and logs the result.
    # ------------------------------------------------------------------
    def log_epoch_metrics(self):
        avg_loss = self._epoch_loss_sum / max(self._epoch_batch_count, 1)

        # Reset accumulators for the next epoch
        self._epoch_loss_sum = 0.0
        self._epoch_batch_count = 0

        return {"avg_loss": avg_loss}
Hyrax model EpochMetricsDemoAutoencoder missing required method infer_batch.

During training the avg_loss value will appear in TensorBoard and MLflow as training/epoch/avg_loss, plotted once per epoch.

Summary#

What you want to log

Where to put it

Logged as

Per-batch training metrics

Return from train_batch

training/<key>

Per-batch validation metrics

Return from validate_batch

validation/<key>

Per-epoch training metrics

Return from log_epoch_metrics

training/epoch/<key>

No extra configuration is needed — Hyrax picks up every key in the returned dictionaries and writes them to TensorBoard and MLflow automatically.