Tracing Data Flow in Hyrax#

When working with a new dataset or model it can be hard to know whether your data is being processed correctly. Hyrax provides a trace mode that lets you follow a small batch of data items through the entire pipeline — from raw dataset access, through batching and input preparation, all the way to model evaluation — so you can inspect what is actually happening at each step.

This notebook walks through a typical trace session:

  1. Setting up a Hyrax instance

  2. Running a verb with trace=N to capture pipeline data

  3. Printing and navigating the returned TraceResult

  4. Drilling into individual stages and function calls

Trace mode is intended for interactive use in notebooks — it is not a production profiling or logging tool.

Install Hyrax#

Skip this step if you have already installed Hyrax in your environment.

[1]:
# %pip install hyrax

Create a Hyrax instance#

We start by creating a Hyrax instance and pointing it at a simple built-in model and dataset. Here we use HyraxAutoencoder trained on HyraxRandomDataset — a dataset that generates random tensors without requiring any downloaded data, which makes it convenient for experimentation.

You can replace this with your own model and dataset configuration to trace your real workflow.

[2]:
import hyrax

h = hyrax.Hyrax()

Configure the model and data#

We will use the HyraxAutoencoder model and HyraxRandomDataset to keep this notebook self-contained. The random dataset produces small random tensors ([1, 32, 32]) that resemble single-band image cutouts.

[3]:
h.config["model"]["name"] = "HyraxAutoencoder"

# Use a small 1-channel 32×32 image shape to keep things fast
h.config["data_set"]["HyraxRandomDataset"]["shape"] = [1, 32, 32]
h.config["data_set"]["HyraxRandomDataset"]["size"] = 50
h.config["data_set"]["HyraxRandomDataset"]["seed"] = 42

data_request = {
    "train": {
        "data": {
            "dataset_class": "HyraxRandomDataset",
            "data_location": "./trace_data",
            "fields": ["image"],
            "primary_id_field": "object_id",
            "split_fraction": 1.0,
        },
    },
    "infer": {
        "data": {
            "dataset_class": "HyraxRandomDataset",
            "data_location": "./trace_data",
            "fields": ["image"],
            "primary_id_field": "object_id",
            "split_fraction": 1.0,
        },
    },
}
h.set_config("data_request", data_request)

Running a verb in trace mode#

Any instrumented verb (train, infer, test) accepts a trace=N keyword argument. N controls how many data items are traced through the pipeline — keep this small (2–10) to get a readable output.

When trace=N is passed, the verb:

  • Processes only a single batch of N items instead of the full dataset

  • Returns a TraceResult object instead of the usual verb return value

The TraceResult captures the inputs and outputs of every major pipeline step.

[4]:
trace_result = h.train(trace=2)
[2026-05-16 03:47:14,200 hyrax.trace:WARNING] Starting Trace
[2026-05-16 03:47:14,201 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-05-16 03:47:14,257 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-05-16 03:47:14,257 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-05-16 03:47:14,258 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-05-16 03:47:14,258 hyrax.verbs.train:INFO] Training model: HyraxAutoencoder
[2026-05-16 03:47:14,259 hyrax.verbs.train:INFO] Training dataset(s):
{'train': Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/trace_data
  Fraction of data to use: 1.0
  Primary ID field: object_id
  Requested fields: image
}
2026-05-16 03:47:14,261 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data (primary':
        {'sampler': <torch.utils.data.sampler.SubsetRandomSampler object at 0x734bb89745c0>, 'batch_size': 2, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/trace_data
  Fraction of data to use: 1.0
  Primary ID field: object_id
  Requested fields: image
>, 'pin_memory': False}
2026/05/16 03:47:15 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/05/16 03:47:15 INFO mlflow.store.db.utils: Updating database tables
2026/05/16 03:47:16 INFO mlflow.tracking.fluent: Experiment with name 'notebook' does not exist. Creating a new experiment.
2026/05/16 03:47:16 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2026/05/16 03:47:16 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2026-05-16 03:47:17,190 hyrax.pytorch_ignite:INFO] Total training time: 0.24[s]
2026/05/16 03:47:17 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2026/05/16 03:47:17 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2026-05-16 03:47:17,211 hyrax.verbs.train:INFO] Finished Training

Printing the trace#

Printing trace_result gives a high-level summary of all pipeline stages and the function calls captured within them. Each entry shows the function name, its input and output names, the shapes / hashes of any tensors, and how long the call took.

[5]:
print(trace_result)
Trace Stages {
        dataset_getter: [
                data__get_image(index) -> image duration=0.0117 ms
                inputs:
                        index = 0
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00767 ms
                inputs:
                        index = 0
                outputs:
                        object_id = '23'

                data__get_image(index) -> image duration=0.00117 ms
                inputs:
                        index = 1
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00118 ms
                inputs:
                        index = 1
                outputs:
                        object_id = '24'

        ]
        resolve_data: [
                DataProvider__resolve_data(index) -> data_dict duration=0.193 ms
                inputs:
                        index = 0
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                }
                                object_id: '23'
                        }

                DataProvider__resolve_data(index) -> data_dict duration=0.0375 ms
                inputs:
                        index = 1
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                }
                                object_id: '24'
                        }

        ]
        collate: [
                DataProvider__collate(batch_dicts) -> batch_dict duration=0.12 ms
                inputs:
                        batch_dicts = <list len=2> [
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                        }
                                        object_id: '23'
                                }
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                        }
                                        object_id: '24'
                                }
                        ]
                outputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }

                DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0596 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }
                outputs:
                        batch_dict_no_nan = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }

        ]
        prepare_inputs: [
                HyraxAutoencoder__prepare_inputs(batch_dict) -> batch_ndarray duration=0.00685 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }
                outputs:
                        batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>

        ]
        evaluation: [
                HyraxAutoencoder__train_batch(batch) -> loss_dict duration=213 ms
                inputs:
                        batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                outputs:
                        loss_dict = {
                                loss: 218.092529296875
                        }

        ]
}

The output lists five stages in pipeline order:

Stage

What is captured

dataset_getter

Individual get_* calls on the dataset class

resolve_data

DataProvider.resolve_data — assembles all fields for each item

collate

DataProvider.collate and handle_nans — builds batch tensors

prepare_inputs

Model’s prepare_inputs — converts the data dictionary to model input tensors

evaluation

Model functions such as forward, train_batch, etc.

The table above is printed as text from the TraceResult.__str__ implementation so that the notebook cell output and a plain print() call look identical.

Exploring stages#

You can access any stage using either attribute access or dictionary-style access — both are equivalent. Tab completion in a notebook environment will suggest the valid stage names.

[6]:
# Attribute-style access
collate_stage = trace_result.collate
print(collate_stage)
[
        DataProvider__collate(batch_dicts) -> batch_dict duration=0.12 ms
        inputs:
                batch_dicts = <list len=2> [
                        {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                }
                                object_id: '23'
                        }
                        {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                }
                                object_id: '24'
                        }
                ]
        outputs:
                batch_dict = {
                        object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                        data: {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                        }
                }

        DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0596 ms
        inputs:
                batch_dict = {
                        object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                        data: {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                        }
                }
        outputs:
                batch_dict_no_nan = {
                        object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                        data: {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                        }
                }

]
[7]:
# Dictionary-style access — identical result
evaluation_stage = trace_result["evaluation"]
print(evaluation_stage)
[
        HyraxAutoencoder__train_batch(batch) -> loss_dict duration=213 ms
        inputs:
                batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
        outputs:
                loss_dict = {
                        loss: 218.092529296875
                }

]

Each stage is a TraceStage — a list of TraceCall records in the order they were executed. You can ask how many calls were captured:

[8]:
print(f"resolve_data calls : {len(trace_result.resolve_data)}")
print(f"collate calls      : {len(trace_result.collate)}")
print(f"evaluation calls   : {len(trace_result.evaluation)}")
resolve_data calls : 2
collate calls      : 2
evaluation calls   : 1

Exploring individual function calls#

Within a stage you can index calls by number (call order) or by function name. A TraceCall captures the function’s argument values and return value along with timing information.

[9]:
# Get the first call in the collate stage
first_collate_call = trace_result.collate[0]
print(first_collate_call)
DataProvider__collate(batch_dicts) -> batch_dict duration=0.12 ms
inputs:
        batch_dicts = <list len=2> [
                {
                        data: {
                                image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                        }
                        object_id: '23'
                }
                {
                        data: {
                                image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                        }
                        object_id: '24'
                }
        ]
outputs:
        batch_dict = {
                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                data: {
                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                }
        }

[10]:
# Access the captured batch tensor directly — attribute and dict access both work
batch_dict = first_collate_call.batch_dict
print(type(batch_dict))
print(batch_dict)
<class 'dict'>
{'object_id': array(['23', '24'], dtype='<U2'), 'data': {'image': array([[[[0.08925092, 0.773956  , 0.6545715 , ..., 0.44341415,
          0.45045954, 0.22723871],
         [0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
          0.36664265, 0.9675097 ],
         [0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
          0.8980876 , 0.28832805],
         ...,
         [0.41836828, 0.5780172 , 0.5375471 , ..., 0.6346291 ,
          0.9714626 , 0.41181087],
         [0.15344363, 0.40878308, 0.8401149 , ..., 0.7874322 ,
          0.3427019 , 0.5491443 ],
         [0.19697303, 0.43141818, 0.5296637 , ..., 0.07205909,
          0.8685205 , 0.84199315]]],


       [[[0.8218863 , 0.05556786, 0.07764488, ..., 0.95842916,
          0.07661045, 0.99546444],
         [0.44957864, 0.77210486, 0.4244706 , ..., 0.8390697 ,
          0.8016305 , 0.99032164],
         [0.65094787, 0.14159584, 0.50290215, ..., 0.58644515,
          0.72324306, 0.49328995],
         ...,
         [0.07175517, 0.4650141 , 0.97094595, ..., 0.62270874,
          0.05409676, 0.7508386 ],
         [0.6373174 , 0.79363537, 0.82733   , ..., 0.9218524 ,
          0.6476614 , 0.31915647],
         [0.41393703, 0.7261801 , 0.7195647 , ..., 0.9918077 ,
          0.45286757, 0.5314138 ]]]], shape=(2, 1, 32, 32), dtype=float32)}}
[11]:
# Get a list of all calls to a particular function by providing the display name
# (visible from the print output above)
all_resolve_calls = trace_result.resolve_data["DataProvider__resolve_data"]
print(f"Number of resolve_data calls: {len(all_resolve_calls)}")
print(all_resolve_calls[0])
Number of resolve_data calls: 2
DataProvider__resolve_data(index) -> data_dict duration=0.193 ms
inputs:
        index = 0
outputs:
        data_dict = {
                data: {
                        image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                }
                object_id: '23'
        }

Tracing other verbs#

The trace=N keyword works the same way for infer and test:

[12]:
infer_trace = h.infer(trace=2)
print(infer_trace)
[2026-05-16 03:47:17,271 hyrax.trace:WARNING] Starting Trace
[2026-05-16 03:47:17,271 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-05-16 03:47:17,326 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-05-16 03:47:17,326 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-05-16 03:47:17,327 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-05-16 03:47:17,327 hyrax.verbs.infer:INFO] Inference model: HyraxAutoencoder
[2026-05-16 03:47:17,328 hyrax.verbs.infer:INFO] Inference dataset(s):
{'infer': Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/trace_data
  Fraction of data to use: 1.0
  Primary ID field: object_id
  Requested fields: image
}
2026-05-16 03:47:17,328 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: data (primary':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x734ba11f0ec0>, 'batch_size': 2, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/trace_data
  Fraction of data to use: 1.0
  Primary ID field: object_id
  Requested fields: image
>, 'pin_memory': False}
[2026-05-16 03:47:17,336 hyrax.models.model_utils:INFO] Updated config['infer']['model_weights_file'] to: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/results/20260516-034714-train-0UUr/example_model.pth
[2026-05-16 03:47:17,337 hyrax.verbs.infer:INFO] Saving inference results at: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/results/20260516-034717-infer-9bI3
[2026-05-16T03:47:17Z WARN  lance::dataset::write::insert] No existing dataset at /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.8.2/docs/pre_executed/results/20260516-034717-infer-9bI3/lance_db/results.lance, it will be created
[2026-05-16 03:47:17,358 hyrax.pytorch_ignite:INFO] Total evaluation time: 0.02[s]
[2026-05-16 03:47:17,360 hyrax.datasets.result_dataset:INFO] Optimizing Lance table after 1 batches
[2026-05-16 03:47:17,362 hyrax.datasets.result_dataset:INFO] Lance table optimization complete
[2026-05-16 03:47:17,366 hyrax.verbs.infer:INFO] Inference Complete.
Trace Stages {
        dataset_getter: [
                data__get_image(index) -> image duration=0.00381 ms
                inputs:
                        index = 0
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00325 ms
                inputs:
                        index = 0
                outputs:
                        object_id = '23'

                data__get_image(index) -> image duration=0.00117 ms
                inputs:
                        index = 1
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00123 ms
                inputs:
                        index = 1
                outputs:
                        object_id = '24'

        ]
        resolve_data: [
                DataProvider__resolve_data(index) -> data_dict duration=0.0934 ms
                inputs:
                        index = 0
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                }
                                object_id: '23'
                        }

                DataProvider__resolve_data(index) -> data_dict duration=0.0357 ms
                inputs:
                        index = 1
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                }
                                object_id: '24'
                        }

        ]
        collate: [
                DataProvider__collate(batch_dicts) -> batch_dict duration=0.0872 ms
                inputs:
                        batch_dicts = <list len=2> [
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                        }
                                        object_id: '23'
                                }
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                        }
                                        object_id: '24'
                                }
                        ]
                outputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }

                DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0471 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }
                outputs:
                        batch_dict_no_nan = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }

        ]
        prepare_inputs: [
                HyraxAutoencoder_inst_prepare_inputs(batch_dict) -> batch_ndarray duration=0.00366 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=16373538229540728419 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                }
                        }
                outputs:
                        batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>

        ]
        evaluation: [
                HyraxAutoencoder__infer_batch(batch) -> batch_results duration=1.16 ms
                inputs:
                        batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                outputs:
                        batch_results = <torch.Tensor shape=(2, 64) hash=17938661593055232 device=cpu>

                HyraxAutoencoder__forward(batch) -> batch_results duration=1.15 ms
                inputs:
                        batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                outputs:
                        batch_results = <torch.Tensor shape=(2, 64) hash=17938661593055232 device=cpu>

        ]
}

Instrumenting custom models and datasets#

If you have written a custom model or dataset class you can opt specific methods into trace capture using the @trace_model_func and @trace_dataset_func decorators.

from hyrax.trace import trace_model_func, trace_dataset_func

class MyModel(nn.Module):
    @trace_model_func
    def my_custom_forward(self, batch):
        # This call will appear in the 'evaluation' stage of the TraceResult
        ...

class MyDataset(HyraxDataset):
    @trace_dataset_func
    def get_image(self, index):
        # This call will appear in the 'dataset_getter' stage of the TraceResult
        ...

The decorators add a small overhead to every call, so they are intended for use during development and debugging rather than in production. Remove them (or use them selectively) once you are satisfied with how your data is flowing.