Source code for hyrax.datasets.hyrax_cifar_dataset

# ruff: noqa: D101, D102
import logging
from pathlib import Path

from .dataset_registry import HyraxDataset

[docs] logger = logging.getLogger(__name__)
[docs] class HyraxCifarDataset(HyraxDataset): """Map style CIFAR 10 dataset for Hyrax This utilizes the CIFAR dataset from torchvision for retrieving the dataset. """
[docs] def __init__(self, config: dict, data_location: Path = None): import torchvision.transforms as transforms from torchvision.datasets import CIFAR10
[docs] self.data_location = data_location
[docs] self.training_data = config["data_set"]["HyraxCifarDataset"]["use_training_data"]
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] )
[docs] self.cifar = CIFAR10( root=self.data_location, train=self.training_data, download=True, transform=transform )
n_id = len(self.cifar)
[docs] self.id_width = len(str(n_id))
super().__init__(config)
[docs] def get_image(self, idx): """Get the image at the given index as a NumPy array.""" image, _ = self.cifar[idx] return image.numpy()
[docs] def get_label(self, idx): """Get the label at the given index.""" _, label = self.cifar[idx] return label
[docs] def get_object_id(self, idx): """Get the object ID for the item as a string.""" return f"{idx:0{self.id_width}d}"
[docs] def __len__(self): return len(self.cifar)