Source code for hyrax.datasets.hats_dataset

from pathlib import Path
from types import MethodType

from hyrax.datasets.dataset_registry import HyraxDataset


[docs] class HyraxHATSDataset(HyraxDataset): """Generic Hyrax dataset for HATS catalogs loaded through LSDB. Notes ----- This phase-1 implementation materializes the LSDB catalog to a pandas DataFrame and dynamically creates ``get_<column>`` methods for requested columns. """
[docs] def __init__(self, config: dict, data_location: Path = None): if data_location is None or data_location is False: raise ValueError("A `data_location` to a HATS catalog must be provided.")
[docs] self.data_location = data_location
requested_columns = self._requested_columns_from_config(config) open_catalog_kwargs = self._open_catalog_kwargs_from_config(config) if requested_columns and "columns" not in open_catalog_kwargs: open_catalog_kwargs["columns"] = requested_columns import lsdb catalog = lsdb.open_catalog(data_location, **open_catalog_kwargs)
[docs] self.dataframe = catalog.compute()
[docs] self.column_names = requested_columns if requested_columns else list(self.dataframe.columns)
def _make_getter(column: str): def getter(self, idx: int, _col: str = column): import numpy as np import pandas as pd ret_val = self.dataframe.iloc[idx][_col] if isinstance(ret_val, pd.Series): ret_val = ret_val.to_numpy() elif isinstance(ret_val, (list, tuple)): ret_val = np.asarray(ret_val) return ret_val return getter for col in self.column_names: method_name = f"get_{col}" if not hasattr(self, method_name): setattr(self, method_name, MethodType(_make_getter(col), self)) super().__init__(config)
[docs] def _requested_columns_from_config(self, config: dict) -> list[str]: data_request = config.get("data_request") or config.get("model_inputs") or {} requested_columns = set() target_location = str(Path(self.data_location).resolve()) for request_group in data_request.values(): for dataset_definition in request_group.values(): if dataset_definition.get("dataset_class") != type(self).__name__: continue if str(Path(dataset_definition["data_location"]).resolve()) != target_location: continue # If any dataset request has no fields specified, that means we need all the columns # no matter what any other request group says, so just early-return requesting everything. if not dataset_definition.get("fields", None): return None requested_columns.update(dataset_definition.get("fields", [])) primary_id_field = dataset_definition.get("primary_id_field") if primary_id_field: requested_columns.add(primary_id_field) join_field = dataset_definition.get("join_field") if join_field: requested_columns.add(join_field) return sorted(requested_columns)
[docs] def _open_catalog_kwargs_from_config(self, config: dict) -> dict: return config["data_set"][type(self).__name__].get("open_catalog_kwargs")
[docs] def __len__(self) -> int: return len(self.dataframe)