Hyrax Getting Started
In this getting started notebook we’ll create an instance of a Hyrax object, train a builtin model on the CiFAR training dataset, and then use that trained model to run inference on the CiFAR testing dataset.
Create a Hyrax instance
[1]:
import hyrax
h = hyrax.Hyrax()
[2025-08-25 21:14:19,112 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
Update the configuration
[2]:
h.config["model"]["name"] = "HyraxAutoencoder"
For this demo, we’ll make a few adjustments to the default configuration settings that the hyrax object was instantiated with. By accessing the .config attribute of the hyrax instance, we can modify any configuration value. There are many configuration values that can be set, but here, we update only the model to train.
Train a model
[3]:
h.train()
[2025-08-25 21:14:33,430 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-08-25 21:14:33,503 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x70da50623980>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-08-25 21:14:33,504 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr':
{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x70da5046eea0>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
/home/drew/miniconda3/envs/hyrax/lib/python3.12/site-packages/ignite/handlers/tqdm_logger.py:127: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
from tqdm.autonotebook import tqdm
2025/08/25 21:14:33 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-08-25 21:14:33,874 hyrax.pytorch_ignite:INFO] Training model on device: cuda
[2025-08-25 21:16:09,062 hyrax.pytorch_ignite:INFO] Total training time: 95.19[s]
[2025-08-25 21:16:09,063 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211424-train-0D0a/checkpoint_epoch_10.pt
[2025-08-25 21:16:09,063 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211424-train-0D0a/checkpoint_9_loss=-126.9350.pt
2025/08/25 21:16:09 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/08/25 21:16:09 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-08-25 21:16:09,079 hyrax.verbs.train:INFO] Finished Training
[2025-08-25 21:16:09,364 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211424-train-0D0a/example_model_opset_20.onnx
[3]:
HyraxAutoencoder(
(encoder): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): GELU(approximate='none')
(6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): GELU(approximate='none')
(8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(9): GELU(approximate='none')
(10): Flatten(start_dim=1, end_dim=-1)
(11): Linear(in_features=1024, out_features=64, bias=True)
)
(dec_linear): Sequential(
(0): Linear(in_features=64, out_features=1024, bias=True)
(1): GELU(approximate='none')
)
(decoder): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(1): GELU(approximate='none')
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): GELU(approximate='none')
(4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(5): GELU(approximate='none')
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): GELU(approximate='none')
(8): ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(9): Tanh()
)
(criterion): CrossEntropyLoss()
)
The output of the training will be stored in a time-stamped directory under the ./results/. By default, a copy of the final configuration used in training is persisted as runtime_config.toml. To train again with the same configuration, you can reference this runtime_config.toml file.
If running in another notebook, instantiate a hyrax object like so:
new_hyrax_instance = hyrax.Hyrax(config_file='./results/<timestamped_directory>/runtime_config.toml')
Or from the command line:
>> hyrax train --runtime-config ./results/<timestamped_directory>/runtime_config.toml
Note here we’re training on only a small handful of CiFAR data, but Hyrax has demonstrated that it can scale up to training sets with >1M samples.
Run inference
[4]:
h.config["data_set"]["test_size"] = 1.0
h.config["data_set"]["train_size"] = 0.0
h.config["data_set"]["validate_size"] = 0.0
h.config["data_loader"]["batch_size"] = 128
h.infer()
[2025-08-25 21:16:17,243 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-08-25 21:16:17,244 hyrax.verbs.infer:INFO] data set has length 50000
2025-08-25 21:16:17,246 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<hyrax.data_sets.hyr':
{'sampler': None, 'batch_size': 128, 'shuffle': False, 'pin_memory': True}
[2025-08-25 21:16:17,263 hyrax.verbs.infer:INFO] Saving inference results at: /home/drew/code/hyrax/docs/pre_executed/results/20250825-211609-infer-SY49
[2025-08-25 21:16:17,602 hyrax.pytorch_ignite:INFO] Evaluating model on device: cuda
[2025-08-25 21:16:17,605 hyrax.pytorch_ignite:INFO] Total epochs: 1
[2025-08-25 21:16:30,310 hyrax.pytorch_ignite:INFO] Total evaluation time: 12.71[s]
[2025-08-25 21:16:30,434 hyrax.verbs.infer:INFO] Inference Complete.
[4]:
<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x70db7a0bafc0>
Once a model has been trained, we can use the model weights file to run inference. By default running infer will look for the latest available model weights file. A specific model weights file can be specified with h.config['infer']['model_weights_file'] = <path_to_model_weights_file>.
Here we’ll make use of the last trained model weights file, and update the data set splits so that 100% of the data will be used for inference.
With the configuration updated, we can run inference by calling h.infer().
The results of running inference are saved in the output directory noted in the last log line. The default output format is batched .npy files. Additionally a ChromaDB vector database will be populated with the inference results to enable efficient similarity search.