.. _dataset_splits:

Data set splits (subsets)
=============================

Datasets used in machine learning are typically split in order to avoid overfitting a particular dataset of
interest, and to perform various sorts of checking that the model is learning what the researcher intends.
In Hyrax, splits are defined in the top-level ``[split]`` configuration table, keyed by the group names
declared in ``data_request``.

Splits in training
------------------

To split a dataset between training and validation, define named groups in the ``data_request`` that point
to the **same** ``data_location``, then assign each group a fraction in the ``[split]`` table. Hyrax will
partition the dataset so that each group receives a non-overlapping subset of the data proportional to its
fraction. The fractions for all groups sharing a ``data_location`` must sum to ``<= 1.0``. Groups omitted
from ``[split]`` default to ``1.0`` (the full dataset).

For example, to use 80% of the data for training and 20% for validation:

.. tab-set::

    .. tab-item:: Notebook

        .. code-block:: python

            from hyrax import Hyrax
            h = Hyrax()

            data_request = {
                "train": {
                    "my_data": {
                        "dataset_class": "HyraxCifarDataset",
                        "data_location": "./all_data",
                        "primary_id_field": "object_id",
                    }
                },
                "validate": {
                    "my_data": {
                        "dataset_class": "HyraxCifarDataset",
                        "data_location": "./all_data",
                        "primary_id_field": "object_id",
                    }
                },
            }
            h.set_config("data_request", data_request)

            split = {
                "train": 0.8,
                "validate": 0.2,
            }
            h.set_config("split", split)

    .. tab-item:: CLI

        .. code-block:: toml

            [data_request.train.my_data]
            dataset_class = "HyraxCifarDataset"
            data_location = "./all_data"
            primary_id_field = "object_id"

            [data_request.validate.my_data]
            dataset_class = "HyraxCifarDataset"
            data_location = "./all_data"
            primary_id_field = "object_id"

            [split]
            train = 0.8
            validate = 0.2

The ``train`` :doc:`verb </verbs>` trains on the ``train`` group and, when present, computes a
validation loss each epoch using the ``validate`` group. Adding a ``test`` group is supported
but the train verb does not use it during training — it is available for downstream evaluation.

For more detail on data requests, including how to use separate directories for each split, see the
:doc:`data requests notebook </notebooks/data_requests>`.

Randomness in splits
--------------------

When ``[split]`` fractions are used, Hyrax randomly assigns indices to each group. By default,
system entropy seeds the random number generator. For reproducible splits, set the ``rng_seed``
key in ``[split]`` to any integer. If ``rng_seed`` is omitted or set to ``false``, 
Hyrax will fall back to using ``data_set.seed`` if it is set, or system entropy if not.:

.. tab-set::

    .. tab-item:: Notebook

        .. code-block:: python

            from hyrax import Hyrax
            h = Hyrax()
            h.config["split"]["rng_seed"] = 1

    .. tab-item:: CLI

        .. code-block:: toml

            [split]
            rng_seed = 1
