Tracing Data Flow in Hyrax#
When working with a new dataset or model it can be hard to know whether your data is being processed correctly. Hyrax provides a trace mode that lets you follow a small batch of data items through the entire pipeline — from raw dataset access, through batching and input preparation, all the way to model evaluation — so you can inspect what is actually happening at each step.
This notebook walks through a typical trace session:
Setting up a Hyrax instance
Running a verb with
trace=Nto capture pipeline dataPrinting and navigating the returned
TraceResultDrilling into individual stages and function calls
Trace mode is intended for interactive use in notebooks — it is not a production profiling or logging tool.
Install Hyrax#
Skip this step if you have already installed Hyrax in your environment.
[1]:
# %pip install hyrax
Create a Hyrax instance#
We start by creating a Hyrax instance and pointing it at a simple built-in model and dataset. Here we use HyraxAutoencoder trained on HyraxRandomDataset — a dataset that generates random tensors without requiring any downloaded data, which makes it convenient for experimentation.
You can replace this with your own model and dataset configuration to trace your real workflow.
[2]:
import hyrax
h = hyrax.Hyrax()
Configure the model and data#
We will use the HyraxAutoencoder model and HyraxRandomDataset to keep this notebook self-contained. The random dataset produces small random tensors ([1, 32, 32]) that resemble single-band image cutouts.
[3]:
h.config["model"]["name"] = "HyraxAutoencoder"
# Use a small 1-channel 32×32 image shape to keep things fast
h.config["data_set"]["HyraxRandomDataset"]["shape"] = [1, 32, 32]
h.config["data_set"]["HyraxRandomDataset"]["size"] = 50
h.config["data_set"]["HyraxRandomDataset"]["seed"] = 42
data_request = {
"train": {
"data": {
"dataset_class": "HyraxRandomDataset",
"data_location": "./trace_data",
"fields": ["image"],
"primary_id_field": "object_id",
"split_fraction": 1.0,
},
},
"infer": {
"data": {
"dataset_class": "HyraxRandomDataset",
"data_location": "./trace_data",
"fields": ["image"],
"primary_id_field": "object_id",
"split_fraction": 1.0,
},
},
}
h.set_config("data_request", data_request)
Running a verb in trace mode#
Any instrumented verb (train, infer, test) accepts a trace=N keyword argument. N controls how many data items are traced through the pipeline — keep this small (2–10) to get a readable output.
When trace=N is passed, the verb:
Processes only a single batch of
Nitems instead of the full datasetReturns a
TraceResultobject instead of the usual verb return value
The TraceResult captures the inputs and outputs of every major pipeline step.
[4]:
trace_result = h.train(trace=2)
[2026-05-18 21:37:47,433 hyrax.trace:WARNING] Starting Trace
[2026-05-18 21:37:47,434 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-05-18 21:37:47,518 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-05-18 21:37:47,519 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-05-18 21:37:47,520 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-05-18 21:37:47,521 hyrax.verbs.train:INFO] Training model: HyraxAutoencoder
[2026-05-18 21:37:47,523 hyrax.verbs.train:INFO] Training dataset(s):
{'train': Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
}
2026-05-18 21:37:47,526 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data (primary':
{'sampler': <torch.utils.data.sampler.SubsetRandomSampler object at 0x7439223ee5a0>, 'batch_size': 2, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
>, 'pin_memory': False}
2026/05/18 21:37:48 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/05/18 21:37:48 INFO mlflow.store.db.utils: Updating database tables
2026/05/18 21:37:50 INFO mlflow.tracking.fluent: Experiment with name 'notebook' does not exist. Creating a new experiment.
2026/05/18 21:37:50 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2026/05/18 21:37:50 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2026-05-18 21:37:51,042 hyrax.pytorch_ignite:INFO] Total training time: 0.33[s]
2026/05/18 21:37:51 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2026/05/18 21:37:51 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2026-05-18 21:37:51,076 hyrax.verbs.train:INFO] Finished Training
Printing the trace#
Printing trace_result gives a high-level summary of all pipeline stages and the function calls captured within them. Each entry shows the function name, its input and output names, the shapes / hashes of any tensors, and how long the call took.
[5]:
print(trace_result)
Trace Stages {
dataset_getter: [
data__get_image(index) -> image duration=0.0124 ms
inputs:
index = 0
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
data__get_object_id(index) -> object_id duration=0.00879 ms
inputs:
index = 0
outputs:
object_id = '23'
data__get_image(index) -> image duration=0.00554 ms
inputs:
index = 1
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
data__get_object_id(index) -> object_id duration=0.00407 ms
inputs:
index = 1
outputs:
object_id = '24'
]
resolve_data: [
DataProvider__resolve_data(index) -> data_dict duration=0.297 ms
inputs:
index = 0
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
DataProvider__resolve_data(index) -> data_dict duration=0.105 ms
inputs:
index = 1
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
collate: [
DataProvider__collate(batch_dicts) -> batch_dict duration=0.237 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
outputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.124 ms
inputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
outputs:
batch_dict_no_nan = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
]
prepare_inputs: [
HyraxAutoencoder__prepare_inputs(batch_dict) -> batch_ndarray duration=0.0104 ms
inputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
outputs:
batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
]
evaluation: [
HyraxAutoencoder__train_batch(batch) -> loss_dict duration=272 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
loss_dict = {
loss: 526.142822265625
}
]
}
The output lists five stages in pipeline order:
Stage |
What is captured |
|---|---|
|
Individual |
|
|
|
|
|
Model’s |
|
Model functions such as |
The table above is printed as text from the TraceResult.__str__ implementation so that the notebook cell output and a plain print() call look identical.
Exploring stages#
You can access any stage using either attribute access or dictionary-style access — both are equivalent. Tab completion in a notebook environment will suggest the valid stage names.
[6]:
# Attribute-style access
collate_stage = trace_result.collate
print(collate_stage)
[
DataProvider__collate(batch_dicts) -> batch_dict duration=0.237 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
outputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.124 ms
inputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
outputs:
batch_dict_no_nan = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
]
[7]:
# Dictionary-style access — identical result
evaluation_stage = trace_result["evaluation"]
print(evaluation_stage)
[
HyraxAutoencoder__train_batch(batch) -> loss_dict duration=272 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
loss_dict = {
loss: 526.142822265625
}
]
Each stage is a TraceStage — a list of TraceCall records in the order they were executed. You can ask how many calls were captured:
[8]:
print(f"resolve_data calls : {len(trace_result.resolve_data)}")
print(f"collate calls : {len(trace_result.collate)}")
print(f"evaluation calls : {len(trace_result.evaluation)}")
resolve_data calls : 2
collate calls : 2
evaluation calls : 1
Exploring individual function calls#
Within a stage you can index calls by number (call order) or by function name. A TraceCall captures the function’s argument values and return value along with timing information.
[9]:
# Get the first call in the collate stage
first_collate_call = trace_result.collate[0]
print(first_collate_call)
DataProvider__collate(batch_dicts) -> batch_dict duration=0.237 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
outputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
[10]:
# Access the captured batch tensor directly — attribute and dict access both work
batch_dict = first_collate_call.batch_dict
print(type(batch_dict))
print(batch_dict)
<class 'dict'>
{'object_id': array(['23', '24'], dtype='<U2'), 'data': {'image': array([[[[0.08925092, 0.773956 , 0.6545715 , ..., 0.44341415,
0.45045954, 0.22723871],
[0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
0.36664265, 0.9675097 ],
[0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
0.8980876 , 0.28832805],
...,
[0.41836828, 0.5780172 , 0.5375471 , ..., 0.6346291 ,
0.9714626 , 0.41181087],
[0.15344363, 0.40878308, 0.8401149 , ..., 0.7874322 ,
0.3427019 , 0.5491443 ],
[0.19697303, 0.43141818, 0.5296637 , ..., 0.07205909,
0.8685205 , 0.84199315]]],
[[[0.8218863 , 0.05556786, 0.07764488, ..., 0.95842916,
0.07661045, 0.99546444],
[0.44957864, 0.77210486, 0.4244706 , ..., 0.8390697 ,
0.8016305 , 0.99032164],
[0.65094787, 0.14159584, 0.50290215, ..., 0.58644515,
0.72324306, 0.49328995],
...,
[0.07175517, 0.4650141 , 0.97094595, ..., 0.62270874,
0.05409676, 0.7508386 ],
[0.6373174 , 0.79363537, 0.82733 , ..., 0.9218524 ,
0.6476614 , 0.31915647],
[0.41393703, 0.7261801 , 0.7195647 , ..., 0.9918077 ,
0.45286757, 0.5314138 ]]]], shape=(2, 1, 32, 32), dtype=float32)}}
[11]:
# Get a list of all calls to a particular function by providing the display name
# (visible from the print output above)
all_resolve_calls = trace_result.resolve_data["DataProvider__resolve_data"]
print(f"Number of resolve_data calls: {len(all_resolve_calls)}")
print(all_resolve_calls[0])
Number of resolve_data calls: 2
DataProvider__resolve_data(index) -> data_dict duration=0.297 ms
inputs:
index = 0
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
Tracing other verbs#
The trace=N keyword works the same way for infer and test:
[12]:
infer_trace = h.infer(trace=2)
print(infer_trace)
[2026-05-18 21:37:51,165 hyrax.trace:WARNING] Starting Trace
[2026-05-18 21:37:51,165 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-05-18 21:37:51,239 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-05-18 21:37:51,241 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-05-18 21:37:51,241 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-05-18 21:37:51,243 hyrax.verbs.infer:INFO] Inference model: HyraxAutoencoder
[2026-05-18 21:37:51,244 hyrax.verbs.infer:INFO] Inference dataset(s):
{'infer': Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
}
2026-05-18 21:37:51,246 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data (primary':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x74390adab2f0>, 'batch_size': 2, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
Dataset class: HyraxRandomDataset
Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/trace_data
Fraction of data to use: 1.0
Primary ID field: object_id
Requested fields: image
>, 'pin_memory': False}
[2026-05-18 21:37:51,258 hyrax.models.model_utils:INFO] Updated config['infer']['model_weights_file'] to: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/results/20260518-213747-train-IcYq/example_model.pth
[2026-05-18 21:37:51,261 hyrax.verbs.infer:INFO] Saving inference results at: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/results/20260518-213751-infer-wsmD
[2026-05-18T21:37:51Z WARN lance::dataset::write::insert] No existing dataset at /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.3/docs/pre_executed/results/20260518-213751-infer-wsmD/lance_db/results.lance, it will be created
[2026-05-18 21:37:51,301 hyrax.pytorch_ignite:INFO] Total evaluation time: 0.03[s]
[2026-05-18 21:37:51,303 hyrax.datasets.result_dataset:INFO] Optimizing Lance table after 1 batches
[2026-05-18 21:37:51,307 hyrax.datasets.result_dataset:INFO] Lance table optimization complete
[2026-05-18 21:37:51,311 hyrax.verbs.infer:INFO] Inference Complete.
Trace Stages {
dataset_getter: [
data__get_image(index) -> image duration=0.00997 ms
inputs:
index = 0
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
data__get_object_id(index) -> object_id duration=0.0072 ms
inputs:
index = 0
outputs:
object_id = '23'
data__get_image(index) -> image duration=0.00404 ms
inputs:
index = 1
outputs:
image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
data__get_object_id(index) -> object_id duration=0.00304 ms
inputs:
index = 1
outputs:
object_id = '24'
]
resolve_data: [
DataProvider__resolve_data(index) -> data_dict duration=0.178 ms
inputs:
index = 0
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
DataProvider__resolve_data(index) -> data_dict duration=0.08 ms
inputs:
index = 1
outputs:
data_dict = {
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
collate: [
DataProvider__collate(batch_dicts) -> batch_dict duration=0.182 ms
inputs:
batch_dicts = <list len=2> [
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
}
object_id: '23'
}
{
data: {
image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
}
object_id: '24'
}
]
outputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0961 ms
inputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
outputs:
batch_dict_no_nan = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
]
prepare_inputs: [
HyraxAutoencoder_inst_prepare_inputs(batch_dict) -> batch_ndarray duration=0.00694 ms
inputs:
batch_dict = {
object_id: <numpy.ndarray shape=(2,) hash=9467597824855676281 device=cpu>
data: {
image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
}
}
outputs:
batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
]
evaluation: [
HyraxAutoencoder__infer_batch(batch) -> batch_results duration=2.08 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
batch_results = <torch.Tensor shape=(2, 64) hash=50005469393059840 device=cpu>
HyraxAutoencoder__forward(batch) -> batch_results duration=2.07 ms
inputs:
batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
outputs:
batch_results = <torch.Tensor shape=(2, 64) hash=50005469393059840 device=cpu>
]
}
Instrumenting custom models and datasets#
If you have written a custom model or dataset class you can opt specific methods into trace capture using the @trace_model_func and @trace_dataset_func decorators.
from hyrax.trace import trace_model_func, trace_dataset_func
class MyModel(nn.Module):
@trace_model_func
def my_custom_forward(self, batch):
# This call will appear in the 'evaluation' stage of the TraceResult
...
class MyDataset(HyraxDataset):
@trace_dataset_func
def get_image(self, index):
# This call will appear in the 'dataset_getter' stage of the TraceResult
...
The decorators add a small overhead to every call, so they are intended for use during development and debugging rather than in production. Remove them (or use them selectively) once you are satisfied with how your data is flowing.