Hyrax Model Export

In this getting started notebook we’ll create an instance of a Hyrax object, train a builtin model on the CiFAR training dataset, and then show how to export the model’s weights for inspection or evaluation.

It is recommended that you use Hyrax’s infer verb for batch evaluation.

Train a Hyrax model

We will configure it to use the builtin HyraxAutoencoder model, and immediately run training on the sample CIFAR dataset. Using the prepare and infer verbs we will also save the input dataset as well as the latent space representation post-training for future exploration.

[ ]:
import hyrax
import torch

h = hyrax.Hyrax()
h.config["model"]["name"] = "HyraxAutoencoder"

dataset = h.prepare()
model = h.train()
latent_space = h.infer()
[2025-07-07 15:33:07,791 hyrax:INFO] Runtime Config read from: /Users/mtauraso/src/hyrax/src/hyrax/hyrax_default_config.toml
/Users/mtauraso/miniforge3/envs/hyrax/lib/python3.10/site-packages/ignite/handlers/checkpoint.py:16: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import ZeroRedundancyOptimizer
Files already downloaded and verified
[2025-07-07 15:33:12,556 hyrax.prepare:INFO] Finished Prepare
Files already downloaded and verified
[2025-07-07 15:33:15,425 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-07-07 15:33:15,434 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x151fc8220>, 'batch_size': 512, 'shuffle': False, 'pin_memory': False}
2025-07-07 15:33:15,435 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x151fc8310>, 'batch_size': 512, 'shuffle': False, 'pin_memory': False}
/Users/mtauraso/miniforge3/envs/hyrax/lib/python3.10/site-packages/ignite/handlers/tqdm_logger.py:127: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
2025/07/07 15:33:15 WARNING mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics because creating `GPUMonitor` failed with error: Failed to initialize NVML, skip logging GPU metrics: NVML Shared Library Not Found.
2025/07/07 15:33:15 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-07-07 15:33:15,591 hyrax.pytorch_ignite:INFO] Training model on device: mps
[2025-07-07 15:33:58,529 hyrax.pytorch_ignite:INFO] Total training time: 42.94[s]
[2025-07-07 15:33:58,530 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153312-train-tNXh/checkpoint_epoch_10.pt
[2025-07-07 15:33:58,530 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153312-train-tNXh/checkpoint_10_loss=-132.5959.pt
2025/07/07 15:33:58 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/07/07 15:33:58 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-07-07 15:33:58,542 hyrax.verbs.train:INFO] Finished Training
[2025-07-07 15:33:58,911 hyrax.model_exporters:INFO] Exported model to ONNX format: /Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153312-train-tNXh/example_model_opset_20.onnx
Files already downloaded and verified
[2025-07-07 15:34:01,733 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-07-07 15:34:01,733 hyrax.verbs.infer:INFO] data set has length 50000
2025-07-07 15:34:01,734 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr':
        {'sampler': None, 'batch_size': 512, 'shuffle': False, 'pin_memory': False}
[2025-07-07 15:34:01,752 hyrax.verbs.infer:INFO] Saving inference results at: /Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153358-infer-A1QT
[2025-07-07 15:34:01,913 hyrax.pytorch_ignite:INFO] Evaluating model on device: mps
[2025-07-07 15:34:01,913 hyrax.pytorch_ignite:INFO] Total epochs: 1
[2025-07-07 15:34:05,654 hyrax.pytorch_ignite:INFO] Total evaluation time: 3.74[s]
[2025-07-07 15:34:05,701 hyrax.verbs.infer:INFO] Inference Complete.
Files already downloaded and verified

Inspect the model

The return value from train is the torch model with its weights set by training.

We can run individual data through the model in order to see the output. Note that whenever we pass a single tensor to our model, we must do torch.stack([<our data>]). This is because torch module functions accept and return batches of data rather than individual items.

[26]:
test_batch = torch.stack([dataset[0]["image"]])
encoded = model.forward(test_batch)
encoded[0]
[26]:
tensor([-0.2684,  2.0327,  0.6681,  0.4684, -0.0093, -0.7515, -0.4369, -0.0560,
        -0.4807, -0.2047, -0.5424, -1.7949,  0.3290, -0.0839,  0.2517, -0.4631,
        -0.5715,  0.9547,  2.2121,  1.6036, -0.2378, -0.6814, -1.1888,  0.1074,
         0.6576, -0.9482, -0.3272,  0.1347, -0.3849,  0.3405,  0.6250,  0.1028,
         1.2349,  0.9591, -1.0280,  0.0798, -0.3281,  1.2027,  0.1158, -1.4270,
        -0.1354,  0.5108, -0.2088, -0.6593, -2.0212, -0.6324, -0.8815, -1.0164,
        -0.4959, -0.2724,  0.5190,  0.0217, -0.8420,  1.0206,  1.4802,  0.6457,
        -0.0134, -0.1939,  0.3689,  0.4015, -0.0540,  0.9022, -0.3325,  0.3758],
       grad_fn=<SelectBackward0>)

HyraxAutoencoder has private members _eval_encoder and _eval_decoder in addition to the normal forward function required of a model class. We can call these to see the decoded version of the model’s latent space.

[27]:
decoded = model._eval_decoder(encoded)
decoded
[27]:
tensor([[[[-0.3913, -0.4311, -0.4035,  ..., -0.0961, -0.1124, -0.1265],
          [-0.4319, -0.4706, -0.4257,  ..., -0.1439, -0.1663, -0.1757],
          [-0.4710, -0.4931, -0.4316,  ..., -0.1495, -0.1690, -0.1926],
          ...,
          [ 0.3070,  0.3168,  0.3672,  ...,  0.1029,  0.0959,  0.0497],
          [ 0.2906,  0.3178,  0.3541,  ...,  0.0275,  0.0309, -0.0040],
          [ 0.2458,  0.2776,  0.2970,  ..., -0.0419, -0.0128, -0.0273]],

         [[-0.4696, -0.5285, -0.5056,  ..., -0.2339, -0.2303, -0.1976],
          [-0.5366, -0.5698, -0.5865,  ..., -0.3387, -0.3203, -0.2556],
          [-0.5772, -0.6236, -0.5876,  ..., -0.3770, -0.3559, -0.3159],
          ...,
          [ 0.1188,  0.1156,  0.0948,  ..., -0.1672, -0.1211, -0.1243],
          [ 0.1446,  0.1269,  0.1182,  ..., -0.1876, -0.1373, -0.0832],
          [ 0.1481,  0.1643,  0.1435,  ..., -0.1726, -0.0903, -0.1622]],

         [[-0.6897, -0.7122, -0.7377,  ..., -0.5629, -0.5428, -0.4896],
          [-0.7524, -0.7772, -0.7948,  ..., -0.6499, -0.6240, -0.5584],
          [-0.7834, -0.7909, -0.8095,  ..., -0.6947, -0.6706, -0.6258],
          ...,
          [-0.2752, -0.3248, -0.3875,  ..., -0.5943, -0.5198, -0.4225],
          [-0.2124, -0.2342, -0.3135,  ..., -0.5495, -0.4852, -0.4268],
          [-0.1406, -0.1733, -0.2076,  ..., -0.4798, -0.3858, -0.3236]]]],
       grad_fn=<AliasBackward0>)

Export the model

The model is already exported in the most recent results directory in two forms:

  1. A pytorch weights file example_model.pth

  2. An onnx weights file example_model_opset_##.onnx

This directory is visible in the output from training; however, we can also list it programattically so you can see the files:

[28]:
import os

results_dir = hyrax.config_utils.find_most_recent_results_dir(h.config, "train")
print(results_dir)
os.listdir(results_dir)
/Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153312-train-tNXh
[28]:
['checkpoint_10_loss=-132.5959.pt',
 'runtime_config.toml',
 'example_model.pth',
 'events.out.tfevents.1751927592.Michaels-MacBook-Pro.local',
 'checkpoint_epoch_10.pt',
 'example_model_opset_20.onnx']

Running your trained model outside Hyrax

If you want to run your trained model without importing hyrax, we highly recommend using the onnx runtime. This is because using the pytorch model weights file has significant drawbacks we will address in the next section.

Evaluating a model using Onnx in python is quite simple.

Further information on using Onnx can be found in the Onnx Documentation.

[31]:
# Get the filename from the results directory
onnx_model_filename = [filename for filename in os.listdir(results_dir) if filename[-4:] == "onnx"][0]
onnx_model_path = results_dir / onnx_model_filename
print(f"Onnx model filename: {onnx_model_path}")

# Run our single datum with ONNX
import onnxruntime as ort

test_batch = torch.stack([dataset[0]["image"]])
ort_sess = ort.InferenceSession(onnx_model_path)
outputs = ort_sess.run(None, {"input": test_batch.numpy()})
outputs
Onnx model filename: /Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153312-train-tNXh/example_model_opset_20.onnx
[31]:
[array([[-0.2683956 ,  2.032734  ,  0.6681053 ,  0.46839038, -0.00929314,
         -0.7515385 , -0.43687135, -0.05602254, -0.48072147, -0.20469648,
         -0.5423851 , -1.7948971 ,  0.3289881 , -0.08393019,  0.25169367,
         -0.4631039 , -0.5715224 ,  0.9547064 ,  2.2121174 ,  1.603585  ,
         -0.23777348, -0.68139553, -1.1888374 ,  0.10737786,  0.6576071 ,
         -0.9481759 , -0.32715958,  0.13470943, -0.38490096,  0.3404901 ,
          0.6249982 ,  0.10281795,  1.2349452 ,  0.9590781 , -1.028022  ,
          0.0798444 , -0.3281045 ,  1.2027371 ,  0.1158179 , -1.4270462 ,
         -0.13537334,  0.51084507, -0.2087721 , -0.6593263 , -2.0212321 ,
         -0.6324384 , -0.881453  , -1.0164088 , -0.49588868, -0.27238098,
          0.5189738 ,  0.02166937, -0.841993  ,  1.0205503 ,  1.4802283 ,
          0.6457264 , -0.0134454 , -0.19393134,  0.36890328,  0.40152574,
         -0.05401462,  0.9021904 , -0.33254507,  0.3757854 ]],
       dtype=float32)]

Running your trained model with pytorch

(not recommended)

In order to load a pytorch file with weights the exact class structure of the pytorch model must be the same at load and save time. This means that you need a full copy of your python model class up to date and available in the program where you load the weights. You will also need to ensure that Python and PyTorch are on the exact same versions.

If these things are true, loading the model is relatively straightforward:

[ ]:
# Get the filename from the results directory
pth_model_filename = [filename for filename in os.listdir(results_dir) if filename[-3:] == "pth"][0]
pth_model_path = results_dir / pth_model_filename
print(f"Pytorch module filename: {pth_model_path}")


from hyrax.models.hyrax_autoencoder import HyraxAutoencoder

test_batch = torch.stack([dataset[0]["image"]])
imported_model = HyraxAutoencoder(dataset=dataset, config=h.config)
imported_model.load(pth_model_path)
imported_model.to(device="cpu")

encoded_from_import = imported_model.forward(test_batch)
encoded_from_import[0]
[2025-07-07 16:15:53,728 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
Pytorch module filename: /Users/mtauraso/src/hyrax/docs/pre_executed/results/20250707-153312-train-tNXh/example_model.pth
tensor([-0.2684,  2.0327,  0.6681,  0.4684, -0.0093, -0.7515, -0.4369, -0.0560,
        -0.4807, -0.2047, -0.5424, -1.7949,  0.3290, -0.0839,  0.2517, -0.4631,
        -0.5715,  0.9547,  2.2121,  1.6036, -0.2378, -0.6814, -1.1888,  0.1074,
         0.6576, -0.9482, -0.3272,  0.1347, -0.3849,  0.3405,  0.6250,  0.1028,
         1.2349,  0.9591, -1.0280,  0.0798, -0.3281,  1.2027,  0.1158, -1.4270,
        -0.1354,  0.5108, -0.2088, -0.6593, -2.0212, -0.6324, -0.8815, -1.0164,
        -0.4959, -0.2724,  0.5190,  0.0217, -0.8420,  1.0206,  1.4802,  0.6457,
        -0.0134, -0.1939,  0.3689,  0.4015, -0.0540,  0.9022, -0.3325,  0.3758],
       grad_fn=<SelectBackward0>)