List Models and Dataset Classes#

Hyrax provides two convenience methods for discovering what models and dataset classes are available in your current environment:

  • h.list_models() — prints an alphabetically sorted list of registered model names.

  • h.list_dataset_classes() — prints an alphabetically sorted list of registered dataset class names.

Both lists automatically include any models or dataset classes contributed by installed third-party plugins.

[1]:
from hyrax import Hyrax

h = Hyrax()

List available models#

[2]:
h.list_models()
[2]:
['HSCAutoencoder',
 'HSCDCAE',
 'HyraxAutoencoder',
 'HyraxAutoencoderV2',
 'HyraxCNN',
 'HyraxLoopback',
 'ImageDCAE',
 'SimCLR']

List available dataset classes#

[3]:
h.list_dataset_classes()
[3]:
['DownloadedLSSTDataset',
 'FitsImageDataset',
 'HSCDataset',
 'HyraxCSVDataset',
 'HyraxCifarDataset',
 'HyraxRandomDataset',
 'InferenceDataset',
 'LSSTDataset',
 'MultimodalUniverseDataset',
 'ResultDataset']

Registering custom models and dataset classes#

New models and dataset classes are automatically added to the registry as soon as their class definition is executed. The example below defines a toy model and a toy dataset, then shows that they appear in the lists produced by list_models() and list_dataset_classes().

[4]:
import torch.nn as nn

from hyrax.models.model_registry import hyrax_model
from hyrax.datasets import HyraxDataset


@hyrax_model
class AAA_MyCustomModel(nn.Module):
    """A simple toy model for demonstration purposes."""

    def __init__(self, config, data_sample=None):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

    def train_batch(self, batch):
        pass

    def infer_batch(self, batch):
        pass


class AAA_MyCustomDataset(HyraxDataset):
    """A simple toy dataset for demonstration purposes."""

    def __init__(self, config):
        super().__init__(config)

    def __len__(self):
        return 0

Running list_models and list_dataset_classes again shows that our newly defined model and dataset class have been registered by Hyrax and are available for use.

[5]:
print("Models:")
h.list_models()
Models:
[5]:
['AAA_MyCustomModel',
 'HSCAutoencoder',
 'HSCDCAE',
 'HyraxAutoencoder',
 'HyraxAutoencoderV2',
 'HyraxCNN',
 'HyraxLoopback',
 'ImageDCAE',
 'SimCLR']
[6]:
print("Dataset classes:")
h.list_dataset_classes()
Dataset classes:
[6]:
['AAA_MyCustomDataset',
 'DownloadedLSSTDataset',
 'FitsImageDataset',
 'HSCDataset',
 'HyraxCSVDataset',
 'HyraxCifarDataset',
 'HyraxRandomDataset',
 'InferenceDataset',
 'LSSTDataset',
 'MultimodalUniverseDataset',
 'ResultDataset']