hyrax.splitting_utils#

Splitting and dataset balancing utilities for Hyrax datasets.

Attributes#

Functions#

_is_path_value(→ bool)

Return True when val is a non-empty string (i.e. a file path, not a fraction).

_resolve_seed(→ int | None)

Return the effective RNG seed, resolving '' or false to

_shuffle(→ None)

Shuffle indices in-place using the configured RNG.

_primary_instance(→ Any)

Return the primary dataset instance from provider.

_find_primary_cfg(→ dict | None)

Return the first dataset config in group_dict that has primary_id_field set.

_compute_weights(→ numpy.ndarray)

Compute per-sample WeightedRandomSampler weights.

validate_split_config(→ None)

Validate [split] config values.

validate_balance_config(→ None)

Validate [balance] config values (pre-scan checks only).

validate_distribution_labels(→ None)

Cross-check distribution keys against the observed class labels (post-scan).

_compute_splits(→ dict[str, dict])

Compute split indices (and optional balance weights) for each group.

persist_splits(→ None)

Write one <group>_split.npz per group and a split_config.toml.

load_split_files(→ dict[str, dict])

Load previously persisted split files.

assign_splits_to_providers(→ None)

Attach split indices and weights onto each provider in datasets.

configs_equivalent(→ tuple[bool, list[str]])

Check whether prev config would produce the same splits as cur.

find_equivalent_split(→ dict[str, pathlib.Path] | None)

Scan the results directory for a previously persisted equivalent split.

create_splits(→ dict[str, dict])

Compute (or load) splits and weights for each data group.

Module Contents#

logger[source]#
_is_path_value(val: Any) bool[source]#

Return True when val is a non-empty string (i.e. a file path, not a fraction).

_resolve_seed(config: dict) int | None[source]#

Return the effective RNG seed, resolving ‘’ or false to config[‘data_set’][‘seed’].

_shuffle(indices: list[int], config: dict) None[source]#

Shuffle indices in-place using the configured RNG.

When split.rng_seed is empty, reproduces the legacy global-seed shuffle used by create_splits_from_fractions bit-for-bit.

_primary_instance(provider: hyrax.datasets.data_provider.DataProvider) Any[source]#

Return the primary dataset instance from provider.

_find_primary_cfg(group_dict: dict) dict | None[source]#

Return the first dataset config in group_dict that has primary_id_field set.

_compute_weights(indices: list[int], index_to_label: dict[int, Any], distribution: dict, num_classes: int) numpy.ndarray[source]#

Compute per-sample WeightedRandomSampler weights.

w_i = target_{class(i)} / count_{class(i)} (raw, not normalised — WRS normalises internally, and raw values stay interpretable against distribution).

validate_split_config(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) None[source]#

Validate [split] config values.

Raises:

RuntimeError – On any violated constraint (mixed float/path, bad domain, shared-location sum > 1.0, paths not in same directory).

validate_balance_config(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) None[source]#

Validate [balance] config values (pre-scan checks only).

Raises:

RuntimeError – If getter is missing, distribution is malformed, or distribution sum ≠ 1.0.

validate_distribution_labels(distribution: dict, observed_labels: set) None[source]#

Cross-check distribution keys against the observed class labels (post-scan).

Raises:

RuntimeError – If distribution contains a label absent from the dataset.

_compute_splits(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) dict[str, dict][source]#

Compute split indices (and optional balance weights) for each group.

Returns:

dict mapping group_name → {“indexes”

Return type:

np.ndarray[int64], “weights”: np.ndarray[float64] | None}

persist_splits(results_dir: pathlib.Path, splits: dict[str, dict], config: dict) None[source]#

Write one <group>_split.npz per group and a split_config.toml.

The weights array is omitted entirely for unbalanced groups (None) to save space; load_split_files treats its absence as None.

load_split_files(paths: dict[str, pathlib.Path]) dict[str, dict][source]#

Load previously persisted split files.

Parameters:

paths – Mapping of group name → path to <group>_split.npz.

Returns:

dict mapping group_name → {“indexes”

Return type:

ndarray, “weights”: ndarray | None}

assign_splits_to_providers(datasets: dict[str, hyrax.datasets.data_provider.DataProvider], splits: dict[str, dict]) None[source]#

Attach split indices and weights onto each provider in datasets.

configs_equivalent(prev: dict, cur: dict) tuple[bool, list[str]][source]#

Check whether prev config would produce the same splits as cur.

Returns:

equivalent is True only when all compared fields match. diffs is a human-readable list of differences (empty when equivalent).

Return type:

(equivalent, diffs)

find_equivalent_split(config: dict, results_root: pathlib.Path | None = None) dict[str, pathlib.Path] | None[source]#

Scan the results directory for a previously persisted equivalent split.

Returns the group→npz path mapping of the first match, or None.

create_splits(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider], *, results_dir: pathlib.Path | None = None, persist: bool = True) dict[str, dict][source]#

Compute (or load) splits and weights for each data group.

Assigns split_indices / split_weights on each provider via assign_splits_to_providers().

Returns:

dict mapping group_name → {“indexes”

Return type:

ndarray[int64], “weights”: ndarray[float64] | None}