hyrax.models.hyrax_cnn
======================

.. py:module:: hyrax.models.hyrax_cnn


Attributes
----------

.. autoapisummary::

   hyrax.models.hyrax_cnn.logger


Classes
-------

.. autoapisummary::

   hyrax.models.hyrax_cnn.HyraxCNN


Module Contents
---------------

.. py:data:: logger

.. py:class:: HyraxCNN(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: conv1


   .. py:attribute:: pool


   .. py:attribute:: conv2


   .. py:attribute:: fc1


   .. py:attribute:: fc2


   .. py:attribute:: fc3


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: pool2d_output_size(input_size, kernel_size, stride, padding=0, dilation=1) -> int


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML training process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process. In this case it is identical to `test_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Model outputs** -- Tensor containing the model outputs for the current batch.
      :rtype: Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Does NOT convert to PyTorch Tensors.
      This works exclusively with numpy data types and returns
      a tuple of numpy data types.



