Source code for hyrax.verbs.create_splits

import logging

from hyrax.trace import trace_verb_data

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
@hyrax_verb
[docs] class CreateSplits(Verb): """Create and persist reproducible dataset splits."""
[docs] cli_name = "create_splits"
[docs] add_parser_kwargs = {}
[docs] description = "Compute and persist dataset splits for reproducible training workflows."
[docs] REQUIRED_DATA_GROUPS = ()
[docs] OPTIONAL_DATA_GROUPS = ()
@staticmethod
[docs] def setup_parser(parser): """No additional CLI options needed."""
[docs]
[docs] def run_cli(self, args=None): """CLI stub for CreateSplits verb.""" logger.info("create_splits run from CLI") self.run()
@trace_verb_data
[docs] def run(self): """Compute dataset splits and write them to a results directory.
Reads the ``[split]`` and ``[balance]`` config tables to determine how to partition each data group, then persists ``.npz`` index files and a ``split_config.toml`` under a timestamped ``*-splits-*`` results directory. Subsequent verbs (``train``, ``infer``, ``test``) can point at this directory to reuse the same split without recomputing it. Returns ------- dict[str, DataProvider] The populated dataset providers, keyed by group name. """ from hyrax.config_utils import create_results_dir, log_runtime_config from hyrax.pytorch_ignite import setup_dataset from hyrax.splitting_utils import create_splits config = self.config results_dir = create_results_dir(config, "splits") datasets = setup_dataset(config) create_splits(config, datasets, results_dir=results_dir, persist=True) log_runtime_config(config, results_dir) logger.info(f"Split files written to: {results_dir}") return datasets