hyrax.splitting_utils
=====================

.. py:module:: hyrax.splitting_utils

.. autoapi-nested-parse::

   Splitting and dataset balancing utilities for Hyrax datasets.



Attributes
----------

.. autoapisummary::

   hyrax.splitting_utils.logger


Functions
---------

.. autoapisummary::

   hyrax.splitting_utils._is_path_value
   hyrax.splitting_utils._resolve_seed
   hyrax.splitting_utils._shuffle
   hyrax.splitting_utils._primary_instance
   hyrax.splitting_utils._find_primary_cfg
   hyrax.splitting_utils._compute_weights
   hyrax.splitting_utils.validate_split_config
   hyrax.splitting_utils.validate_balance_config
   hyrax.splitting_utils.validate_distribution_labels
   hyrax.splitting_utils._compute_splits
   hyrax.splitting_utils.persist_splits
   hyrax.splitting_utils.load_split_files
   hyrax.splitting_utils.assign_splits_to_providers
   hyrax.splitting_utils.configs_equivalent
   hyrax.splitting_utils.find_equivalent_split
   hyrax.splitting_utils.create_splits


Module Contents
---------------

.. py:data:: logger

.. py:function:: _is_path_value(val: Any) -> bool

   Return True when val is a non-empty string (i.e. a file path, not a fraction).


.. py:function:: _resolve_seed(config: dict) -> int | None

   Return the effective RNG seed, resolving '' or false to
   config['data_set']['seed'].


.. py:function:: _shuffle(indices: list[int], config: dict) -> None

   Shuffle *indices* in-place using the configured RNG.

   When ``split.rng_seed`` is empty, reproduces the legacy global-seed shuffle
   used by ``create_splits_from_fractions`` bit-for-bit.


.. py:function:: _primary_instance(provider: hyrax.datasets.data_provider.DataProvider) -> Any

   Return the primary dataset instance from *provider*.


.. py:function:: _find_primary_cfg(group_dict: dict) -> dict | None

   Return the first dataset config in *group_dict* that has primary_id_field set.


.. py:function:: _compute_weights(indices: list[int], index_to_label: dict[int, Any], distribution: dict, num_classes: int) -> numpy.ndarray

   Compute per-sample WeightedRandomSampler weights.

   w_i = target_{class(i)} / count_{class(i)}  (raw, not normalised — WRS
   normalises internally, and raw values stay interpretable against distribution).


.. py:function:: validate_split_config(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) -> None

   Validate ``[split]`` config values.

   :raises RuntimeError: On any violated constraint (mixed float/path, bad domain, shared-location
       sum > 1.0, paths not in same directory).


.. py:function:: validate_balance_config(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) -> None

   Validate ``[balance]`` config values (pre-scan checks only).

   :raises RuntimeError: If getter is missing, distribution is malformed, or distribution sum ≠ 1.0.


.. py:function:: validate_distribution_labels(distribution: dict, observed_labels: set) -> None

   Cross-check distribution keys against the observed class labels (post-scan).

   :raises RuntimeError: If distribution contains a label absent from the dataset.


.. py:function:: _compute_splits(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) -> dict[str, dict]

   Compute split indices (and optional balance weights) for each group.

   :returns: **dict mapping group_name → {"indexes"**
   :rtype: np.ndarray[int64], "weights": np.ndarray[float64] | None}


.. py:function:: persist_splits(results_dir: pathlib.Path, splits: dict[str, dict], config: dict) -> None

   Write one ``<group>_split.npz`` per group and a ``split_config.toml``.

   The ``weights`` array is omitted entirely for unbalanced groups (``None``)
   to save space; ``load_split_files`` treats its absence as ``None``.


.. py:function:: load_split_files(paths: dict[str, pathlib.Path]) -> dict[str, dict]

   Load previously persisted split files.

   :param paths: Mapping of group name → path to ``<group>_split.npz``.

   :returns: **dict mapping group_name → {"indexes"**
   :rtype: ndarray, "weights": ndarray | None}


.. py:function:: assign_splits_to_providers(datasets: dict[str, hyrax.datasets.data_provider.DataProvider], splits: dict[str, dict]) -> None

   Attach split indices and weights onto each provider in *datasets*.


.. py:function:: configs_equivalent(prev: dict, cur: dict) -> tuple[bool, list[str]]

   Check whether *prev* config would produce the same splits as *cur*.

   :returns: *equivalent* is True only when all compared fields match.
             *diffs* is a human-readable list of differences (empty when equivalent).
   :rtype: (equivalent, diffs)


.. py:function:: find_equivalent_split(config: dict, results_root: pathlib.Path | None = None) -> dict[str, pathlib.Path] | None

   Scan the results directory for a previously persisted equivalent split.

   Returns the group→npz path mapping of the first match, or ``None``.


.. py:function:: create_splits(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider], *, results_dir: pathlib.Path | None = None, persist: bool = True) -> dict[str, dict]

   Compute (or load) splits and weights for each data group.

   Assigns ``split_indices`` / ``split_weights`` on each provider via
   :func:`assign_splits_to_providers`.

   :returns: **dict mapping group_name → {"indexes"**
   :rtype: ndarray[int64], "weights": ndarray[float64] | None}


