Multi-modal data with a GraphQL alternative

This notebook will demonstrate working with multi-modal or multi-source data. By default, Hyrax assumes only a single dataset will be used for a given model.

As before, the two required configuration parameters are config['general']['data_dir'] (to define the location of the data) and config['data_set']['name'] (to define the dataset class that will load specific data from disk).

Below we the standard approach of loading the default dataset defined in the Hyrax config using h.prepare(). We then print the dataset object to see the basic information about it.

[1]:
from hyrax import Hyrax

h = Hyrax()

# Set a few configs for later use
h.config["train"]["epochs"] = 1
h.config["data_set"]["name"] = "HyraxRandomDataset"
h.config["data_set"]["HyraxRandomDataset"]["shape"] = (1, 32, 32)
h.config["data_set"]["HyraxRandomDataset"]["size"] = 50000
[2025-09-12 12:16:27,891 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml
[2]:
ds = h.prepare()
print(ds)
[2025-09-12 12:16:30,455 hyrax.data_sets.data_provider:INFO] No fields were specified for 'data'. The request will be modified to select all by default. You can specify `fields` in `model_inputs`.
[2025-09-12 12:16:30,768 hyrax.prepare:INFO] Finished Prepare
Name: data
  Dataset class: HyraxRandomDataset
  Data location: /Users/drew/code/hyrax/docs/pre_executed/data
  Requested fields: image, label, meta_field_1, meta_field_2, object_id

Pass the data through to_tensor

Since we have access to the model class, we can call the to_tensor method with example data. This allows easy checking that the output matches the expectations of the model architecture.

In this example, we expect to_tensor to return a tuple of (Tensor, int), or specifically a multi-channel image and a label.

[3]:
samp = ds[2335]
print(samp)
m = h.model()
res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(res[0])}, {res[0].shape}")
{'data': {'image': array([[[0.9993206 , 0.24844742, 0.07472038, ..., 0.19890529,
         0.6762412 , 0.2656371 ],
        [0.26220918, 0.305416  , 0.83473223, ..., 0.00641263,
         0.21008873, 0.7148612 ],
        [0.8158195 , 0.7323841 , 0.13330418, ..., 0.81598246,
         0.2337833 , 0.3971014 ],
        ...,
        [0.41137224, 0.40423387, 0.19257122, ..., 0.34238744,
         0.58312994, 0.7775264 ],
        [0.3764146 , 0.6264323 , 0.28519917, ..., 0.14582026,
         0.93530905, 0.5605867 ],
        [0.75328624, 0.73888373, 0.36882102, ..., 0.8995268 ,
         0.3712474 , 0.03462631]]], dtype=float32), 'label': np.int64(0), 'meta_field_1': np.float64(23832.5), 'meta_field_2': np.float64(15888.333333333334), 'object_id': '2388'}, 'object_id': '2388'}
Type and shape of resulting image: <class 'numpy.ndarray'>, (32, 32)

Defining the data request specification

The following shows a way to request new datasets. For this example, we’ll use a contrived scenario were we want two instances of cifar images from the HyraxCifarDataset in addition to a random image from the HyraxRandomDataset.

We can directly modify the loaded configuration using h.config["model_inputs"]

[4]:
h.config["model_inputs"] = {
    "cifar_0": {
        "dataset_class": "HyraxCifarDataSet",
        "data_location": "./data",
        "fields": ["object_id", "image", "label"],
        "primary_id_field": "object_id",
    },
    "cifar_1": {
        "dataset_class": "HyraxCifarDataSet",
        "data_location": "./data",
        "fields": ["image", "label"],
    },
    "random": {
        "dataset_class": "HyraxRandomDataset",
        "data_location": "./data",
        "fields": ["image"],
        "dataset_config": {
            "seed": 4200,
        },
    },
}

When experimentation is complete, the preceding can be specified directly in a configuration .toml file like so:

[model_inputs]
[model_inputs.cifar_0]
dataset_class = "HyraxCifarDataSet"
data_location = "./data"
fields = ["object_id", "image", "label"]
primary_id_field = "object_id"

[model_inputs.cifar_1]
dataset_class = "HyraxCifarDataSet"
data_location = "./data"
fields = ["image"]

[model_inputs.random]
dataset_class = "HyraxRandomDataset"
data_location = "./data"
fields = ["image"]
[model_inputs.rando.dataset_config]
seed = 4200

Examine the multimodal dataset

As before, calling h.prepare() will return an instance of the DataProvider dataset. The DataProvider class can be thought of as a container of multiple datasets, as well as a gateway (in GraphQL terminology) that will send requests for specific data to the datasets it contains. Printing the dataset object will show the configuration of the dataset.

[5]:
ds = h.prepare()
print(ds)
[2025-09-12 12:16:44,770 hyrax.prepare:INFO] Finished Prepare
Name: cifar_0
  Dataset class: HyraxCifarDataSet
  Data location: ./data
  Requested fields: object_id, image, label
Name: cifar_1
  Dataset class: HyraxCifarDataSet
  Data location: ./data
  Requested fields: image, label
Name: random
  Dataset class: HyraxRandomDataset
  Data location: ./data
  Requested fields: image
  Dataset config:
    seed: 4200

[6]:
ds.metadata_fields()
[6]:
['label_cifar_0',
 'object_id_cifar_0',
 'label_cifar_1',
 'object_id_cifar_1',
 'object_id_random',
 'meta_field_1_random',
 'meta_field_2_random',
 'object_id']
[7]:
ds.metadata(idxs=[0, 1, 2, 4, 5], fields=["label_cifar_0", "label_cifar_1", "meta_field_1_random"])
[7]:
array([(6, 6, 25000. ), (9, 9, 24999.5), (9, 9, 24999. ), (1, 1, 24998. ),
       (1, 1, 24997.5)],
      dtype=[('label_cifar_0', '<i8'), ('label_cifar_1', '<i8'), ('meta_field_1_random', '<f8')])

The various datasets contained within the DataProvider instance.

[8]:
ds.prepped_datasets
[8]:
{'cifar_0': <hyrax.data_sets.hyrax_cifar_data_set.HyraxCifarDataSet at 0x16282a570>,
 'cifar_1': <hyrax.data_sets.hyrax_cifar_data_set.HyraxCifarDataSet at 0x16282b590>,
 'random': <hyrax.data_sets.random.hyrax_random_dataset.HyraxRandomDataset at 0x163396750>}
[9]:
print(f"Is iterable: {ds.is_iterable()}")  # Should return False
print(f"Is mappable: {ds.is_map()}")  # Should return True
Is iterable: False
Is mappable: True

Checking the length of the dataset is the same as always.

[10]:
print(f"Length of the multimodal dataset: {len(ds)}")
print(f"Length of a specific dataset contained inside: {len(ds.prepped_datasets['cifar_0'])}")
Length of the multimodal dataset: 50000
Length of a specific dataset contained inside: 50000
[11]:
samp = ds[2335]
print("Fields from cifar_0")
print(samp["cifar_0"]["image"].shape)
print(samp["cifar_0"]["label"])
print(samp["cifar_0"]["object_id"])
print("Fields from cifar_1")
print(samp["cifar_1"]["image"].shape)
print(samp["cifar_1"]["label"])
print("Fields from random")
print(samp["random"]["image"].shape)
Fields from cifar_0
(3, 32, 32)
5
2335
Fields from cifar_1
(3, 32, 32)
5
Fields from random
(1, 32, 32)

Updating to_tensor

The default implementation of to_tensor only makes use of “cifar_0” and “rando”. But if we are experimenting, we don’t want to have to make code changes in the model class. It would be much easier to experiment with in the notebook. Here, we redefine the to_tensor method, and check the results by running sample data through the method.

[12]:
import torch
import numpy as np


@staticmethod
def to_tensor(data_dict):
    """This function converts structured data to the input tensor we need to run

    Parameters
    ----------
    data_dict : dict
        The dictionary returned from our data source
    """
    cifar_data = data_dict.get("cifar_0", {})
    random_data = data_dict.get("random", {})
    more_cifar_data = data_dict.get("cifar_1", {})

    # Get the image data from each dataset
    cifar_image = cifar_data["image"]
    random_image = random_data["image"]
    more_cifar_image = more_cifar_data["image"]

    stack_dim = 0 if cifar_image.ndim == 3 else 1

    # Stack the images together to produce a single 7 channel image.
    image = torch.from_numpy(np.concatenate([cifar_image, random_image, more_cifar_image], axis=stack_dim))

    return image


m.to_tensor = to_tensor

After running the same sample through as before, we can see that the number of channels in the image has changed (from 4 to 7), while all the other values have remained the same.

[13]:
new_res = m.to_tensor(samp)
print(f"Type and shape of resulting 0th layer of the image: {type(new_res[0])}, {new_res[0].shape}")
print(f"Type and shape of the 1st layer of the image: {type(new_res[1])}, {new_res[1].shape}")
print(f"Overall shape of the returned image: {new_res.shape}")
Type and shape of resulting 0th layer of the image: <class 'torch.Tensor'>, torch.Size([32, 32])
Type and shape of the 1st layer of the image: <class 'torch.Tensor'>, torch.Size([32, 32])
Overall shape of the returned image: torch.Size([7, 32, 32])

Train with this model

Now that we’ve seen that the to_tensor method is returning a reasonable form of data, we can train our model. As before, we call h.train(). While it is quiet verbose, the initialization logging shows that the model instance is created with data from the DataProvider class, and that our new implementation of to_tensor is being used to manipulate the data from DataProvider into the a form that our model architecture accepts.

[14]:
h.train()
[2025-09-12 12:16:58,923 hyrax.models.hyrax_autoencoder:INFO] Found shape: torch.Size([7, 32, 32]) in data sample, using this to initialize model.
[2025-09-12 12:16:58,926 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
2025-09-12 12:16:58,938 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: cifar_0
  Data':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x1633a2810>, 'batch_size': 512, 'shuffle': False, 'pin_memory': False}
2025-09-12 12:16:58,939 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: cifar_0
  Data':
        {'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x1633a0950>, 'batch_size': 512, 'shuffle': False, 'pin_memory': False}
/Users/drew/opt/miniconda3/envs/hyrax/lib/python3.12/site-packages/ignite/handlers/tqdm_logger.py:127: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  from tqdm.autonotebook import tqdm
2025/09/12 12:16:59 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2025/09/12 12:16:59 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-09-12 12:16:59,068 hyrax.pytorch_ignite:INFO] Training model on device: mps
[2025-09-12 12:17:15,552 hyrax.pytorch_ignite:INFO] Total training time: 16.48[s]
[2025-09-12 12:17:15,553 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121644-train-1-X8/checkpoint_epoch_1.pt
[2025-09-12 12:17:15,554 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121644-train-1-X8/checkpoint_1_loss=-1077.0980.pt
2025/09/12 12:17:15 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/09/12 12:17:15 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-09-12 12:17:15,571 hyrax.verbs.train:INFO] Finished Training
[2025-09-12 12:17:16,239 hyrax.model_exporters:INFO] Exported model to ONNX format: /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121644-train-1-X8/example_model_opset_20.onnx
[14]:
HyraxAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(7, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU(approximate='none')
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU(approximate='none')
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): GELU(approximate='none')
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU(approximate='none')
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): GELU(approximate='none')
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=1024, out_features=64, bias=True)
  )
  (dec_linear): Sequential(
    (0): Linear(in_features=64, out_features=1024, bias=True)
    (1): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): GELU(approximate='none')
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU(approximate='none')
    (4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): GELU(approximate='none')
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU(approximate='none')
    (8): ConvTranspose2d(32, 7, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (9): Tanh()
  )
  (criterion): CrossEntropyLoss()
)
[15]:
h.infer()
[2025-09-12 12:17:30,395 hyrax.models.hyrax_autoencoder:INFO] Found shape: torch.Size([7, 32, 32]) in data sample, using this to initialize model.
[2025-09-12 12:17:30,398 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-09-12 12:17:30,398 hyrax.verbs.infer:INFO] data set has length 50000
2025-09-12 12:17:30,399 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset 'Name: cifar_0
  Data':
        {'sampler': None, 'batch_size': 512, 'shuffle': False, 'pin_memory': False}
[2025-09-12 12:17:30,410 hyrax.verbs.infer:INFO] Saving inference results at: /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121716-infer-P7Wo
[2025-09-12 12:17:30,614 hyrax.pytorch_ignite:INFO] Evaluating model on device: mps
[2025-09-12 12:17:30,614 hyrax.pytorch_ignite:INFO] Total epochs: 1
[2025-09-12 12:17:46,576 hyrax.pytorch_ignite:INFO] Total evaluation time: 15.96[s]
[2025-09-12 12:17:46,630 hyrax.verbs.infer:INFO] Inference Complete.
[15]:
<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x16282a7e0>
[16]:
h.umap()
[2025-09-12 12:18:04,144 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121716-infer-P7Wo for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-09-12 12:18:18,301 hyrax.verbs.umap:INFO] Saving UMAP results to /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121818-umap-C6Hl
[2025-09-12 12:18:18,548 hyrax.verbs.umap:INFO] Fitting the UMAP
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
[2025-09-12 12:18:23,240 hyrax.verbs.umap:INFO] Saving fitted UMAP Reducer
[2025-09-12 12:19:06,495 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP
[16]:
<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x40d7ace30>
[17]:
h.config["visualize"]["fields"] = ["meta_field_1_random", "meta_field_2_random"]
v = h.visualize()
[2025-09-12 12:19:23,181 hyrax.verbs.visualize:INFO] UMAP directory not specified at runtime. Reading from config values.
[2025-09-12 12:19:23,182 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121818-umap-C6Hl for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-09-12 12:19:37,275 hyrax.verbs.visualize:INFO] Rendering UMAP from the following directory: /Users/drew/code/hyrax/docs/pre_executed/results/20250912-121818-umap-C6Hl