hyrax.models.hyrax_cnn
======================

.. py:module:: hyrax.models.hyrax_cnn


Attributes
----------

.. autoapisummary::

   hyrax.models.hyrax_cnn.logger


Classes
-------

.. autoapisummary::

   hyrax.models.hyrax_cnn.HyraxCNN


Module Contents
---------------

.. py:data:: logger

.. py:class:: HyraxCNN(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: conv1


   .. py:attribute:: pool


   .. py:attribute:: conv2


   .. py:attribute:: fc1


   .. py:attribute:: fc2


   .. py:attribute:: fc3


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: pool2d_output_size(input_size, kernel_size, stride, padding=0, dilation=1) -> int


   .. py:method:: forward(x)


   .. py:method:: train_step(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: to_tensor(data_dict) -> tuple
      :staticmethod:


      Does NOT convert to PyTorch Tensors.
      This works exclusively with numpy data types and returns
      a tuple of numpy data types.



