hyrax.models.hyrax_cnn
======================

.. py:module:: hyrax.models.hyrax_cnn


Attributes
----------

.. autoapisummary::

   hyrax.models.hyrax_cnn.logger


Classes
-------

.. autoapisummary::

   hyrax.models.hyrax_cnn.HyraxCNN


Module Contents
---------------

.. py:data:: logger

.. py:class:: HyraxCNN(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: conv1


   .. py:attribute:: pool


   .. py:attribute:: conv2


   .. py:attribute:: fc1


   .. py:attribute:: fc2


   .. py:attribute:: fc3


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: pool2d_output_size(input_size, kernel_size, stride, padding=0, dilation=1) -> int


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML training process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process. In this case it is identical to `test_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Model outputs** -- Tensor containing the model outputs for the current batch.
      :rtype: Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Extract image and label arrays from the batch dictionary.

      This static method is the interface between the data pipeline and the
      model. Override it on the model class to reshape or select fields from
      the collated batch to match the inputs your model expects.

      Hyrax will convert the returned arrays to PyTorch tensors and move them
      to the appropriate device automatically.

      :param data_dict: The collated batch dictionary produced by the data pipeline.
                        Expected to contain a ``"data"`` key with ``"image"`` and optionally
                        ``"label"`` fields.
      :type data_dict: dict

      :returns: **inputs** -- A tuple of ``(image, label)`` as float32 and int64 arrays respectively.
      :rtype: tuple of numpy.ndarray



