Tracing Data Flow in Hyrax#

When working with a new dataset or model it can be hard to know whether your data is being processed correctly. Hyrax provides a trace mode that lets you follow a small batch of data items through the entire pipeline — from raw dataset access, through batching and input preparation, all the way to model evaluation — so you can inspect what is actually happening at each step.

This notebook walks through a typical trace session:

  1. Setting up a Hyrax instance

  2. Running a verb with trace=N to capture pipeline data

  3. Printing and navigating the returned TraceResult

  4. Drilling into individual stages and function calls

Trace mode is intended for interactive use in notebooks — it is not a production profiling or logging tool.

Install Hyrax#

Skip this step if you have already installed Hyrax in your environment.

[1]:
# %pip install hyrax

Create a Hyrax instance#

We start by creating a Hyrax instance and pointing it at a simple built-in model and dataset. Here we use HyraxAutoencoder trained on HyraxRandomDataset — a dataset that generates random tensors without requiring any downloaded data, which makes it convenient for experimentation.

You can replace this with your own model and dataset configuration to trace your real workflow.

[2]:
import hyrax

h = hyrax.Hyrax()

Configure the model and data#

We will use the HyraxAutoencoder model and HyraxRandomDataset to keep this notebook self-contained. The random dataset produces small random tensors ([1, 32, 32]) that resemble single-band image cutouts.

[3]:
h.config["model"]["name"] = "HyraxAutoencoder"

# Use a small 1-channel 32×32 image shape to keep things fast
h.config["data_set"]["HyraxRandomDataset"]["shape"] = [1, 32, 32]
h.config["data_set"]["HyraxRandomDataset"]["size"] = 50
h.config["data_set"]["HyraxRandomDataset"]["seed"] = 42

data_request = {
    "train": {
        "data": {
            "dataset_class": "HyraxRandomDataset",
            "data_location": "./trace_data",
            "fields": ["image"],
            "primary_id_field": "object_id",
        },
    },
    "infer": {
        "data": {
            "dataset_class": "HyraxRandomDataset",
            "data_location": "./trace_data",
            "fields": ["image"],
            "primary_id_field": "object_id",
        },
    },
}
h.set_config("data_request", data_request)
[4]:
from hyrax.datasets.random.hyrax_random_dataset import HyraxRandomDataset
import numpy as np


def collate_image(self, samples: list[dict]) -> dict:
    collated_data = {}
    collated_data["image"] = np.stack([sample["image"] for sample in samples], axis=0)
    collated_data["mask"] = np.zeros_like(collated_data["image"], dtype=bool)
    return collated_data


HyraxRandomDataset.collate = collate_image

We attach either a collate_image function to HyraxRandomDataset above. We could instead attach a dataset-level collate function like the one shown below. The TraceResult will contain information for whichever function is attached. Field-level collation functions are described in the field_level_collation group which will be empty if a dataset-level collate function is defined. Dataset-level collate functions are described in the collate group.

def collate(self, samples: list[dict]) -> dict:
    collated_data = {}
    collated_data["image"] = np.stack([sample["image"] for sample in samples], axis=0)
    collated_data["mask"] = np.zeros_like(collated_data["image"], dtype=bool)
    return collated_data

HyraxRandomDataset.collate = collate

Running a verb in trace mode#

Any instrumented verb (train, infer, test) accepts a trace=N keyword argument. N controls how many data items are traced through the pipeline — keep this small (2–10) to get a readable output.

When trace=N is passed, the verb:

  • Processes only a single batch of N items instead of the full dataset

  • Returns a TraceResult object instead of the usual verb return value

The TraceResult captures the inputs and outputs of every major pipeline step.

[5]:
trace_result = h.train(trace=2)
[2026-06-30 20:45:23,094 hyrax.trace:WARNING] Starting Trace
[2026-06-30 20:45:23,095 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-06-30 20:45:23,176 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-06-30 20:45:23,177 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-06-30 20:45:23,178 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-06-30 20:45:23,178 hyrax.verbs.train:INFO] Training model: HyraxAutoencoder
[2026-06-30 20:45:23,179 hyrax.verbs.train:INFO] Training dataset(s):
{'train': Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.9.0/docs/pre_executed/trace_data
  Selected items: 50
  Primary ID field: object_id
  Requested fields: image
}
2026-06-30 20:45:23,189 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<torch.utils.data.da':
        {'sampler': <torch.utils.data.sampler.WeightedRandomSampler object at 0x7527f4440bf0>, 'batch_size': 2, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.9.0/docs/pre_executed/trace_data
  Selected items: 50
  Primary ID field: object_id
  Requested fields: image
>, 'pin_memory': False}
2026/06/30 20:45:24 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/06/30 20:45:24 INFO mlflow.store.db.utils: Updating database tables
2026/06/30 20:45:26 INFO mlflow.tracking.fluent: Experiment with name 'notebook' does not exist. Creating a new experiment.
2026/06/30 20:45:26 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2026/06/30 20:45:26 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2026-06-30 20:45:26,659 hyrax.pytorch_ignite:INFO] Total training time: 0.27[s]
2026/06/30 20:45:26 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2026/06/30 20:45:26 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2026-06-30 20:45:26,687 hyrax.verbs.train:INFO] Finished Training

Printing the trace#

Printing trace_result gives a high-level summary of all pipeline stages and the function calls captured within them. Each entry shows the function name, its input and output names, the shapes / hashes of any tensors, and how long the call took.

[6]:
print(trace_result)
Trace Stages {
        dataset_getter: [
                data__get_image(index) -> image duration=0.00749 ms
                inputs:
                        index = 19
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00518 ms
                inputs:
                        index = 19
                outputs:
                        object_id = '42'

                data__get_image(index) -> image duration=0.00268 ms
                inputs:
                        index = 21
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>

                data__get_object_id(index) -> object_id duration=0.0263 ms
                inputs:
                        index = 21
                outputs:
                        object_id = '44'

        ]
        resolve_data: [
                DataProvider__resolve_data(index) -> data_dict duration=0.216 ms
                inputs:
                        index = 19
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                                }
                                object_id: '42'
                        }

                DataProvider__resolve_data(index) -> data_dict duration=0.105 ms
                inputs:
                        index = 21
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>
                                }
                                object_id: '44'
                        }

        ]
        field_level_collation: []
        collate: [
                DataProvider__collate(batch_dicts) -> batch_dict duration=0.203 ms
                inputs:
                        batch_dicts = <list len=2> [
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                                        }
                                        object_id: '42'
                                }
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>
                                        }
                                        object_id: '44'
                                }
                        ]
                outputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }

                data__collate(samples) -> batch_dict duration=0.0612 ms
                inputs:
                        samples = <list len=2> [
                                {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                                }
                                {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>
                                }
                        ]
                outputs:
                        batch_dict = {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                        }

                DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0968 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }
                outputs:
                        batch_dict_no_nan = {
                                object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }

        ]
        prepare_inputs: [
                HyraxAutoencoder__prepare_inputs(batch_dict) -> batch_ndarray duration=0.00839 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }
                outputs:
                        batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>

        ]
        evaluation: [
                HyraxAutoencoder__train_batch(batch) -> loss_dict duration=238 ms
                inputs:
                        batch = <torch.Tensor shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                outputs:
                        loss_dict = {
                                loss: 634.53857421875
                        }

        ]
}

The output lists five stages in pipeline order:

Stage

What is captured

dataset_getter

Individual get_* calls on the dataset class

resolve_data

DataProvider.resolve_data — assembles all fields for each item

field_level_collate

Individual collate_* calls for fields where such a function is defined

collate

DataProvider.collate and handle_nans — builds batch tensors

prepare_inputs

Model’s prepare_inputs — converts the data dictionary to model input tensors

evaluation

Model functions such as forward, train_batch, etc.

The table above is printed as text from the TraceResult.__str__ implementation so that the notebook cell output and a plain print() call look identical.

Exploring stages#

You can access any stage using either attribute access or dictionary-style access — both are equivalent. Tab completion in a notebook environment will suggest the valid stage names.

[7]:
# Attribute-style access
collate_stage = trace_result.collate
print(collate_stage)
[
        DataProvider__collate(batch_dicts) -> batch_dict duration=0.203 ms
        inputs:
                batch_dicts = <list len=2> [
                        {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                                }
                                object_id: '42'
                        }
                        {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>
                                }
                                object_id: '44'
                        }
                ]
        outputs:
                batch_dict = {
                        object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                        data: {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                        }
                }

        data__collate(samples) -> batch_dict duration=0.0612 ms
        inputs:
                samples = <list len=2> [
                        {
                                image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                        }
                        {
                                image: <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>
                        }
                ]
        outputs:
                batch_dict = {
                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                }

        DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0968 ms
        inputs:
                batch_dict = {
                        object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                        data: {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                        }
                }
        outputs:
                batch_dict_no_nan = {
                        object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                        data: {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                                mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                        }
                }

]
[8]:
# Dictionary-style access — identical result
evaluation_stage = trace_result["evaluation"]
print(evaluation_stage)
[
        HyraxAutoencoder__train_batch(batch) -> loss_dict duration=238 ms
        inputs:
                batch = <torch.Tensor shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
        outputs:
                loss_dict = {
                        loss: 634.53857421875
                }

]

Each stage is a TraceStage — a list of TraceCall records in the order they were executed. You can ask how many calls were captured:

[9]:
print(f"resolve_data calls : {len(trace_result.resolve_data)}")
print(f"collate calls      : {len(trace_result.collate)}")
print(f"evaluation calls   : {len(trace_result.evaluation)}")
resolve_data calls : 2
collate calls      : 3
evaluation calls   : 1

Exploring individual function calls#

Within a stage you can index calls by number (call order) or by function name. A TraceCall captures the function’s argument values and return value along with timing information.

[10]:
# Get the first call in the collate stage
first_collate_call = trace_result.collate[0]
print(first_collate_call)
DataProvider__collate(batch_dicts) -> batch_dict duration=0.203 ms
inputs:
        batch_dicts = <list len=2> [
                {
                        data: {
                                image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                        }
                        object_id: '42'
                }
                {
                        data: {
                                image: <numpy.ndarray shape=(1, 32, 32) hash=64471622010011648 device=cpu>
                        }
                        object_id: '44'
                }
        ]
outputs:
        batch_dict = {
                object_id: <numpy.ndarray shape=(2,) hash=13615829435956363149 device=cpu>
                data: {
                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=90485229067698176 device=cpu>
                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                }
        }

[11]:
# Access the captured batch tensor directly — attribute and dict access both work
batch_dict = first_collate_call.batch_dict
print(type(batch_dict))
print(batch_dict)
<class 'dict'>
{'object_id': array(['42', '44'], dtype='<U2'), 'data': {'image': array([[[[0.3532344 , 0.607178  , 0.7395802 , ..., 0.57619756,
          0.9900934 , 0.6433857 ],
         [0.74579805, 0.803482  , 0.03820932, ..., 0.0764904 ,
          0.73687065, 0.24468738],
         [0.50652415, 0.8825237 , 0.49233162, ..., 0.77767473,
          0.18754518, 0.10764486],
         ...,
         [0.6180621 , 0.3048144 , 0.18005782, ..., 0.5115385 ,
          0.42217797, 0.20555836],
         [0.99082094, 0.25810826, 0.05632359, ..., 0.6630122 ,
          0.82801294, 0.7073826 ],
         [0.74559903, 0.73958987, 0.44700366, ..., 0.35947347,
          0.70961577, 0.4010756 ]]],


       [[[0.9984733 , 0.681059  , 0.55139065, ..., 0.24912357,
          0.5124912 , 0.42957765],
         [0.58149946, 0.80072296, 0.15518284, ..., 0.2593645 ,
          0.99248797, 0.24980205],
         [0.93099046, 0.98935384, 0.511387  , ..., 0.940135  ,
          0.39599067, 0.8682785 ],
         ...,
         [0.36340976, 0.8872119 , 0.9558456 , ..., 0.061755  ,
          0.3128417 , 0.9939142 ],
         [0.5558388 , 0.05441999, 0.97539544, ..., 0.23685807,
          0.85406137, 0.4906596 ],
         [0.5213873 , 0.25020808, 0.91928196, ..., 0.54372746,
          0.05825591, 0.2868395 ]]]], shape=(2, 1, 32, 32), dtype=float32), 'mask': array([[[[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]],


       [[[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]]],
      shape=(2, 1, 32, 32))}}
[12]:
# Get a list of all calls to a particular function by providing the display name
# (visible from the print output above)
all_resolve_calls = trace_result.resolve_data["DataProvider__resolve_data"]
print(f"Number of resolve_data calls: {len(all_resolve_calls)}")
print(all_resolve_calls[0])
Number of resolve_data calls: 2
DataProvider__resolve_data(index) -> data_dict duration=0.216 ms
inputs:
        index = 19
outputs:
        data_dict = {
                data: {
                        image: <numpy.ndarray shape=(1, 32, 32) hash=118355071061721088 device=cpu>
                }
                object_id: '42'
        }

Tracing other verbs#

The trace=N keyword works the same way for infer and test:

[13]:
infer_trace = h.infer(trace=2)
print(infer_trace)
[2026-06-30 20:45:26,787 hyrax.trace:WARNING] Starting Trace
[2026-06-30 20:45:26,788 hyrax.trace:WARNING] Trace mode enabled, will only run a single batch of length 2
[2026-06-30 20:45:26,864 hyrax.models.model_registry:INFO] Setting model's self.optimizer from config: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2026-06-30 20:45:26,865 hyrax.models.model_registry:INFO] Setting model's self.criterion from config: torch.nn.CrossEntropyLoss with default arguments.
[2026-06-30 20:45:26,866 hyrax.models.model_registry:INFO] Setting model's self.scheduler from config: torch.optim.lr_scheduler.ExponentialLR
with arguments: {'gamma': 1}.
[2026-06-30 20:45:26,866 hyrax.verbs.infer:INFO] Inference model: HyraxAutoencoder
[2026-06-30 20:45:26,867 hyrax.verbs.infer:INFO] Inference dataset(s):
{'infer': Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.9.0/docs/pre_executed/trace_data
  Selected items: 50
  Primary ID field: object_id
  Requested fields: image
}
2026-06-30 20:45:26,868 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<torch.utils.data.da':
        {'sampler': None, 'batch_size': 2, 'collate_fn': <bound method DataProvider.collate of Name: data (primary dataset)
  Dataset class: HyraxRandomDataset
  Data location: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.9.0/docs/pre_executed/trace_data
  Selected items: 50
  Primary ID field: object_id
  Requested fields: image
>, 'pin_memory': False}
[2026-06-30 20:45:26,879 hyrax.models.model_utils:INFO] Updated config['infer']['model_weights_file'] to: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.9.0/docs/pre_executed/results/20260630-204523-train-LTmR/example_model.pth
[2026-06-30 20:45:26,882 hyrax.verbs.infer:INFO] Saving inference results at: /home/docs/checkouts/readthedocs.org/user_builds/hyrax/checkouts/v0.9.0/docs/pre_executed/results/20260630-204526-infer-tif7
[2026-06-30 20:45:26,901 hyrax.pytorch_ignite:INFO] Total evaluation time: 0.01[s]
[2026-06-30 20:45:26,903 hyrax.datasets.result_dataset:INFO] Optimizing Lance table after 1 batches
[2026-06-30 20:45:26,910 hyrax.datasets.result_dataset:INFO] Lance table optimization complete
[2026-06-30 20:45:26,914 hyrax.verbs.infer:INFO] Inference Complete.
Trace Stages {
        dataset_getter: [
                data__get_image(index) -> image duration=0.00855 ms
                inputs:
                        index = 0
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00681 ms
                inputs:
                        index = 0
                outputs:
                        object_id = '23'

                data__get_image(index) -> image duration=0.00287 ms
                inputs:
                        index = 1
                outputs:
                        image = <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>

                data__get_object_id(index) -> object_id duration=0.00251 ms
                inputs:
                        index = 1
                outputs:
                        object_id = '24'

        ]
        resolve_data: [
                DataProvider__resolve_data(index) -> data_dict duration=0.16 ms
                inputs:
                        index = 0
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                }
                                object_id: '23'
                        }

                DataProvider__resolve_data(index) -> data_dict duration=0.0647 ms
                inputs:
                        index = 1
                outputs:
                        data_dict = {
                                data: {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                }
                                object_id: '24'
                        }

        ]
        field_level_collation: []
        collate: [
                DataProvider__collate(batch_dicts) -> batch_dict duration=0.196 ms
                inputs:
                        batch_dicts = <list len=2> [
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                        }
                                        object_id: '23'
                                }
                                {
                                        data: {
                                                image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                        }
                                        object_id: '24'
                                }
                        ]
                outputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=14226149857159391975 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }

                data__collate(samples) -> batch_dict duration=0.0568 ms
                inputs:
                        samples = <list len=2> [
                                {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=26660822658842624 device=cpu>
                                }
                                {
                                        image: <numpy.ndarray shape=(1, 32, 32) hash=7316936887107584 device=cpu>
                                }
                        ]
                outputs:
                        batch_dict = {
                                image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                        }

                DataProvider__handle_nans(batch_dict) -> batch_dict_no_nan duration=0.0982 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=14226149857159391975 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }
                outputs:
                        batch_dict_no_nan = {
                                object_id: <numpy.ndarray shape=(2,) hash=14226149857159391975 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }

        ]
        prepare_inputs: [
                HyraxAutoencoder_inst_prepare_inputs(batch_dict) -> batch_ndarray duration=0.00627 ms
                inputs:
                        batch_dict = {
                                object_id: <numpy.ndarray shape=(2,) hash=14226149857159391975 device=cpu>
                                data: {
                                        image: <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                                        mask: <numpy.ndarray shape=(2, 1, 32, 32) hash=0 device=cpu>
                                }
                        }
                outputs:
                        batch_ndarray = <numpy.ndarray shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>

        ]
        evaluation: [
                HyraxAutoencoder__infer_batch(batch) -> batch_results duration=1.85 ms
                inputs:
                        batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                outputs:
                        batch_results = <torch.Tensor shape=(2, 64) hash=57172974704263168 device=cpu>

                HyraxAutoencoder__forward(batch) -> batch_results duration=1.83 ms
                inputs:
                        batch = <torch.Tensor shape=(2, 1, 32, 32) hash=20065364041793536 device=cpu>
                outputs:
                        batch_results = <torch.Tensor shape=(2, 64) hash=57172974704263168 device=cpu>

        ]
}

Instrumenting custom models and datasets#

If you have written a custom model or dataset class you can opt specific methods into trace capture using the @trace_model_func and @trace_dataset_func decorators.

from hyrax.trace import trace_model_func, trace_dataset_func

class MyModel(nn.Module):
    @trace_model_func
    def my_custom_forward(self, batch):
        # This call will appear in the 'evaluation' stage of the TraceResult
        ...

class MyDataset(HyraxDataset):
    @trace_dataset_func
    def get_image(self, index):
        # This call will appear in the 'dataset_getter' stage of the TraceResult
        ...

The decorators add a small overhead to every call, so they are intended for use during development and debugging rather than in production. Remove them (or use them selectively) once you are satisfied with how your data is flowing.