Getting started with Hyrax#
Installation#
Hyrax can be installed via pip:
pip install hyrax
Hyrax is officially supported and tested with Python versions 3.10, 3.11, 3.12, and 3.13. Other versions may work but are not guaranteed to be compatible.
We strongly encourage the use of a virtual environment when working with Hyrax because Hyrax depends on several open source packages that may have conflicting dependencies with other packages you have installed.
First Steps#
This getting started example uses Hyrax to train a small convolutional neural network to classify CIFAR data. It is based on a similar PyTorch tutorial. We also use the CIFAR10 dataset: Learning multiple layers of features from tiny images. Alex Krizhevsky, 2009.
As part of this example we will:
Create a Hyrax instance
Specify a model and a dataset
Train the model
Predict with the model
Evaluate the results
Create a hyrax instance#
The main driver for Hyrax is the Hyrax class. To get started we’ll create an
instance of this class.
from hyrax import Hyrax
h = Hyrax()
When we create the Hyrax instance, it will automatically load a default configuration file. This file contains default settings for all of the components that Hyrax uses.
Specify a model#
We’ll need to let Hyrax know which model to use for training. Here we’ll tell Hyrax to use the built-in HyraxCNN model that is based on the simple CNN architecture from the PyTorch CIFAR10 tutorial.
h.set_config('model.name', 'HyraxCNN')
Defining the dataset#
We’ll also need to tell Hyrax what data should be used for training, in this case the CIFAR10 dataset. Hyrax has a built in dataset class for working with CIFAR10 data, so we’ll configure that here. You can learn more about the CIFAR10 at the offical site: https://www.cs.toronto.edu/~kriz/cifar.html
1model_inputs_definition = {
2 "train": {
3 "data": {
4 "dataset_class": "HyraxCifarDataset",
5 "data_location": "./data",
6 "fields": ["image", "label"],
7 "primary_id_field": "object_id",
8 },
9 }
10 }
11
12 h.set_config("model_inputs", model_inputs_definition)
This may appear overwhelming, especially for a simple case, but being explicit about the dataset configuration will allow for great flexibility down the line when working with more complex data.
Training the model#
Now that we have the model and data specified, we’re ready for training.
We’ll use the train verb to kick off the training process.
h.train()
Once the training is complete, the model weights will be saved in a timestamped
directory with a name similar to /YYYYmmdd-HHMMSS-train-RAND.
Note that RAND is a random four character string to avoid collisions if you
run multiple training sessions in the same second.
Predicting with the model#
Now that we’ve trained a model, we can use it to infer classes of samples from the CIFAR10 test dataset. First we’ll add to our model input definition to specify the data to use for inference.
1model_inputs_definition["infer"] = {
2 "data": {
3 "dataset_class": "HyraxCifarDataset",
4 "data_location": "./data",
5 "fields": ["image"],
6 "primary_id_field": "object_id",
7 "dataset_config": {
8 "use_training_data": False,
9 },
10 },
11}
12
13h.set_config("model_inputs", model_inputs_definition)
Then we’ll use Hyrax’s infer verb to load the trained model weights and process
the data defined above.
inference_results = h.infer()
Evaluate the performance#
Let’s compare the model’s predictions to the actual labels from the test dataset.
The model’s prediction is a 10 element vector where the largest value represents
the highest confidence class.
So we’ll extract the index of the max value for each prediction and save that as
predicted_classes.
We’ll also load the original test data to get the true labels for comparison.
1import numpy as np
2import pickle
3
4# Accumulate the predicted classes
5predicted_classes = np.zeros(len(inference_results)).astype(int)
6for i, result in enumerate(inference_results):
7 predicted_classes[i] = np.argmax(result["model_output"])
8
9# Load the true labels
10with open("./data/cifar-10-batches-py/test_batch", "rb") as f_in:
11 test_data = pickle.load(f_in, encoding="bytes")
Using scikit-learn’s confusion_matrix, we can compute and display the confusion matrix
to see how well the model performed on each class.
1import matplotlib.pyplot as plt
2from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
3
4y_true = test_data[b"labels"]
5y_pred = predicted_classes.tolist()
6
7correct = 0
8for t, p in zip(y_true, y_pred):
9 correct += t == p
10
11print("\nAccuracy for test dataset:", correct / len(y_true))
12
13cm = confusion_matrix(y_true, y_pred)
14disp = ConfusionMatrixDisplay(confusion_matrix=cm)
15disp.plot()
16plt.show()
>> Accuracy for test dataset: 0.5003
The model performs much better than chance (which would be 10%) with some classes being predicted more accurately.#