Source code for hyrax.data_sets.hyrax_cifar_data_set

# ruff: noqa: D101, D102
import logging

import numpy as np
from torch.utils.data import Dataset, IterableDataset

from hyrax.config_utils import ConfigDict

from .data_set_registry import HyraxDataset

[docs] logger = logging.getLogger(__name__)
[docs] class HyraxCifarBase: """Base class for Hyrax Cifar datasets"""
[docs]
[docs] def __init__(self, config: ConfigDict): import torchvision.transforms as transforms from astropy.table import Table from torchvision.datasets import CIFAR10 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] )
[docs] self.cifar = CIFAR10( root=config["general"]["data_dir"], train=True, download=True, transform=transform )
metadata_table = Table( {"label": np.array([self.cifar[index][1] for index in range(len(self.cifar))])} ) super().__init__(config, metadata_table)
[docs] class HyraxCifarDataSet(HyraxCifarBase, HyraxDataset, Dataset): """Map style CIFAR 10 dataset for Hyrax This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation that works well for example code. We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that into Train/test/Validate according to configuration. """
[docs] def __len__(self): return len(self.cifar)
[docs] def __getitem__(self, idx): image, label = self.cifar[idx] return { "object_id": idx, "image": image, "label": label, }
[docs] class HyraxCifarIterableDataSet(HyraxCifarBase, HyraxDataset, IterableDataset): """Iterable style CIFAR 10 dataset for Hyrax This is simply a version of CIFAR10 that is initialized using Hyrax config with a transformation that works well for example code. This version only supports iteration, and not map-style access We only use the training split in the data, because it is larger (50k images). Hyrax will then divide that into Train/test/Validate according to configuration. """
[docs] def __iter__(self): for idx, (image, label) in enumerate(self.cifar): yield { "object_id": idx, "image": image, "label": label, }