hyrax.datasets.data_provider
============================

.. py:module:: hyrax.datasets.data_provider


Attributes
----------

.. autoapisummary::

   hyrax.datasets.data_provider.logger
   hyrax.datasets.data_provider.tensorboardx_logger
   hyrax.datasets.data_provider._JOIN_CACHE_VERSION


Classes
-------

.. autoapisummary::

   hyrax.datasets.data_provider.DataProvider


Functions
---------

.. autoapisummary::

   hyrax.datasets.data_provider._handle_nans
   hyrax.datasets.data_provider._handle_nans_numpy
   hyrax.datasets.data_provider._handle_nans_tuple
   hyrax.datasets.data_provider._handle_nans_logic_numpy
   hyrax.datasets.data_provider._handle_nan_quantile_numpy
   hyrax.datasets.data_provider._handle_nan_zero_numpy
   hyrax.datasets.data_provider._join_cache_fingerprint
   hyrax.datasets.data_provider._join_cache_path
   hyrax.datasets.data_provider._load_join_cache
   hyrax.datasets.data_provider._save_join_cache
   hyrax.datasets.data_provider.generate_data_request_from_config


Module Contents
---------------

.. py:data:: logger

.. py:data:: tensorboardx_logger

.. py:function:: _handle_nans(batch, config)

   The default _handle_nan function. Will print a warning and return `batch`.


.. py:function:: _handle_nans_numpy(batch, config)

.. py:function:: _handle_nans_tuple(batch, config)

   This is the tuple-specific implementation of _handle_nans. Each element
   of the tuple will have nan-handling applied.
   Non-numpy elements are returned unchanged.


.. py:function:: _handle_nans_logic_numpy(batch, config)

.. py:function:: _handle_nan_quantile_numpy(batch, quantile)

.. py:function:: _handle_nan_zero_numpy(batch)

.. py:data:: _JOIN_CACHE_VERSION
   :value: 1


   Bump when the cache format changes to invalidate all existing caches.

.. py:function:: _join_cache_fingerprint(dataset_len: int, getter, *, n_probes: int = 8) -> str

   Compute a lightweight fingerprint for a dataset's join-key column.

   The fingerprint incorporates the dataset length and a small number of
   deterministically sampled key values.  If any of these change, the
   fingerprint changes and the cache is considered stale.

   :param dataset_len: Number of items in the dataset.
   :type dataset_len: int
   :param getter: The ``get_<join_field>`` method; called with integer indices.
   :type getter: callable
   :param n_probes: How many key values to sample.  Kept small so the cost is negligible.
   :type n_probes: int


.. py:function:: _join_cache_path(data_location: str | None, fingerprint: str) -> pathlib.Path | None

   Return the path where a join cache file would live, or ``None`` if
   caching is not possible (e.g. no ``data_location``).


.. py:function:: _load_join_cache(data_location: str | None, dataset_len: int, getter) -> dict[str, int] | None

   Attempt to load a cached reverse-index map from disk.

   Returns ``None`` on any cache miss (file absent, fingerprint mismatch,
   corrupt data, permission error, etc.).


.. py:function:: _save_join_cache(data_location: str | None, dataset_len: int, getter, reverse_map: dict[str, int]) -> None

   Persist a reverse-index map to disk.  Failures are silently ignored
   (caching is best-effort).


.. py:function:: generate_data_request_from_config(config)

   This function extracts the data request from the configuration.

   If `[data_request]` is not defined, an error will be raised.

   :param config: The Hyrax configuration that is passed to each dataset instance.
   :type config: dict

   :returns: A dictionary where keys are dataset names and values are lists of fields
   :rtype: dict

   :raises RuntimeError: If `data_request` is not provided in the configuration.


.. py:class:: DataProvider(config: dict, request: dict)

   This class presents itself as a PyTorch Dataset, but acts like a GraphQL
   gateway that fetches data from multiple datasets based on the `data_request`
   dictionary provided during initialization.

   This class allows for flexible data retrieval from multiple dataset classes,
   each of which can have different fields requested.

   Additionally, the user can provide specific configuration options for each
   dataset class that will be merged with the original configuration provided
   during initialization.

   Initialize the DataProvider with a Hyrax config and extract (or create)
   the data_request.

   :param config: The Hyrax configuration that defines the data_request.
   :type config: dict
   :param request: A dictionary that defines the data request.
   :type request: dict


   .. py:attribute:: config


   .. py:attribute:: data_request


   .. py:attribute:: prepped_datasets


   .. py:attribute:: dataset_getters


   .. py:attribute:: all_metadata_fields


   .. py:attribute:: requested_fields


   .. py:attribute:: custom_collate_functions


   .. py:attribute:: field_collate_functions


   .. py:attribute:: primary_dataset
      :value: None



   .. py:attribute:: primary_dataset_id_field_name
      :value: None



   .. py:attribute:: split_fraction
      :value: None



   .. py:attribute:: primary_data_location
      :value: None



   .. py:attribute:: split_indices
      :value: None



   .. py:attribute:: _join_fields
      :type:  dict[str, str]


   .. py:attribute:: _join_maps
      :type:  dict[str, dict[str, int]]


   .. py:attribute:: data_cache


   .. py:method:: pull_up_primary_dataset_methods()

      If a primary dataset is defined, we will pull up some of its methods
      to the DataProvider level so that they can be called directly on the
      DataProvider instance.



   .. py:method:: __getitem__(idx) -> dict

      This method returns data for a given index.

      It is also a wrapper that allows this class to be treated as a PyTorch
      Dataset.

      :param idx: The index of the data item to retrieve.
      :type idx: int

      :returns: A dictionary containing the requested data from the prepared datasets.
      :rtype: dict



   .. py:method:: __len__() -> int

      Returns the length of the dataset.
      If the primary dataset is defined, it will return that length, otherwise
      it will use the length of the first dataset in ``self.prepped_datasets``.



   .. py:method:: __repr__() -> str


   .. py:method:: fields() -> dict

      Print all the available fields for each dataset in the DataProvider.

      :returns: A dictionary mapping friendly dataset names to their available fields.
      :rtype: dict



   .. py:method:: _setup_trace()

      If we're tracing, set up the relevant hooks



   .. py:method:: prepare_datasets()

      Instantiate each of the requested datasets based on the ``data_request``
      configuration dictionary. Store the prepared instances in the
      ``self.prepped_datasets`` dictionary.



   .. py:method:: _build_join_indices()

      Build reverse-index mappings for datasets that declare a ``join_field``.

      For each joined secondary dataset, a dict ``{str(key): int(index)}``
      is built by iterating over all items in that dataset.  At runtime,
      ``resolve_data`` uses these maps to translate primary object IDs to
      secondary indices (left outer join — unmatched primaries get ``None``).

      Reverse maps for independent secondaries are built in parallel using
      threads.  Built maps are persisted to a cache file next to the
      dataset's ``data_location``; a fingerprint check on subsequent runs
      avoids the O(N) rebuild.

      This method is called once during ``__init__``.  Runtime lookups in
      ``resolve_data`` remain O(1) dict access.



   .. py:method:: _apply_configurations(base_config: dict, dataset_definition: dict) -> dict
      :staticmethod:


      Merge the original base config with the dataset-specific config.

      This function uses ``ConfigManager.merge_configs`` to merge the
      dataset-specific configuration into a copy of the original base config.

      If no ``dataset_config`` is provided in the ``dataset_definition`` dict,
      the original base config will be returned unmodified.

      Data request dictionary examples:

      1) Requesting a built-in Hyrax dataset, "MyDataset"

      .. code-block:: python

          "my_dataset": {
              "dataset_class": "MyDataset",
              "data_location": "/path/to/data",
              "dataset_config": {
                  "MyDataset": {
                      "param1": "value1",
                      "param2": "value2"
                  }
              },
              "fields": ["field1", "field2"]
          }

      or equivalently in a .toml file:

      .. code-block:: toml

          [data_request]
          [data_request.my_dataset]
          dataset_class = "MyDataset"
          data_location = "/path/to/data"
          fields = ["field1", "field2"]
          [data_request.my_dataset.dataset_config.MyDataset]
          param1 = "value1"
          param2 = "value2"

      Here the ``dataset_config`` dictionary will be merged into
      the original base config, overriding the values of param1 and param2
      when creating an instance of ``MyDataset``.

      2) Requesting an external dataset (not built-in), "ExternalDataset"
      Note that the dictionary nesting under "dataset_config" will match the
      dictionary structure in the external dataset's default_config file.

      .. code-block:: python

          "my_dataset": {
              "dataset_class": "ExternalDataset",
              "data_location": "/path/to/data",
              "dataset_config": {
                  "external_example": {
                      "ExternalDataset": {
                          "param1": "value1",
                          "param2": "value2"
                      },
                  },
              },
              "fields": ["field1", "field2"]
          }

      or equivalently in a .toml file:

      .. code-block:: toml

          [data_request]
          [data_request.my_dataset]
          dataset_class = "ExternalDataset"
          data_location = "/path/to/data"
          fields = ["field1", "field2"]
          [data_request.my_dataset.dataset_config.external_example.MyDataset]
          param1 = "value1"
          param2 = "value2"

      Here the ``dataset_config`` dictionary will be merged into
      the original base config, overriding the values of param1 and param2
      when creating an instance of ``ExternalDataset``.

      :param base_config: The original base configuration dictionary. A copy of this is created,
                          the dataset_definition dict is merged into the copy, and the copy
                          is returned.
      :type base_config: dict
      :param dataset_definition: A dictionary defining the dataset, including any dataset-specific
                                 configuration options in a nested ``dataset_config`` dictionary.
      :type dataset_definition: dict

      :returns: A final configuration dictionary to be passed when creating an instance
                of the dataset class.
      :rtype: dict



   .. py:method:: sample_data() -> dict

      Returns a data sample. Primarily this will be used for instantiating a
      model so that any runtime resizing can be handled properly.

      :returns: A dictionary containing the data for index 0.
      :rtype: dict



   .. py:method:: get_object_id(idx: int) -> str

      Returns the ID at a particular index.

      IDs are provided by the primary dataset's primary ID column.



   .. py:method:: ids() -> list[str]

      Returns the IDs of the primary dataset.

      :returns: A list of string IDs corresponding to the primary dataset, ordered by index.
      :rtype: list of str



   .. py:method:: resolve_data(idx: int) -> dict[str, dict[str, Any] | str | None]

      This method requests the field data from the prepared datasets by index.

      For joined secondary datasets (those with ``join_field``), the primary
      dataset's object ID is looked up in the secondary's reverse map.  If a
      match exists the secondary's data is returned normally; if no match
      exists the friendly-name entry is set to ``None`` (left outer join).

      :param idx: The index of the data item to retrieve.
      :type idx: int

      :returns: A dictionary containing the requested data from the prepared datasets.
                Each key is a dataset friendly name mapped to a dict of field values,
                or ``None`` when a joined secondary has no match for this item.
                If a primary dataset is configured, the top-level ``"object_id"`` key
                holds a string representation of the primary ID.
      :rtype: dict[str, dict[str, Any] | str | None]



   .. py:method:: metadata(idxs=None, fields=None) -> numpy.ndarray

      Fetch the requested metadata fields for the given indices.

      Example:

      .. code-block:: python

          # Fetch the metadata_1 and metadata_2 fields from the dataset with the
          # friendly name "random_1".

          metadata = data_provider.metadata(
              idxs=[0, 1, 2],
              fields=["metadata_1_random_1", "metadata_2_random_1"]
          )

      :param idxs: A list of indices for which to fetch metadata. If None, no metadata
                   will be returned.
      :type idxs: list of int, optional
      :param fields: A list of metadata fields to fetch. If None, no metadata will be
                     returned.
      :type fields: list of str, optional

      :returns: A structured NumPy array containing the requested metadata fields.
                The dtype names of the array will be the metadata field names, modified
                to include the friendly name of the dataset they come from. For example,
                if the "RA" field comes from a dataset with the friendly name "cifar",
                the returned field name will be "RA_cifar".
      :rtype: np.ndarray



   .. py:method:: metadata_fields(friendly_name=None) -> list[str]

      Returns a list of metadata fields that are available across all prepared
      datasets.

      The field names will be modified to include the friendly name of the
      dataset they come from. For example, if the "RA" field comes from a dataset
      with the friendly name "cifar", the returned field name will be "RA_cifar".

      NOTE: If a specific dataset friendly_name is provided, only the metadata
      fields for that dataset will be returned, and the field names will not
      include the friendly name suffix.

      :param friendly_name: If provided, only the metadata fields for the specified friendly name
                            will be returned. If not provided, metadata fields from all datasets
                            will be returned.
      :type friendly_name: str, optional

      :returns: The column names of the metadata table passed. Empty list if no metadata
                was provided during construction of the DataProvider.
      :rtype: list[str]



   .. py:method:: _translate_metadata_indices(idxs, friendly_name)

      Translate primary indices to real dataset indices for metadata.

      For joined secondaries, looks up the matching secondary index via the
      join map.  Indices with no match in the secondary are omitted (the
      caller receives fewer rows than requested for that secondary).

      Returns a tuple ``(translated_idxs, mask)`` where *mask* is a boolean
      list of the same length as *idxs* indicating which positions had a
      valid match.  Non-joined datasets always return a full-True mask.



   .. py:method:: _primary_or_first_dataset()

      Returns the primary dataset instance if it exists, otherwise returns
      the first dataset in the prepped_datasets.



   .. py:method:: collate(batch: list[dict]) -> dict

      Custom collate function to be used outside the context of a PyTorch
      DataLoader.

      This function takes a list of data samples (each sample is a dictionary)
      and combines them into a single batch dictionary.

      :param batch: A list of data samples, where each sample is a dictionary.
      :type batch: list of dict

      :returns: A dictionary where each key corresponds to a field and the value is
                a list of values for that field across the batch.
      :rtype: dict



   .. py:method:: handle_nans(batch_dict)

      Apply nan handling to a batch dictionary

      :param batch_dict: Dictionary from data column to an entire batch of data in np.ndarray form
      :type batch_dict: dict[str, np.ndarray]

      :returns: The same batch dict but with NaNs altered according to the Hyrax configuration.
      :rtype: dict[str, np.ndarray]



