Source code for hyrax.datasets.mmu_dataset
import itertools
import re
from pathlib import Path
from types import MethodType
from typing import Any
from hyrax.datasets.dataset_registry import HyraxDataset
[docs]
class _IndexedSubset:
"""Fallback wrapper to enforce a max row count for indexable datasets."""
def __init__(self, dataset: Any, max_samples: int):
[docs]
self._dataset = dataset
[docs]
self._max_samples = max_samples
[docs]
def __getitem__(self, idx: int) -> Any:
if idx >= self._max_samples:
raise IndexError(idx)
return self._dataset[idx]
[docs]
def __len__(self) -> int:
return self._max_samples
[docs]
class MultimodalUniverseDataset(HyraxDataset):
"""Load a MultimodalUniverse dataset through Hugging Face ``datasets``.
This dataset class is intentionally generic so one configuration pattern can
be used for image, spectra, and time-series MMU datasets.
Examples
--------
Example ``data_request`` configuration::
{
"infer": {
"mmu": {
"dataset_class": "MultimodalUniverseDataset",
"data_location": "hf://MultimodalUniverse/plasticc",
"primary_id_field": "object_id",
"dataset_config": {
"MultimodalUniverseDataset": {
"split": "train",
"max_samples": 32,
}
},
}
}
}
"""
[docs]
def __init__(self, config: dict, data_location: Path | str | None = None):
if data_location is None:
raise ValueError(
"A `data_location` must be provided. Use either a local dataset path or "
"a Hugging Face URI like 'hf://MultimodalUniverse/plasticc'."
)
[docs]
self.data_location = str(data_location)
dataset_settings = config["data_set"]["MultimodalUniverseDataset"]
[docs]
self.split = dataset_settings["split"]
[docs]
self.max_samples = int(dataset_settings["max_samples"]) if dataset_settings["max_samples"] else None
[docs]
self.streaming = dataset_settings["streaming"]
dataset_source = self._normalize_data_location(self.data_location)
[docs]
self.dataset = self._load_dataset(dataset_source)
[docs]
self._column_name_map = self._build_column_name_map()
self._register_getters()
super().__init__(config)
[docs]
def _normalize_data_location(self, data_location: str) -> str:
if data_location.startswith("hf://"):
return data_location[5:]
return data_location
[docs]
def _load_dataset(self, dataset_source: str):
try:
from datasets import load_dataset
except ImportError as err:
raise ImportError(
"MultimodalUniverseDataset requires the `datasets` package. "
"Install it with `pip install datasets`."
) from err
dataset = load_dataset(dataset_source, split=self.split, streaming=self.streaming)
dataset = dataset.with_format("numpy")
if self.streaming:
if self.max_samples is None:
raise ValueError(
"When streaming=True, set data_set.MultimodalUniverseDataset.max_samples "
"to avoid iterating through the full dataset."
)
dataset = list(itertools.islice(dataset, self.max_samples))
elif self.max_samples is not None:
dataset = self._limit_non_streaming_dataset(dataset, self.max_samples)
return dataset
[docs]
def _limit_non_streaming_dataset(self, dataset: Any, max_samples: int):
limit = min(max_samples, len(dataset))
if hasattr(dataset, "select"):
return dataset.select(range(limit))
if isinstance(dataset, list):
return dataset[:limit]
return _IndexedSubset(dataset, limit)
[docs]
def _build_column_name_map(self) -> dict[str, str]:
"""
Returns a map from sanitized column names to the original column names.
It's possible for a column name to have punctuation or start with a number.
In these cases we also allow column access via a sanitized name where all
punctuation is replaced with the underscore character, and any field starting
with a number is replaced by ``field_``
Every field is entered in the dictionary regardless of whether it needed
sanitization or not. In this case the sanitized name is exactly the field
name.
"""
sample = self.dataset[0]
column_name_map: dict[str, str] = {}
for key in sample:
# Always register the raw key so users can request exact MMU field
# names (including punctuation) in data_request.fields.
column_name_map[key] = key
# Register a sanitized alias for convenience.
sanitized = self._sanitize_name(key)
# Note that if the sanitized name is the key name, this line is a noop
# because the key was already set above.
column_name_map.setdefault(sanitized, key)
return column_name_map
[docs]
def _sanitize_name(self, column_name: str) -> str:
"""
Take a column name that may contain punctuation and return a version with
underscore replacing the punctuation
"""
sanitized = re.sub(r"\W", "_", column_name)
if not sanitized:
return "field"
if sanitized[0].isdigit():
return f"field_{sanitized}"
return sanitized
[docs]
def _register_getters(self) -> None:
def _make_getter(source_name):
def getter(self, idx, _source_name=source_name):
import numpy as np
from PIL.Image import Image
retval = self.dataset[idx][_source_name]
# Some fields in MMU are PIL images.
# Hyrax only acepts numpy arrays
if isinstance(retval, Image):
retval = np.asarray(retval)
return retval
return getter
for method_suffix, source_name in self._column_name_map.items():
method_name = f"get_{method_suffix}"
if not hasattr(self, method_name):
setattr(self, method_name, MethodType(_make_getter(source_name), self))
[docs]
def __len__(self) -> int:
return len(self.dataset)