Source code for hyrax.datasets.dataset_registry

# ruff: noqa: D102, B027
import logging
from collections.abc import Callable
from types import MethodType
from typing import Any

import numpy.typing as npt

from hyrax.plugin_utils import get_or_load_class, update_registry

[docs] logger = logging.getLogger(__name__)
[docs] DATASET_REGISTRY: dict[str, type["HyraxDataset"]] = {}
[docs] class HyraxDataset: """ How to make a hyrax dataset: .. code-block:: python from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(self, config: dict): super().__init__(config) def __len__(self): # Your len function goes here pass Optional interfaces: ``metadata`` -> Subclasses may pass an astropy table of metadata to ``__init__`` in the superclass. This table of metadata will be available through the ``metadata_fields`` and ``metadata`` functions. If desired, a subclass may override these functions directly rather than using the astropy Table interface. Further documentation is in the :doc:`/pre_executed/external_dataset_class` example notebook. """
[docs] def __init__(self, config: dict, metadata_table=None, object_id_column_name=None): """ .. py:method:: __init__ Overall initialization for all Datasets which saves the config Subclasses of HyraxDataset ought call this at the end of their __init__ like: .. code-block:: python from hyrax.datasets import HyraxDataset class MyDataset(HyraxDataset): def __init__(config): <your code> super().__init__(config) If per tensor metadata is available, it is recommended that dataset authors create an astropy Table of that data, in the same order as their data and pass that `metadata_table` as shown below: .. code-block:: python from hyrax.datasets import HyraxDataset from astropy.table import Table class MyDataset(HyraxDataset): def __init__(config): <your code> metadata_table = Table(<Your catalog data goes here>) super().__init__(config, metadata_table) Parameters ---------- config : dict, Optional The runtime configuration for hyrax metadata_table : Optional[Table], optional An Astropy Table with 1. the metadata columns desired for visualization AND 2. in the order your data will be enumerated. object_id_column_name : Optional[str], optional The name of the column containing object IDs. If None, uses the default from config or creates one from the ids() method. """
[docs] self._config = config
[docs] self._metadata_table = metadata_table
# Pull up all metadata fields as HyraxQL getters. if self._metadata_table is not None: def _make_getter(column): def getter(self, idx, _col=column): return self._metadata_table[_col][idx] return getter for col in self._metadata_table.colnames: method_name = f"get_{col}" if not hasattr(self, method_name): setattr(self, method_name, MethodType(_make_getter(col), self))
@property
[docs] def config(self): return self._config
[docs] def __init_subclass__(cls): from abc import ABC if ABC in cls.__bases__: return # We only require a user to implement a __len__ method. if not hasattr(cls, "__len__"): msg = f"Hyrax data set {cls.__name__} is missing required length function. " msg += "__len__ must be defined." raise RuntimeError(msg) # Ensure the class is in the registry so the config system can find it update_registry(DATASET_REGISTRY, cls.__name__, cls)
[docs] def metadata_fields(self) -> list[str]: """Returns a list of metadata fields supported by this object Returns ------- list[str] The column names of the metadata table passed. Empty string if no metadata was provided at during construction of the HyraxDataset (or derived class). """ return [] if self._metadata_table is None else list(self._metadata_table.colnames)
[docs] def metadata(self, idxs: npt.ArrayLike, fields: list[str]) -> npt.ArrayLike: """Returns a table representing the metadata given an array of indexes and a list of fields. Parameters ---------- idxs : npt.ArrayLike The indexes of the relevant tensor objects fields : list[str] The names of the fields you would like returned. All values must be among those returned by metadata_fields() Returns ------- npt.ArrayLike A numpy record array of your metadata, with only the columns specified. Roughly equivalent to: `metadata_table[idxs][fields].as_array()` where metadata_table is the astropy table that the HyraxDataset (or derived class) was constructed with. Raises ------ RuntimeError When none of the provided fields are """ metadata_fields = self.metadata_fields() for field in fields: if field not in metadata_fields: msg = f"Field {field} is not available for {self.__class__.__name__}." logger.error(msg) columns = [field for field in fields if field in metadata_fields] if len(columns) == 0: msg = ( f"None of the metadata fields passed [{fields}] are available for {self.__class__.__name__}." ) raise RuntimeError(msg) result = self._metadata_table[idxs][columns].as_array() # Convert masked arrays to regular arrays with NaN for masked values import numpy as np import numpy.ma as ma if ma.isMaskedArray(result): result = ma.filled(result, np.nan) return result
[docs] def fetch_dataset_class(class_name: str) -> type[HyraxDataset]: """Fetch the dataset class from the registry. Parameters ---------- class_name : str The name of the dataset class to fetch. Either the class name of a built in dataset, or the fully qualified name of a user-defined dataset. e.g. "my_module.my_submodule.MyDatasetClass" or "HyraxRandomDataset". Returns ------- type[HyraxDataset] The dataset class. Raises ------ ValueError If a built in dataset was requested, but not found in the registry. ValueError If no dataset was specified in the runtime configuration. """ if not class_name: raise RuntimeError("dataset_class must be specified in 'data_request'.") dataset_cls = get_or_load_class(class_name, DATASET_REGISTRY) return dataset_cls
[docs] class HyraxImageDataset: """ This is a mixin for Image datasets primarily concerned with providing utility functions to allow derived classes to set and apply transformations based on configs. The various set_*_transform functions stack individual transformations on a single stack The stack can be applied with apply_transform. """
[docs] def set_function_transform(self): from torchvision.transforms.v2 import Lambda function_name = self.config["data_set"]["transform"] if function_name: transform_func = self._get_np_function(function_name) self._update_transform(Lambda(lambd=transform_func))
[docs] def set_crop_transform(self, cutout_shape=None): from torchvision.transforms.v2 import CenterCrop if cutout_shape is None: cutout_shape = self.config["data_set"]["crop_to"] if self.config["data_set"]["crop_to"] else None if (not isinstance(cutout_shape, list) and not isinstance(cutout_shape, tuple)) or len( cutout_shape ) != 2: msg = "Must provide a cutout shape in config['data_set']['crop_to']." msg += " Shape should be a list of integer pixel sizes e.g. [100,100]" raise RuntimeError(msg) self._update_transform(CenterCrop(size=cutout_shape))
[docs] def apply_transform(self, data_torch): if self.__dict__.get("transform", False) is False: self.transform = None data_transformed = self.transform(data_torch) if self.transform is not None else data_torch return data_transformed.numpy()
[docs] def _update_transform(self, new_transform): from torchvision.transforms.v2 import Compose if self.__dict__.get("transform", False) is False: self.transform = None self.transform = new_transform if self.transform is None else Compose([new_transform, self.transform])
[docs] def _get_np_function(self, transform_str: str) -> Callable[..., Any]: """ _get_np_function. Returns the numpy mathematical function that the supplied string maps to; or raises an error if the supplied string cannot be mapped to a function. Parameters ---------- transform_str: str The string to me mapped to a numpy function """ import numpy as np try: func: Callable[..., Any] = getattr(np, transform_str) if callable(func): return func except AttributeError as err: msg = f"{transform_str} is not a valid numpy function.\n" msg += "The string passed to the transform variable needs to be a numpy function" raise RuntimeError(msg) from err