Source code for hyrax.datasets.nested_pandas_dataset

from pathlib import Path
from types import MethodType

from hyrax.datasets.dataset_registry import HyraxDataset


[docs] class NestedPandasDataset(HyraxDataset): """A minimal Hyrax wrapper around ``nested_pandas.read_parquet``."""
[docs] def __init__(self, config: dict, data_location: Path | str | None = None): if data_location is None: raise ValueError("A `data_location` must be provided.")
[docs] self.data_location = str(data_location)
settings = config["data_set"]["NestedPandasDataset"]
[docs] self.read_kwargs = settings["read_kwargs"]
[docs] self.nested_frame = self._load_nested_frame(self.read_kwargs)
self._register_getters() super().__init__(config)
[docs] def _load_nested_frame(self, read_kwargs: dict): try: import nested_pandas as npd except ImportError as err: raise ImportError( "NestedPandasDataset requires the `nested-pandas` package. " "Install it with `pip install nested-pandas`." ) from err return npd.read_parquet(self.data_location, **read_kwargs)
[docs] def _all_available_fields(self) -> list[str]: fields = list(self.nested_frame.columns) if hasattr(self.nested_frame, "get_subcolumns"): fields.extend(self.nested_frame.get_subcolumns()) return fields
[docs] def _register_getters(self) -> None: def _make_getter(field_name: str): def getter(self, idx, _field_name=field_name): import pandas as pd retval = self.nested_frame[_field_name].loc[self.nested_frame.index[idx]] return retval.to_numpy() if isinstance(retval, (pd.DataFrame, pd.Series)) else retval return getter for field_name in self._all_available_fields(): method_name = f"get_{field_name}" if not hasattr(self, method_name): setattr(self, method_name, MethodType(_make_getter(field_name), self))
[docs] def __len__(self) -> int: return len(self.nested_frame)