Source code for hyrax.splitting_utils

"""Splitting and dataset balancing utilities for Hyrax datasets."""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
import tomlkit

if TYPE_CHECKING:
    from hyrax.datasets.data_provider import DataProvider

[docs] logger = logging.getLogger(__name__)
# ── Private helpers ────────────────────────────────────────────────────────────
[docs] def _is_path_value(val: Any) -> bool: """Return True when val is a non-empty string (i.e. a file path, not a fraction).""" return isinstance(val, str) and bool(val)
[docs] def _resolve_seed(config: dict) -> int | None: """Return the effective RNG seed, resolving '' or false to config['data_set']['seed'].""" rng_seed = config["split"]["rng_seed"] if config["split"]["rng_seed"] else None if not rng_seed: raw = config.get("data_set", {}).get("seed") return raw if raw else None return rng_seed
[docs] def _shuffle(indices: list[int], config: dict) -> None: """Shuffle *indices* in-place using the configured RNG. When ``split.rng_seed`` is empty, reproduces the legacy global-seed shuffle used by ``create_splits_from_fractions`` bit-for-bit. """ rng_seed = config["split"].get("rng_seed") if isinstance(rng_seed, str): raise RuntimeError( f"split.rng_seed must be an integer (or false to use data_set.seed); got {rng_seed!r}." ) if rng_seed in ("", None, False): seed = config["data_set"]["seed"] if config["data_set"]["seed"] else None np.random.seed(seed) np.random.shuffle(indices) else: np.random.default_rng(int(rng_seed)).shuffle(indices)
[docs] def _primary_instance(provider: DataProvider) -> Any: """Return the primary dataset instance from *provider*.""" return provider.prepped_datasets[provider.primary_dataset]
[docs] def _find_primary_cfg(group_dict: dict) -> dict | None: """Return the first dataset config in *group_dict* that has primary_id_field set.""" for cfg in group_dict.values(): if isinstance(cfg, dict) and cfg.get("primary_id_field"): return cfg return None
[docs] def _compute_weights( indices: list[int], index_to_label: dict[int, Any], distribution: dict, num_classes: int, ) -> np.ndarray: """Compute per-sample WeightedRandomSampler weights. w_i = target_{class(i)} / count_{class(i)} (raw, not normalised — WRS normalises internally, and raw values stay interpretable against distribution). """ count_c: dict[Any, int] = {} for idx in indices: label = index_to_label[idx] count_c[label] = count_c.get(label, 0) + 1 if distribution: target_c = {label: float(distribution.get(label, 0)) for label in count_c} else: uniform = 1.0 / num_classes target_c = {label: uniform for label in count_c} # Use Float32 for compatibility with mps. # # Note that the minimum representable positive number in Float32 is about 10^-38 # (without using subnormals). # # The maximum int64 is about 2*10^19, and The maximum int128 is 3*10^38. # # Even if there is only 1 member of a class, we are probably fine with Float32 unless you need an # int128 to describe your indexes. weights = np.array( [target_c.get(index_to_label[idx], 0) / count_c[index_to_label[idx]] for idx in indices], dtype=np.float32, ) return weights
# ── Validation ─────────────────────────────────────────────────────────────────
[docs] def validate_split_config(config: dict, datasets: dict[str, DataProvider]) -> None: """Validate ``[split]`` config values. Raises ------ RuntimeError On any violated constraint (mixed float/path, bad domain, shared-location sum > 1.0, paths not in same directory). """ split_cfg = config["split"] group_values = {name: split_cfg.get(name, 1.0) for name in datasets} path_groups = [n for n, v in group_values.items() if _is_path_value(v)] float_groups = [n for n in group_values if n not in path_groups] if path_groups and float_groups: raise RuntimeError( "split values must be all floats or all paths, not mixed. " f"Float groups: {float_groups}; path groups: {path_groups}." ) if path_groups: parents = {Path(str(split_cfg[name])).parent for name in path_groups} if len(parents) > 1: raise RuntimeError( "All split path files must share a common parent directory. " f"Found multiple parents: {[str(p) for p in parents]}" ) for name in path_groups: if not Path(str(split_cfg[name])).exists(): raise RuntimeError(f"split.{name} = {split_cfg[name]} does not exist.") return # All floats: validate domain for name in datasets: # if `name` isn't in split_cfg, use full dataset raw = split_cfg.get(name, 1.0) frac = 1.0 if raw in ("", None, False) else float(raw) if not (0.0 < frac <= 1.0): raise RuntimeError(f"split.{name} = {frac} is out of range (0.0, 1.0].") # Sum check per shared primary_data_location (infer always independent) location_fracs: dict[str, list[float]] = {} for name, provider in datasets.items(): if name == "infer": continue loc = provider.primary_data_location if loc: raw = split_cfg.get(name, 1.0) frac = 1.0 if raw in ("", None, False) else float(raw) location_fracs.setdefault(loc, []).append(frac) for loc, fracs in location_fracs.items(): total = sum(fracs) if np.round(total, 5) > 1.0: raise RuntimeError( f"split fractions for data_location '{loc}' sum to {total:.6f}, which exceeds 1.0." )
[docs] def validate_balance_config(config: dict, datasets: dict[str, DataProvider]) -> None: """Validate ``[balance]`` config values (pre-scan checks only). Raises ------ RuntimeError If getter is missing, distribution is malformed, or distribution sum ≠ 1.0. """ balance_cfg = config["balance"] field = balance_cfg["field"] if balance_cfg["field"] else None balance_groups = balance_cfg["groups"] distribution = balance_cfg["distribution"] if not field: if balance_groups or distribution: raise RuntimeError( "balance.field must be set when balance.groups or balance.distribution are provided." ) return for group_name, provider in datasets.items(): if group_name == "infer": continue primary_ds = _primary_instance(provider) if not hasattr(primary_ds, f"get_{field}"): raise RuntimeError( f"balance.field='{field}' requires a get_{field} method on the primary " f"dataset of group '{group_name}', but none was found on " f"{type(primary_ds).__name__}." ) for g in balance_groups: if g not in datasets: logger.warning("balance.groups contains '%s' which is not in data_request; ignoring.", g) if distribution: for label, val in distribution.items(): try: fval = float(val) except (TypeError, ValueError) as err: raise RuntimeError( f"balance.distribution['{label}'] = {val!r} is not a valid float." ) from err if not (0.0 <= fval <= 1.0): raise RuntimeError(f"balance.distribution['{label}'] = {fval} is out of range [0.0, 1.0].") total = sum(float(v) for v in distribution.values()) if np.round(total, 5) != 1.0: raise RuntimeError( f"balance.distribution values sum to {total:.6f}; they must sum to exactly 1.0." ) # [label] pre-scan checks (only consulted when both label table and distribution are present) label_cfg = config["label"] if label_cfg: raw_values = list(label_cfg.values()) if len(raw_values) != len(set(str(v) for v in raw_values)): raise RuntimeError( "[label] values must be unique — two or more aliases map to the same raw value." ) if distribution: for dist_key in distribution: if dist_key not in label_cfg: raise RuntimeError( f"balance.distribution key '{dist_key}' is not defined in [label]. " "All distribution keys must appear in [label] when [label] is non-empty." )
[docs] def validate_distribution_labels(distribution: dict, observed_labels: set) -> None: """Cross-check distribution keys against the observed class labels (post-scan). Raises ------ RuntimeError If distribution contains a label absent from the dataset. """ if not distribution: return for label in distribution: if label not in observed_labels: raise RuntimeError( f"balance.distribution contains label '{label}' not found in the dataset. " f"Observed labels: {sorted(observed_labels)}" ) for label in observed_labels: if label not in distribution: logger.warning( "Dataset class '%s' is absent from balance.distribution; " "it will receive weight 0 (no samples drawn for this class).", label, )
# ── Core split computation ─────────────────────────────────────────────────────
[docs] def _compute_splits(config: dict, datasets: dict[str, DataProvider]) -> dict[str, dict]: """Compute split indices (and optional balance weights) for each group. Returns ------- dict mapping group_name → {"indexes": np.ndarray[int64], "weights": np.ndarray[float64] | None} """ split_cfg = config["split"] balance_cfg = config["balance"] field = balance_cfg["field"] if balance_cfg["field"] else None balance_groups_cfg = balance_cfg["groups"] distribution = balance_cfg["distribution"] # Resolve groups_to_balance per spec §4.2 table if balance_groups_cfg: groups_to_balance = set(balance_groups_cfg) elif distribution and field: groups_to_balance = set(datasets.keys()) - {"infer"} else: groups_to_balance = set() result: dict[str, dict] = {} # Infer: always independent, no shuffle, no weights if "infer" in datasets: provider = datasets["infer"] n_items = len(provider) raw = split_cfg.get("infer", 1.0) frac = 1.0 if raw in ("", None, False) else float(raw) count = round(n_items * frac) result["infer"] = { "indexes": np.array(list(range(count)), dtype=np.int64), "weights": None, } # Group remaining providers by primary_data_location non_infer = {k: v for k, v in datasets.items() if k != "infer"} location_groups: dict[str, dict[str, DataProvider]] = {} for group_name, provider in non_infer.items(): loc = provider.primary_data_location or group_name location_groups.setdefault(loc, {})[group_name] = provider for _loc, loc_datasets in location_groups.items(): first_provider = next(iter(loc_datasets.values())) n_items = len(first_provider) fractions = { name: (1.0 if split_cfg.get(name, 1.0) in ("", None, False) else float(split_cfg.get(name, 1.0))) for name in loc_datasets } total = sum(fractions.values()) last_name = list(loc_datasets.keys())[-1] if not field: # Non-stratified: mirror create_splits_from_fractions semantics indices = list(range(n_items)) _shuffle(indices, config) offset = 0 for name, frac in fractions.items(): count = min(round(n_items * frac), n_items - offset) if name == last_name and total >= 1.0 - 1e-5: count = n_items - offset result[name] = { "indexes": np.array(indices[offset : offset + count], dtype=np.int64), "weights": None, } offset += count else: # Stratified: build class index map, then distribute per-class logger.info( f"Computing stratified or balanced splits for data_location '{_loc}' " f"using balance.field '{field}'. This requires a full scan of " "the dataset, which may take a while for large datasets." ) primary_ds = _primary_instance(first_provider) getter = getattr(primary_ds, f"get_{field}") class_inds: dict[Any, list[int]] = {} for i in range(n_items): label = getter(i) class_inds.setdefault(label, []).append(i) # [label] re-keying: translate raw values to alias strings (§4.3) label_cfg = dict(config.get("label") or {}) if label_cfg: raw_to_name = {v: k for k, v in label_cfg.items()} rekeyed: dict[Any, list[int]] = {} for raw_val, inds in class_inds.items(): alias = raw_to_name.get(raw_val) if alias is None: logger.warning( "Dataset contains raw label value %r from get_%s " "that has no alias in [label]; %d item(s) with this value " "will be excluded from all split groups.", raw_val, field, len(inds), ) else: rekeyed[alias] = inds class_inds = rekeyed validate_distribution_labels(distribution, set(class_inds)) # Build reverse lookup for weight computation index_to_label: dict[int, Any] = {} for label, inds in class_inds.items(): for i in inds: index_to_label[i] = label num_classes = len(class_inds) per_group: dict[str, list[int]] = {name: [] for name in loc_datasets} for label in sorted(class_inds, key=str): inds = list(class_inds[label]) _shuffle(inds, config) offset = 0 for name, frac in fractions.items(): count = min(round(len(inds) * frac), len(inds) - offset) per_group[name] += inds[offset : offset + count] offset += count if offset < len(inds) and total >= 1.0 - 1e-5: per_group[last_name] += inds[offset:] for name, indices_list in per_group.items(): if name in groups_to_balance: weights = _compute_weights(indices_list, index_to_label, distribution, num_classes) else: weights = None result[name] = { "indexes": np.array(indices_list, dtype=np.int64), "weights": weights, } return result
# ── Persistence / loading ──────────────────────────────────────────────────────
[docs] def persist_splits(results_dir: Path, splits: dict[str, dict], config: dict) -> None: """Write one ``<group>_split.npz`` per group and a ``split_config.toml``. The ``weights`` array is omitted entirely for unbalanced groups (``None``) to save space; ``load_split_files`` treats its absence as ``None``. """ for group, data in splits.items(): save_kwargs: dict[str, np.ndarray] = {"indexes": data["indexes"]} if data["weights"] is not None: save_kwargs["weights"] = data["weights"] np.savez_compressed(results_dir / f"{group}_split.npz", **save_kwargs) split_config: dict = {} for key in ("data_request", "split", "balance", "label"): if key in config: split_config[key] = config[key] with open(results_dir / "split_config.toml", "w") as f: f.write(tomlkit.dumps(split_config))
[docs] def load_split_files(paths: dict[str, Path]) -> dict[str, dict]: """Load previously persisted split files. Parameters ---------- paths: Mapping of group name → path to ``<group>_split.npz``. Returns ------- dict mapping group_name → {"indexes": ndarray, "weights": ndarray | None} """ result: dict[str, dict] = {} for group, path in paths.items(): path = Path(path) if not path.exists(): raise RuntimeError(f"Split file for group '{group}' not found: {path}") npz = np.load(path) if "indexes" not in npz.files: raise RuntimeError(f"Split file '{path}' is missing the required 'indexes' array.") result[group] = { "indexes": npz["indexes"], "weights": npz["weights"] if "weights" in npz.files else None, } return result
[docs] def assign_splits_to_providers(datasets: dict[str, DataProvider], splits: dict[str, dict]) -> None: """Attach split indices and weights onto each provider in *datasets*.""" for group, data in splits.items(): if group not in datasets: continue provider = datasets[group] provider.split_indices = data["indexes"].tolist() provider.split_weights = data["weights"]
# ── Equivalency ────────────────────────────────────────────────────────────────
[docs] def configs_equivalent(prev: dict, cur: dict) -> tuple[bool, list[str]]: """Check whether *prev* config would produce the same splits as *cur*. Returns ------- (equivalent, diffs) *equivalent* is True only when all compared fields match. *diffs* is a human-readable list of differences (empty when equivalent). """ diffs: list[str] = [] def _get(cfg: dict, *keys: str, default: Any = None) -> Any: node = cfg for k in keys: if not isinstance(node, dict): return default node = node.get(k, default) return node # Global comparisons prev_field = _get(prev, "balance", "field") or "" cur_field = _get(cur, "balance", "field") or "" if str(prev_field) != str(cur_field): diffs.append(f"balance.field: {prev_field!r}{cur_field!r}") prev_dist = dict(_get(prev, "balance", "distribution") or {}) cur_dist = dict(_get(cur, "balance", "distribution") or {}) if {str(k): float(v) for k, v in prev_dist.items()} != {str(k): float(v) for k, v in cur_dist.items()}: diffs.append("balance.distribution changed") if _resolve_seed(prev) != _resolve_seed(cur): diffs.append(f"rng_seed (resolved): {_resolve_seed(prev)!r}{_resolve_seed(cur)!r}") # Per-group comparisons cur_dr = _get(cur, "data_request") or {} prev_dr = _get(prev, "data_request") or {} for group_name in cur_dr: if group_name not in prev_dr: diffs.append(f"group '{group_name}' absent from previous split config") continue cur_primary = _find_primary_cfg(cur_dr[group_name]) prev_primary = _find_primary_cfg(prev_dr[group_name]) if cur_primary is None or prev_primary is None: diffs.append(f"group '{group_name}': cannot find primary dataset config") continue if cur_primary.get("dataset_class") != prev_primary.get("dataset_class"): diffs.append( f"group '{group_name}' dataset_class: " f"{prev_primary.get('dataset_class')!r}{cur_primary.get('dataset_class')!r}" ) if cur_primary.get("data_location") != prev_primary.get("data_location"): diffs.append( f"group '{group_name}' data_location: " f"{prev_primary.get('data_location')!r}{cur_primary.get('data_location')!r}" ) cur_frac = _get(cur, "split", group_name) prev_frac = _get(prev, "split", group_name) if cur_frac != prev_frac: diffs.append(f"split.{group_name}: {prev_frac!r}{cur_frac!r}") cur_groups = list(_get(cur, "balance", "groups") or []) prev_groups = list(_get(prev, "balance", "groups") or []) cur_in = group_name in cur_groups prev_in = group_name in prev_groups if cur_in != prev_in: diffs.append(f"group '{group_name}' balance.groups membership: {prev_in}{cur_in}") return (len(diffs) == 0, diffs)
[docs] def find_equivalent_split(config: dict, results_root: Path | None = None) -> dict[str, Path] | None: """Scan the results directory for a previously persisted equivalent split. Returns the group→npz path mapping of the first match, or ``None``. """ if results_root is None: results_root = Path(config["general"]["results_dir"]).expanduser().resolve() if not results_root.exists(): return None split_dirs = sorted( (p for p in results_root.glob("*-splits-*") if p.is_dir()), key=lambda p: p.name, reverse=True, ) for split_dir in split_dirs: config_path = split_dir / "split_config.toml" if not config_path.exists(): continue try: with open(config_path) as f: prev_config = dict(tomlkit.parse(f.read())) except Exception: continue equivalent, _ = configs_equivalent(prev_config, config) if not equivalent: continue required_groups = set(config.get("data_request") or {}) paths = {g: split_dir / f"{g}_split.npz" for g in required_groups} if required_groups and all(p.exists() for p in paths.values()): return paths return None
# ── Public driver ──────────────────────────────────────────────────────────────
[docs] def create_splits( config: dict, datasets: dict[str, DataProvider], *, results_dir: Path | None = None, persist: bool = True, ) -> dict[str, dict]: """Compute (or load) splits and weights for each data group. Assigns ``split_indices`` / ``split_weights`` on each provider via :func:`assign_splits_to_providers`. Returns ------- dict mapping group_name → {"indexes": ndarray[int64], "weights": ndarray[float64] | None} """ validate_split_config(config, datasets) validate_balance_config(config, datasets) split_cfg = config.get("split", {}) # Determine whether paths were supplied using_paths = any(_is_path_value(split_cfg.get(name)) for name in datasets) if using_paths: paths = {name: Path(str(split_cfg[name])) for name in datasets if name in split_cfg} splits = load_split_files(paths) # Warn if the sibling split_config.toml differs from current config first_path = next(iter(paths.values())) sibling_cfg_path = first_path.parent / "split_config.toml" if sibling_cfg_path.exists(): try: with open(sibling_cfg_path) as f: prev_config = dict(tomlkit.parse(f.read())) equivalent, diffs = configs_equivalent(prev_config, config) if not equivalent: logger.warning( "Supplied split files were produced with a different config. Differences: %s", "; ".join(diffs), ) except Exception: pass assign_splits_to_providers(datasets, splits) return splits # Search for a reusable equivalent split equivalent_paths = find_equivalent_split(config) if equivalent_paths is not None: logger.info("Reusing equivalent split from %s", next(iter(equivalent_paths.values())).parent) splits = load_split_files(equivalent_paths) assign_splits_to_providers(datasets, splits) if persist and results_dir is not None: persist_splits(results_dir, splits, config) return splits # Compute fresh splits splits = _compute_splits(config, datasets) assign_splits_to_providers(datasets, splits) if persist and results_dir is not None: persist_splits(results_dir, splits, config) return splits