Tracing Data Flow in Hyrax#
When working with a new dataset or model it can be hard to know whether your data is being processed correctly. Hyrax provides a trace mode that lets you follow a small batch of data items through the entire pipeline — from raw dataset access, through batching and input preparation, all the way to model evaluation — so you can inspect what is actually happening at each step.
This notebook walks through a typical trace session:
Setting up a Hyrax instance
Running a verb with
trace=Nto capture pipeline dataPrinting and navigating the returned
TraceResultDrilling into individual stages and function calls
Trace mode is intended for interactive use in notebooks — it is not a production profiling or logging tool.
Install Hyrax#
Skip this step if you have already installed Hyrax in your environment.
[1]:
# %pip install hyrax
Create a Hyrax instance#
We start by creating a Hyrax instance and pointing it at a simple built-in model and dataset. Here we use HyraxAutoencoder trained on HyraxRandomDataset — a dataset that generates random tensors without requiring any downloaded data, which makes it convenient for experimentation.
You can replace this with your own model and dataset configuration to trace your real workflow.
[2]:
import hyrax
h = hyrax.Hyrax()
Configure the model and data#
We will use the HyraxAutoencoder model and HyraxRandomDataset to keep this notebook self-contained. The random dataset produces small random tensors ([1, 32, 32]) that resemble single-band image cutouts.
[3]:
h.config["model"]["name"] = "HyraxAutoencoder"
# Use a small 1-channel 32×32 image shape to keep things fast
h.config["data_set"]["HyraxRandomDataset"]["shape"] = [1, 32, 32]
h.config["data_set"]["HyraxRandomDataset"]["size"] = 50
h.config["data_set"]["HyraxRandomDataset"]["seed"] = 42
data_request = {
"train": {
"data": {
"dataset_class": "HyraxRandomDataset",
"data_location": "./trace_data",
"fields": ["image"],
"primary_id_field": "object_id",
"split_fraction": 1.0,
},
},
"infer": {
"data": {
"dataset_class": "HyraxRandomDataset",
"data_location": "./trace_data",
"fields": ["image"],
"primary_id_field": "object_id",
"split_fraction": 1.0,
},
},
}
h.set_config("data_request", data_request)
Running a verb in trace mode#
Any instrumented verb (train, infer, test) accepts a trace=N keyword argument. N controls how many data items are traced through the pipeline — keep this small (2–10) to get a readable output.
When trace=N is passed, the verb:
Processes only a single batch of
Nitems instead of the full datasetReturns a
TraceResultobject instead of the usual verb return value
The TraceResult captures the inputs and outputs of every major pipeline step.
[4]:
trace_result = h.train(trace=2)
[2026-04-23 22:25:17,638 hyrax.trace:WARNING] Starting Trace
[2026-04-23 22:25:17,639 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-04-23 22:25:17,711 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-04-23 22:25:17,712 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-04-23 22:25:17,713 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-04-23 22:25:17,714 hyrax.verbs.train:INFO] Training model: HyraxAutoencoder
[2026-04-23 22:25:17,715 hyrax.verbs.train:INFO] Training dataset(s):
{'train': Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
}
2026-04-23 22:25:17,718 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data (primary':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x739881213f20>, 'batch_size': 2, 'shuffle': False, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
>, 'pin_memory': False}
2026/04/23 22:25:18 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/04/23 22:25:18 INFO mlflow.store.db.utils: Updating database tables
2026/04/23 22:25:20 INFO mlflow.tracking.fluent: Experiment with name 'notebook' does not exist. Creating a new experiment.
2026/04/23 22:25:20 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2026/04/23 22:25:20 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2026-04-23 22:25:21,245 hyrax.pytorch_ignite:INFO] Total training time: 0.35[s]
2026/04/23 22:25:21 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2026/04/23 22:25:21 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2026-04-23 22:25:21,279 hyrax.verbs.train:INFO] Finished Training
Printing the trace#
Printing trace_result gives a high-level summary of all pipeline stages and the function calls captured within them. Each entry shows the function name, its input and output names, the shapes / hashes of any tensors, and how long the call took.
[5]:
print(trace_result)
Trace Stages {
dataset_getter: [
data__get_image(index) -> image duration=0.0117 ms
inputs:
index = 1
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
data__get_object_id(index) -> object_id duration=0.00884 ms
inputs:
index = 1
outputs:
object_id = '24'
data__get_image(index) -> image duration=0.00278 ms
inputs:
index = 0
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
data__get_object_id(index) -> object_id duration=0.00166 ms
inputs:
index = 0
outputs:
object_id = '23'
]
resolve_data: [
DataProvider__resolve_data(index) -> data_dict duration=0.273 ms
inputs:
index = 1
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
DataProvider__resolve_data(index) -> data_dict duration=0.0532 ms
inputs:
index = 0
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
]
collate: [
DataProvider__collate(batch_dicts) -> batch_dict duration=0.168 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
]
outputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0958 ms
inputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
outputs:
batch_dict_no_nan = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
]
prepare_inputs: [
HyraxAutoencoder__prepare_inputs(batch_dict) -> batch_ndarray duration=0.00796 ms
inputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
outputs:
batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
]
evaluation: [
HyraxAutoencoder__train_batch(batch) -> loss_dict duration=315 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
loss_dict = {
loss: 247.11880493164062
}
]
}
The output lists five stages in pipeline order:
Stage |
What is captured |
|---|---|
|
Individual |
|
|
|
|
|
Model’s |
|
Model functions such as |
The table above is printed as text from the TraceResult.__str__ implementation so that the notebook cell output and a plain print() call look identical.
Exploring stages#
You can access any stage using either attribute access or dictionary-style access — both are equivalent. Tab completion in a notebook environment will suggest the valid stage names.
[6]:
# Attribute-style access
collate_stage = trace_result.collate
print(collate_stage)
[
DataProvider__collate(batch_dicts) -> batch_dict duration=0.168 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
]
outputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0958 ms
inputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
outputs:
batch_dict_no_nan = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
]
[7]:
# Dictionary-style access — identical result
evaluation_stage = trace_result["evaluation"]
print(evaluation_stage)
[
HyraxAutoencoder__train_batch(batch) -> loss_dict duration=315 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
loss_dict = {
loss: 247.11880493164062
}
]
Each stage is a TraceStage — a list of TraceCall records in the order they were executed. You can ask how many calls were captured:
[8]:
print(f"resolve_data calls : {len(trace_result.resolve_data)}")
print(f"collate calls : {len(trace_result.collate)}")
print(f"evaluation calls : {len(trace_result.evaluation)}")
resolve_data calls : 2
collate calls : 2
evaluation calls : 1
Exploring individual function calls#
Within a stage you can index calls by number (call order) or by function name. A TraceCall captures the function’s argument values and return value along with timing information.
[9]:
# Get the first call in the collate stage
first_collate_call = trace_result.collate[0]
print(first_collate_call)
DataProvider__collate(batch_dicts) -> batch_dict duration=0.168 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
]
outputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
[10]:
# Access the captured batch tensor directly — attribute and dict access both work
batch_dict = first_collate_call.batch_dict
print(type(batch_dict))
print(batch_dict)
<class 'dict'>
{'data': {'image': array([[[[0.8218863 , 0.05556786, 0.07764488, ..., 0.95842916,
0.07661045, 0.99546444],
[0.44957864, 0.77210486, 0.4244706 , ..., 0.8390697 ,
0.8016305 , 0.99032164],
[0.65094787, 0.14159584, 0.50290215, ..., 0.58644515,
0.72324306, 0.49328995],
...,
[0.07175517, 0.4650141 , 0.97094595, ..., 0.62270874,
0.05409676, 0.7508386 ],
[0.6373174 , 0.79363537, 0.82733 , ..., 0.9218524 ,
0.6476614 , 0.31915647],
[0.41393703, 0.7261801 , 0.7195647 , ..., 0.9918077 ,
0.45286757, 0.5314138 ]]],
[[[0.08925092, 0.773956 , 0.6545715 , ..., 0.44341415,
0.45045954, 0.22723871],
[0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
0.36664265, 0.9675097 ],
[0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
0.8980876 , 0.28832805],
...,
[0.41836828, 0.5780172 , 0.5375471 , ..., 0.6346291 ,
0.9714626 , 0.41181087],
[0.15344363, 0.40878308, 0.8401149 , ..., 0.7874322 ,
0.3427019 , 0.5491443 ],
[0.19697303, 0.43141818, 0.5296637 , ..., 0.07205909,
0.8685205 , 0.84199315]]]], shape=(2, 1, 32, 32), dtype=float32)}, 'object_id': array(['24', '23'], dtype='<U2')}
[11]:
# Get a list of all calls to a particular function by providing the display name
# (visible from the print output above)
all_resolve_calls = trace_result.resolve_data["DataProvider__resolve_data"]
print(f"Number of resolve_data calls: {len(all_resolve_calls)}")
print(all_resolve_calls[0])
Number of resolve_data calls: 2
DataProvider__resolve_data(index) -> data_dict duration=0.273 ms
inputs:
index = 1
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
Tracing other verbs#
The trace=N keyword works the same way for infer and test:
[12]:
infer_trace = h.infer(trace=2)
print(infer_trace)
[2026-04-23 22:25:21,356 hyrax.trace:WARNING] Starting Trace
[2026-04-23 22:25:21,356 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-04-23 22:25:21,417 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-04-23 22:25:21,418 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-04-23 22:25:21,419 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-04-23 22:25:21,420 hyrax.verbs.infer:INFO] Inference model: HyraxAutoencoder
[2026-04-23 22:25:21,420 hyrax.verbs.infer:INFO] Inference dataset(s):
{'infer': Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
}
2026-04-23 22:25:21,422 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data (primary':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x739869feeff0>, 'batch_size': 2, 'shuffle': False, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
>, 'pin_memory': False}
[2026-04-23 22:25:21,430 hyrax.models.model_utils:INFO] Updated config['infer']['model_weights_file'] to: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/results/20260423-222517-train-QPb4/example_model.pth
[2026-04-23 22:25:21,433 hyrax.verbs.infer:INFO] Saving inference results at: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/results/20260423-222521-infer-fBDm
[2026-04-23T22:25:21Z WARN lance::dataset::write::insert] No existing dataset at /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.1/docs/pre_executed/results/20260423-222521-infer-fBDm/lance_db/results.lance, it will be created
[2026-04-23 22:25:21,471 hyrax.pytorch_ignite:INFO] Total evaluation time: 0.03[s]
[2026-04-23 22:25:21,472 hyrax.datasets.result_dataset:INFO] Optimizing Lance table after 1 batches
[2026-04-23 22:25:21,476 hyrax.datasets.result_dataset:INFO] Lance table optimization complete
[2026-04-23 22:25:21,480 hyrax.verbs.infer:INFO] Inference Complete.
Trace Stages {
dataset_getter: [
data__get_image(index) -> image duration=0.00885 ms
inputs:
index = 0
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
data__get_object_id(index) -> object_id duration=0.00915 ms
inputs:
index = 0
outputs:
object_id = '23'
data__get_image(index) -> image duration=0.00248 ms
inputs:
index = 1
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
data__get_object_id(index) -> object_id duration=0.0012 ms
inputs:
index = 1
outputs:
object_id = '24'
]
resolve_data: [
DataProvider__resolve_data(index) -> data_dict duration=0.194 ms
inputs:
index = 0
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
DataProvider__resolve_data(index) -> data_dict duration=0.0435 ms
inputs:
index = 1
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
collate: [
DataProvider__collate(batch_dicts) -> batch_dict duration=0.265 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
outputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.204 ms
inputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
outputs:
batch_dict_no_nan = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
]
prepare_inputs: [
HyraxAutoencoder_inst_prepare_inputs(batch_dict) -> batch_ndarray duration=0.00685 ms
inputs:
batch_dict = {
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
object_id: <numpy.ndarray shape=(2,) hash=1516682233991693147 device=cpu>
}
outputs:
batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
]
evaluation: [
HyraxAutoencoder__infer_batch(batch) -> batch_results duration=1.71 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
batch_results = <torch.Tensor shape=(2, 64) hash=9362612189552181248 device=cpu>
HyraxAutoencoder__forward(batch) -> batch_results duration=1.7 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
batch_results = <torch.Tensor shape=(2, 64) hash=9362612189552181248 device=cpu>
]
}
Instrumenting custom models and datasets#
If you have written a custom model or dataset class you can opt specific methods into trace capture using the @trace_model_func and @trace_dataset_func decorators.
from hyrax.trace import trace_model_func, trace_dataset_func
class MyModel(nn.Module):
@trace_model_func
def my_custom_forward(self, batch):
# This call will appear in the 'evaluation' stage of the TraceResult
...
class MyDataset(HyraxDataset):
@trace_dataset_func
def get_image(self, index):
# This call will appear in the 'dataset_getter' stage of the TraceResult
...
The decorators add a small overhead to every call, so they are intended for use during development and debugging rather than in production. Remove them (or use them selectively) once you are satisfied with how your data is flowing.