Hyrax Hello World#
This example walks through a small end-to-end Hyrax workflow using a built-in convolutional neural network. This notebook uses CIFAR10 data rather than astronomy data so that the basic Hyrax workflow is easier to see and understand. The example is based in part on the PyTorch CIFAR10 tutorial and uses the CIFAR10 dataset described in Learning multiple layers of features from tiny images, Alex Krizhevsky, 2009.
In Hyrax, the basic pattern is to first create a
Hyraxobject, configure a model and adata_request, and then run verbs such astrainandinfer.
This example will:
Create a Hyrax instance
Specify a model and a dataset
Train the model
Predict with the model
Evaluate the results
Install Hyrax#
Before we begin we’ll need to install Hyrax. You can skip this step if you’re running locally and have already installed Hyrax in your virtual environment.
[ ]:
%pip install hyrax
Create a Hyrax instance#
The main driver for Hyrax is the Hyrax class. To get started we’ll create an instance of this class.
When we create the Hyrax instance, it automatically loads a default configuration file containing baseline settings for the components Hyrax uses. In this example, we will update configuration values directly in Python with h.set_config(...), but configuration can also be supplied in an external TOML file with entries such as model.name = "HyraxCNN". You can read more about this in the configuration system.
[16]:
from hyrax import Hyrax
h = Hyrax()
Specify a model#
We’ll need to let Hyrax know which model to use for training. Here we’ll tell Hyrax to use the built-in HyraxCNN model that is based on the simple CNN architecture from the PyTorch CIFAR10 tutorial.
In Hyrax terms, specifying the model is one of the required inputs for a run.
You do not need to use a built-in model. Hyrax also supports defining custom models.
[ ]:
h.set_config("model.name", "HyraxCNN")
Define the dataset#
We’ll also need to tell Hyrax what data should be used for training, in this case the CIFAR10 dataset.
Hyrax has a built-in dataset class for working with CIFAR10 data, so we’ll configure that here. You can learn more about CIFAR10 at the official site: https://www.cs.toronto.edu/~kriz/cifar.html
This may appear overwhelming, especially for a simple case, but being explicit about the dataset configuration will allow for greater flexibility when working with more complex data. In Hyrax, the data_request defines what data each stage of the workflow should use.
Hyrax also includes built-in datasets for several astronomy workflows, and users can define custom dataset classes when needed.
[ ]:
data_request_definition = {
"train": {
"data": {
"dataset_class": "HyraxCifarDataset",
"data_location": "./data",
"fields": ["image", "label"],
"primary_id_field": "object_id",
"split_fraction": 1.0,
},
},
}
h.set_config("data_request", data_request_definition)
Train the model#
Now that we have the model and data specified, we’re ready for training. We’ll use the train verb to kick off the training process.
[ ]:
trained_model = h.train()
Once the training is complete, the model weights will be saved in a timestamped directory with a name similar to .../YYYYmmdd-HHMMSS-train-RAND/. Here RAND is a random four-character string used to avoid collisions if you run multiple training sessions in the same second.
Predict with the model#
Now that we’ve trained a model, we can use it to infer classes of samples from the CIFAR10 test dataset. This follows the same pattern as training: first update the configuration, then run the next verb.
First we’ll add to our model input definition to specify the data to use for inference.
[ ]:
data_request_definition["infer"] = {
"data": {
"dataset_class": "HyraxCifarDataset",
"data_location": "./data",
"fields": ["image", "object_id"],
"primary_id_field": "object_id",
"dataset_config": {
"HyraxCifarDataset": {
"use_training_data": False,
},
},
},
}
h.set_config("data_request", data_request_definition)
Then we’ll use Hyrax’s infer verb to load the trained model weights and process the data defined above.
[ ]:
inference_results = h.infer()
Evaluate the performance#
Let’s compare the model’s predictions to the actual labels from the test dataset. The model’s prediction is a 10 element vector where the largest value represents the highest confidence class. So we’ll extract the index of the max value for each prediction and save that as predicted_classes.
[ ]:
# get the index of the maximum predicted class for all test samples
import numpy as np
predicted_classes = np.zeros(len(inference_results)).astype(int)
for i, result in enumerate(inference_results):
predicted_classes[i] = np.argmax(result)
We’ll also load the original test data to get the true labels for comparison.
[22]:
import pickle
with open("./data/cifar-10-batches-py/test_batch", "rb") as f_in:
test_data = pickle.load(f_in, encoding="bytes")
Using scikit-learn’s confusion_matrix, we can compute and display the confusion matrix to see how well the model performed on each class.
[23]:
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
y_true = test_data[b"labels"]
y_pred = predicted_classes.tolist()
correct = 0
for t, p in zip(y_true, y_pred):
correct += t == p
print("\nAccuracy for test dataset:", correct / len(y_true))
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()
Accuracy for test dataset: 0.4676
The overall accuracy is around 50%, significantly better than random chance, which would be 10%. Accuracy of about 50% is in agreement with the PyTorch example results.
What to take away
A Hyrax workflow starts by configuring a
Hyraxobject, which can be configured using a configuration file or updated directly in Python.The model and the
data_requestare configured separately and are both required inputs to the workflow.Verbs such as
trainandinferrun the configured workflow.