hyrax.splitting_utils#
Splitting and dataset balancing utilities for Hyrax datasets.
Attributes#
Functions#
|
Return True when val is a non-empty string (i.e. a file path, not a fraction). |
|
Return the effective RNG seed, resolving '' or false to |
|
Shuffle indices in-place using the configured RNG. |
|
Return the primary dataset instance from provider. |
|
Return the first dataset config in group_dict that has primary_id_field set. |
|
Compute per-sample WeightedRandomSampler weights. |
|
Validate |
|
Validate |
|
Cross-check distribution keys against the observed class labels (post-scan). |
|
Compute split indices (and optional balance weights) for each group. |
|
Write one |
|
Load previously persisted split files. |
|
Attach split indices and weights onto each provider in datasets. |
|
Check whether prev config would produce the same splits as cur. |
|
Scan the results directory for a previously persisted equivalent split. |
|
Compute (or load) splits and weights for each data group. |
Module Contents#
- _is_path_value(val: Any) bool[source]#
Return True when val is a non-empty string (i.e. a file path, not a fraction).
- _resolve_seed(config: dict) int | None[source]#
Return the effective RNG seed, resolving ‘’ or false to config[‘data_set’][‘seed’].
- _shuffle(indices: list[int], config: dict) None[source]#
Shuffle indices in-place using the configured RNG.
When
split.rng_seedis empty, reproduces the legacy global-seed shuffle used bycreate_splits_from_fractionsbit-for-bit.
- _primary_instance(provider: hyrax.datasets.data_provider.DataProvider) Any[source]#
Return the primary dataset instance from provider.
- _find_primary_cfg(group_dict: dict) dict | None[source]#
Return the first dataset config in group_dict that has primary_id_field set.
- _compute_weights(indices: list[int], index_to_label: dict[int, Any], distribution: dict, num_classes: int) numpy.ndarray[source]#
Compute per-sample WeightedRandomSampler weights.
w_i = target_{class(i)} / count_{class(i)} (raw, not normalised — WRS normalises internally, and raw values stay interpretable against distribution).
- validate_split_config(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) None[source]#
Validate
[split]config values.- Raises:
RuntimeError – On any violated constraint (mixed float/path, bad domain, shared-location sum > 1.0, paths not in same directory).
- validate_balance_config(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) None[source]#
Validate
[balance]config values (pre-scan checks only).- Raises:
RuntimeError – If getter is missing, distribution is malformed, or distribution sum ≠ 1.0.
- validate_distribution_labels(distribution: dict, observed_labels: set) None[source]#
Cross-check distribution keys against the observed class labels (post-scan).
- Raises:
RuntimeError – If distribution contains a label absent from the dataset.
- _compute_splits(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider]) dict[str, dict][source]#
Compute split indices (and optional balance weights) for each group.
- Returns:
dict mapping group_name → {“indexes”
- Return type:
np.ndarray[int64], “weights”: np.ndarray[float64] | None}
- persist_splits(results_dir: pathlib.Path, splits: dict[str, dict], config: dict) None[source]#
Write one
<group>_split.npzper group and asplit_config.toml.The
weightsarray is omitted entirely for unbalanced groups (None) to save space;load_split_filestreats its absence asNone.
- load_split_files(paths: dict[str, pathlib.Path]) dict[str, dict][source]#
Load previously persisted split files.
- Parameters:
paths – Mapping of group name → path to
<group>_split.npz.- Returns:
dict mapping group_name → {“indexes”
- Return type:
ndarray, “weights”: ndarray | None}
- assign_splits_to_providers(datasets: dict[str, hyrax.datasets.data_provider.DataProvider], splits: dict[str, dict]) None[source]#
Attach split indices and weights onto each provider in datasets.
- configs_equivalent(prev: dict, cur: dict) tuple[bool, list[str]][source]#
Check whether prev config would produce the same splits as cur.
- Returns:
equivalent is True only when all compared fields match. diffs is a human-readable list of differences (empty when equivalent).
- Return type:
(equivalent, diffs)
- find_equivalent_split(config: dict, results_root: pathlib.Path | None = None) dict[str, pathlib.Path] | None[source]#
Scan the results directory for a previously persisted equivalent split.
Returns the group→npz path mapping of the first match, or
None.
- create_splits(config: dict, datasets: dict[str, hyrax.datasets.data_provider.DataProvider], *, results_dir: pathlib.Path | None = None, persist: bool = True) dict[str, dict][source]#
Compute (or load) splits and weights for each data group.
Assigns
split_indices/split_weightson each provider viaassign_splits_to_providers().- Returns:
dict mapping group_name → {“indexes”
- Return type:
ndarray[int64], “weights”: ndarray[float64] | None}