hyrax.models
============

.. py:module:: hyrax.models


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/hyrax/models/hsc_autoencoder/index
   /autoapi/hyrax/models/hsc_dcae/index
   /autoapi/hyrax/models/hyrax_autoencoder/index
   /autoapi/hyrax/models/hyrax_autoencoderv2/index
   /autoapi/hyrax/models/hyrax_cnn/index
   /autoapi/hyrax/models/hyrax_loopback/index
   /autoapi/hyrax/models/image_dcae/index
   /autoapi/hyrax/models/model_registry/index
   /autoapi/hyrax/models/model_utils/index
   /autoapi/hyrax/models/simclr/index


Classes
-------

.. autoapisummary::

   hyrax.models.HSCAutoencoder
   hyrax.models.HSCDCAE
   hyrax.models.ImageDCAE
   hyrax.models.HyraxAutoencoder
   hyrax.models.HyraxAutoencoderV2
   hyrax.models.HyraxCNN
   hyrax.models.HyraxLoopback
   hyrax.models.SimCLR


Functions
---------

.. autoapisummary::

   hyrax.models.hyrax_model


Package Contents
----------------

.. py:class:: HSCAutoencoder(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: encoder


   .. py:attribute:: decoder


   .. py:attribute:: config


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



.. py:class:: HSCDCAE(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: encoder1


   .. py:attribute:: encoder2


   .. py:attribute:: encoder3


   .. py:attribute:: encoder4


   .. py:attribute:: pool


   .. py:attribute:: decoder4


   .. py:attribute:: decoder3


   .. py:attribute:: decoder2


   .. py:attribute:: decoder1


   .. py:attribute:: activation


   .. py:attribute:: config


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



.. py:class:: ImageDCAE(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is an autoencoder with skipconnections that should work with
   arbitarily sized images with arbitrary number of channels.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: input_shape


   .. py:attribute:: config


   .. py:attribute:: latent_dim


   .. py:attribute:: base_channel_size


   .. py:attribute:: conv_output_size


   .. py:attribute:: encoder1


   .. py:attribute:: encoder2


   .. py:attribute:: encoder3


   .. py:attribute:: encoder4


   .. py:attribute:: pool


   .. py:attribute:: latent_encoder


   .. py:attribute:: latent_decoder


   .. py:attribute:: decoder4


   .. py:attribute:: decoder3


   .. py:attribute:: decoder2


   .. py:attribute:: decoder1


   .. py:attribute:: activation


   .. py:method:: _calculate_conv_output_size()

      Calculate the output size after all convolutional layers for the linear bottleneck.



   .. py:method:: encode(x)

      Encode input to latent space with skip connections.



   .. py:method:: decode(latent, skip_connections, encoded_shape)

      Decode from latent space to image with skip connections.



   .. py:method:: forward(x)

      Forward pass - returns latent representation for anomaly detection.



   .. py:method:: reconstruct(x)

      Full reconstruction for evaluation and anomaly detection.



   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed images** -- Tensor containing the reconstructed images for the current batch.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict)
      :staticmethod:


      Convert structured data to tensor format.



.. py:class:: HyraxAutoencoder(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This autoencoder is designed to work with a wide range of image datasets to allow testing.

   This example model is taken from this
   `autoenocoder tutorial <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html>`_

   The train function has been converted into train_batch for use with pytorch-ignite.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: c_hid


   .. py:attribute:: latent_dim


   .. py:attribute:: conv_end_w


   .. py:attribute:: conv_end_h


   .. py:method:: conv2d_multi_layer(input_size, num_applications, **kwargs) -> int


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: _init_encoder()


   .. py:method:: _eval_encoder(x)


   .. py:method:: _init_decoder()


   .. py:method:: _eval_decoder(x)


   .. py:method:: forward(batch)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step. i.e. the
      contents of the inner loop of a ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed inputs** -- The reconstructed inputs from the autoencoder.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      This function converts structured data to the input tensor we need to run

      :param data_dict: The dictionary returned from our data source
      :type data_dict: dict



   .. py:method:: _optimizer()


.. py:class:: HyraxAutoencoderV2(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

   V2 improvements:
   - Configurable final layer activation
   - Uses criterion and optimizer from config variables

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: c_hid


   .. py:attribute:: latent_dim


   .. py:attribute:: conv_end_w


   .. py:attribute:: conv_end_h


   .. py:attribute:: band_reduction


   .. py:method:: conv2d_multi_layer(input_size, num_applications, **kwargs) -> int


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: _init_encoder()


   .. py:method:: _eval_encoder(x)


   .. py:method:: _init_decoder()


   .. py:method:: _eval_decoder(x)


   .. py:method:: forward(batch)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step. i.e. the
      contents of the inner loop of a ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      This function converts structured data to the input tensor we need to run

      :param data_dict: The dictionary returned from our data source
      :type data_dict: dict



.. py:class:: HyraxCNN(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: conv1


   .. py:attribute:: pool


   .. py:attribute:: conv2


   .. py:attribute:: fc1


   .. py:attribute:: fc2


   .. py:attribute:: fc3


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: pool2d_output_size(input_size, kernel_size, stride, padding=0, dilation=1) -> int


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML training process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process. In this case it is identical to `test_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Model outputs** -- Tensor containing the model outputs for the current batch.
      :rtype: Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Does NOT convert to PyTorch Tensors.
      This works exclusively with numpy data types and returns
      a tuple of numpy data types.



.. py:class:: HyraxLoopback(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   Simple model for testing which returns its own input

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: unused_module


   .. py:attribute:: config


   .. py:attribute:: load


   .. py:method:: forward(x)

      We simply return our input



   .. py:method:: train_batch(batch)

      Training is a noop



   .. py:method:: validate_batch(batch)

      Validation is just a forward pass



   .. py:method:: test_batch(batch)

      Testing is just a forward pass



   .. py:method:: infer_batch(batch)

      Inference is just a forward pass



.. py:function:: hyrax_model(cls)

   Decorator to register a model with the model registry, and to add common interface functions

   :returns: The class with additional interface functions.
   :rtype: type


.. py:class:: SimCLR(config, shape)

   Bases: :py:obj:`torch.nn.Module`


   SimCLR model. Implementation based on Chen, 2020

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: shape


   .. py:attribute:: backbone


   .. py:attribute:: projection_head


   .. py:attribute:: criterion


   .. py:method:: forward(x)


   .. py:method:: train_batch(x)


   .. py:method:: validate_batch(x)


   .. py:method:: test_batch(x)


   .. py:method:: infer_batch(x)

      Function to run inference on a batch of data.

      :param x: Input tensor of shape (batch_size, channels, height, width).
      :type x: torch.Tensor

      :returns: Output tensor of shape (batch_size, projection_dimension).
      :rtype: torch.Tensor



