Source code for hyrax.datasets.hyrax_cifar_dataset
# ruff: noqa: D101, D102
import logging
from pathlib import Path
from .dataset_registry import HyraxDataset
[docs]
logger = logging.getLogger(__name__)
[docs]
class HyraxCifarDataset(HyraxDataset):
"""Map style CIFAR 10 dataset for Hyrax
This utilizes the CIFAR dataset from torchvision for retrieving the dataset.
"""
[docs]
def __init__(self, config: dict, data_location: Path = None):
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
[docs]
self.data_location = data_location
[docs]
self.training_data = config["data_set"]["HyraxCifarDataset"]["use_training_data"]
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
[docs]
self.cifar = CIFAR10(
root=self.data_location, train=self.training_data, download=True, transform=transform
)
n_id = len(self.cifar)
[docs]
self.id_width = len(str(n_id))
super().__init__(config)
[docs]
def get_image(self, idx):
"""Get the image at the given index as a NumPy array."""
image, _ = self.cifar[idx]
return image.numpy()
[docs]
def get_label(self, idx):
"""Get the label at the given index."""
_, label = self.cifar[idx]
return label
[docs]
def get_object_id(self, idx):
"""Get the object ID for the item as a string."""
return f"{idx:0{self.id_width}d}"
[docs]
def __len__(self):
return len(self.cifar)