Source code for hyrax.config_schemas.data_request

"""
Pydantic models describing the structure of the ``data_request`` configuration.

These schemas validate and enforce the structure of dataset requests used throughout
the Hyrax framework.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any
from urllib.parse import urlparse

from pydantic import Field, RootModel, field_validator, model_validator

from .base import BaseConfigModel


[docs] class DataRequestConfig(BaseConfigModel): """Per-dataset configuration used within ``data_request``."""
[docs] dataset_class: str = Field(..., description="Fully qualified dataset class name.")
[docs] data_location: str = Field(..., description="Path or URI describing where the dataset is stored.")
[docs] fields: list[str] | None = Field( None, description="Subset of columns/fields to request from the dataset." )
[docs] primary_id_field: str | None = Field( None, description="Name of the primary identifier field in the dataset." )
[docs] join_field: str | None = Field( None, description=( "Field in this dataset whose values match the primary dataset's " "primary_id_field. When set, the DataProvider joins this secondary " "dataset to the primary by matching field values instead of requiring " "index alignment." ), )
[docs] dataset_config: dict | None = Field( None, description="Dataset-specific configuration as a free-form dictionary.", )
[docs] augment: bool | list[str] | None = Field( None, description=( "Enable augmentation for this dataset. When True (boolean), all " "augment_<field> methods found on the dataset class are used, with " "fallback to get_<field> for fields without an augment method. " "When a list of field names, only the listed fields are augmented " "(each must have an augment_<field> method — hard error if missing) " "and unlisted fields use get_<field>." ), )
@field_validator("data_location") @classmethod
[docs] def resolve_data_location(cls, v: str) -> str: """Fully resolve the data_location path, expanding user home directories and converting relative paths to absolute paths.""" parsed = urlparse(v) if parsed.scheme and v.startswith(f"{parsed.scheme}://"): return v return str(Path(v).expanduser().resolve())
@model_validator(mode="after")
[docs] def join_field_excludes_primary(self) -> DataRequestConfig: """Ensure that join_field and primary_id_field are mutually exclusive.""" if self.join_field is not None and self.primary_id_field is not None: raise ValueError( "'join_field' and 'primary_id_field' are mutually exclusive. " "'join_field' is for secondary datasets that join to the primary." ) return self
@model_validator(mode="after")
[docs] def validate_augment_list(self) -> DataRequestConfig: """Validate the list form of augment against fields and primary_id_field.""" if not isinstance(self.augment, list): return self if self.primary_id_field is not None and self.primary_id_field in self.augment: raise ValueError( f"Cannot enable augmentation on primary_id_field '{self.primary_id_field}'. " f"The primary ID field is implicitly repeated and must not be augmented." ) if self.fields is not None: allowed = set(self.fields) extra = set(self.augment) - allowed if extra: raise ValueError(f"augment list contains fields not in 'fields': {sorted(extra)}") return self
[docs] def as_dict(self, *, exclude_unset: bool = False) -> dict[str, Any]: """Return the configuration as a plain dictionary.""" return self.model_dump(exclude_unset=exclude_unset)
# Type alias for a dataset group value: a dict mapping friendly names to configs.
[docs] DatasetGroupValue = dict[str, DataRequestConfig]
[docs] def _normalize_dataset_group(value: Any) -> DatasetGroupValue: """Normalize a single dataset group value into a ``dict[str, DataRequestConfig]``. Every dataset source within a group must be identified by a user-supplied *friendly name*. The friendly name is the key in the returned dict and is used by ``DataProvider`` to reference the dataset at runtime. Accepted inputs --------------- - A ``dict`` whose values are ``DataRequestConfig`` instances or plain dicts that can be validated as one. The keys become the friendly names. Rejected inputs (raise ``ValueError``) ---------------------------------------- - A flat dict that contains ``dataset_class`` at the top level (no friendly name wrapper). - A bare ``DataRequestConfig`` instance (no friendly name wrapper). """ if isinstance(value, DataRequestConfig): raise ValueError( "A friendly name is required for each dataset source. " "Wrap the config in a dict with a friendly name, e.g. " '{"<friendly_name>": <DataRequestConfig>}.' ) if isinstance(value, dict): # Detect a flat config: dataset_class at the top level means no friendly # name was provided. if "dataset_class" in value: raise ValueError( "A friendly name is required for each dataset source. " 'Instead of {"dataset_class": ..., ...}, use ' '{"<friendly_name>": {"dataset_class": ..., ...}}.' ) # Dict of named configs — parse each value. parsed_dict: DatasetGroupValue = {} for key, val in value.items(): if isinstance(val, DataRequestConfig): parsed_dict[key] = val elif isinstance(val, dict): parsed_dict[key] = DataRequestConfig.model_validate(val) else: raise ValueError( f"Value for friendly name '{key}' must be a dict or DataRequestConfig instance, " f"got {type(val).__name__}." ) return parsed_dict raise ValueError(f"Cannot parse dataset group value of type {type(value).__name__}")
[docs] def _iter_all_configs( groups: dict[str, DatasetGroupValue], ) -> list[tuple[str, DataRequestConfig]]: """Yield ``(group_name, config)`` pairs across all groups.""" result = [] for group_name, group_value in groups.items(): for config in group_value.values(): result.append((group_name, config)) return result
[docs] class DataRequestDefinition(RootModel[dict[str, DatasetGroupValue]]): """Typed representation of the full ``data_request`` table. Accepts any number of arbitrarily-named dataset groups (e.g. ``train``, ``validate``, ``infer``, ``test``, ``finetune``, …). Each group value is a ``dict`` of *friendly-named* ``DataRequestConfig`` instances. A friendly name must always be provided explicitly — the schema will raise a validation error if a dataset source is specified without one. Example (Python):: { "train": { "my_dataset": { "dataset_class": "HyraxRandomDataset", "data_location": "/path/to/data", "primary_id_field": "object_id", } } } Example (TOML):: [data_request.train.my_dataset] dataset_class = "HyraxRandomDataset" data_location = "/path/to/data" primary_id_field = "object_id" """ @model_validator(mode="before") @classmethod
[docs] def normalize_all_groups(cls, value: Any) -> dict[str, DatasetGroupValue]: """Parse every top-level key into the expected group format.""" if not isinstance(value, dict): raise ValueError("DataRequestDefinition expects a dictionary of dataset groups.") normalized: dict[str, DatasetGroupValue] = {} for group_name, group_value in value.items(): if group_value is None: continue normalized[group_name] = _normalize_dataset_group(group_value) return normalized
@model_validator(mode="after")
[docs] def reject_augment_on_infer(self) -> DataRequestDefinition: """Augmentation cannot be enabled on the 'infer' data group.""" for group_name, group_value in self.root.items(): if group_name == "infer": for friendly_name, cfg in group_value.items(): has_augment = cfg.augment is True or ( isinstance(cfg.augment, list) and len(cfg.augment) > 0 ) if has_augment: raise ValueError( f"Augmentation cannot be enabled on 'infer' data group " f"(dataset '{friendly_name}'). Augmentation is only valid for " f"'train', 'validate', and 'test' groups." ) return self
@model_validator(mode="after")
[docs] def require_at_least_one_dataset(self) -> DataRequestDefinition: """Ensure at least one dataset group is provided.""" if not self.root: raise ValueError("At least one dataset group must be provided.") return self
@model_validator(mode="after")
[docs] def validate_primary_id_fields(self) -> DataRequestDefinition: """Validate that exactly one DataRequestConfig in each dataset group has a non-None primary_id_field. This ensures that when multiple datasets are requested (e.g., a group contains a dict of multiple DataRequestConfig instances), exactly one of them specifies which field to use as the primary identifier. """ for group_name, group_value in self.root.items(): primary_count = sum(1 for config in group_value.values() if config.primary_id_field is not None) if primary_count == 0: raise ValueError( f"'{group_name}' must have exactly one DataRequestConfig with " f"'primary_id_field' set, but found none." ) elif primary_count > 1: raise ValueError( f"'{group_name}' must have exactly one DataRequestConfig with " f"'primary_id_field' set, but found {primary_count}." ) return self
[docs] def validate_cross_group(self, groups: set[str]) -> None: """No-op: cross-group split validation is now handled by splitting_utils.validate_split_config."""
[docs] def __contains__(self, key: str) -> bool: """Return True if the group name is present in the definition.""" return key in self.root
[docs] def __getitem__(self, key: str) -> DatasetGroupValue: """Return the dataset group value for the given group name.""" return self.root[key]
[docs] def as_dict(self, *, exclude_unset: bool = False) -> dict[str, Any]: """Export as a nested dictionary compatible with existing configs. Each group value is a dict of ``{friendly_name: flat_config_dict}``. No implicit ``"data"`` wrapper is added — the friendly names supplied by the user are preserved verbatim. """ return { name: {key: cfg.as_dict(exclude_unset=exclude_unset) for key, cfg in value.items()} for name, value in self.root.items() }