hyrax.models
============

.. py:module:: hyrax.models


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/hyrax/models/hsc_autoencoder/index
   /autoapi/hyrax/models/hsc_dcae/index
   /autoapi/hyrax/models/hyrax_autoencoder/index
   /autoapi/hyrax/models/hyrax_autoencoderv2/index
   /autoapi/hyrax/models/hyrax_cnn/index
   /autoapi/hyrax/models/hyrax_loopback/index
   /autoapi/hyrax/models/image_dcae/index
   /autoapi/hyrax/models/model_registry/index
   /autoapi/hyrax/models/model_utils/index
   /autoapi/hyrax/models/simclr/index


Classes
-------

.. autoapisummary::

   hyrax.models.HSCAutoencoder
   hyrax.models.HSCDCAE
   hyrax.models.ImageDCAE
   hyrax.models.HyraxAutoencoder
   hyrax.models.HyraxAutoencoderV2
   hyrax.models.HyraxCNN
   hyrax.models.HyraxLoopback
   hyrax.models.SimCLR


Functions
---------

.. autoapisummary::

   hyrax.models.hyrax_model


Package Contents
----------------

.. py:class:: HSCAutoencoder(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: encoder


   .. py:attribute:: decoder


   .. py:attribute:: config


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



.. py:class:: HSCDCAE(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: encoder1


   .. py:attribute:: encoder2


   .. py:attribute:: encoder3


   .. py:attribute:: encoder4


   .. py:attribute:: pool


   .. py:attribute:: decoder4


   .. py:attribute:: decoder3


   .. py:attribute:: decoder2


   .. py:attribute:: decoder1


   .. py:attribute:: activation


   .. py:attribute:: config


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



.. py:class:: ImageDCAE(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is an autoencoder with skipconnections that should work with
   arbitarily sized images with arbitrary number of channels.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: input_shape


   .. py:attribute:: config


   .. py:attribute:: latent_dim


   .. py:attribute:: base_channel_size


   .. py:attribute:: conv_output_size


   .. py:attribute:: encoder1


   .. py:attribute:: encoder2


   .. py:attribute:: encoder3


   .. py:attribute:: encoder4


   .. py:attribute:: pool


   .. py:attribute:: latent_encoder


   .. py:attribute:: latent_decoder


   .. py:attribute:: decoder4


   .. py:attribute:: decoder3


   .. py:attribute:: decoder2


   .. py:attribute:: decoder1


   .. py:attribute:: activation


   .. py:method:: _calculate_conv_output_size()

      Calculate the output size after all convolutional layers for the linear bottleneck.



   .. py:method:: encode(x)

      Encode input to latent space with skip connections.



   .. py:method:: decode(latent, skip_connections, encoded_shape)

      Decode from latent space to image with skip connections.



   .. py:method:: forward(x)

      Forward pass - returns latent representation for anomaly detection.



   .. py:method:: reconstruct(x)

      Full reconstruction for evaluation and anomaly detection.



   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed images** -- Tensor containing the reconstructed images for the current batch.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict)
      :staticmethod:


      Extract the image array from the batch dictionary.

      This static method is the interface between the data pipeline and the
      model. Override it on the model class to reshape or select fields from
      the collated batch to match the inputs your model expects.

      Hyrax will convert the returned array to a PyTorch tensor and move it
      to the appropriate device automatically.

      :param data_dict: The collated batch dictionary produced by the data pipeline.
                        Expected to contain a ``"data"`` key with an ``"image"`` field.
      :type data_dict: dict

      :returns: **image** -- The image array extracted from the batch.
      :rtype: numpy.ndarray



.. py:class:: HyraxAutoencoder(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This autoencoder is designed to work with a wide range of image datasets to allow testing.

   This example model is taken from this
   `autoenocoder tutorial <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html>`_

   The train function has been converted into train_batch for use with pytorch-ignite.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: c_hid


   .. py:attribute:: latent_dim


   .. py:attribute:: conv_end_w


   .. py:attribute:: conv_end_h


   .. py:method:: conv2d_multi_layer(input_size, num_applications, **kwargs) -> int


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: _init_encoder()


   .. py:method:: _eval_encoder(x)


   .. py:method:: _init_decoder()


   .. py:method:: _eval_decoder(x)


   .. py:method:: forward(batch)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step. i.e. the
      contents of the inner loop of a ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed inputs** -- The reconstructed inputs from the autoencoder.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Extract the image array from the batch dictionary.

      This static method is the interface between the data pipeline and the
      model. Override it on the model class to reshape or select fields from
      the collated batch to match the inputs your model expects.

      Hyrax will convert the returned array to a PyTorch tensor and move it
      to the appropriate device automatically.

      :param data_dict: The collated batch dictionary produced by the data pipeline.
                        Expected to contain a ``"data"`` key with an ``"image"`` field.
      :type data_dict: dict

      :returns: **image** -- The image array extracted from the batch.
      :rtype: numpy.ndarray



   .. py:method:: _optimizer()


.. py:class:: HyraxAutoencoderV2(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

   V2 improvements:
   - Configurable final layer activation
   - Uses criterion and optimizer from config variables

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: c_hid


   .. py:attribute:: latent_dim


   .. py:attribute:: conv_end_w


   .. py:attribute:: conv_end_h


   .. py:attribute:: band_reduction


   .. py:method:: conv2d_multi_layer(input_size, num_applications, **kwargs) -> int


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: _init_encoder()


   .. py:method:: _eval_encoder(x)


   .. py:method:: _init_decoder()


   .. py:method:: _eval_decoder(x)


   .. py:method:: forward(batch)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step. i.e. the
      contents of the inner loop of a ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Extract the image array from the batch dictionary.

      This static method is the interface between the data pipeline and the
      model. Override it on the model class to reshape or select fields from
      the collated batch to match the inputs your model expects.

      Hyrax will convert the returned array to a PyTorch tensor and move it
      to the appropriate device automatically.

      :param data_dict: The collated batch dictionary produced by the data pipeline.
                        Expected to contain a ``"data"`` key with an ``"image"`` field.
      :type data_dict: dict

      :returns: **image** -- The image array extracted from the batch.
      :rtype: numpy.ndarray



.. py:class:: HyraxCNN(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: conv1


   .. py:attribute:: pool


   .. py:attribute:: conv2


   .. py:attribute:: fc1


   .. py:attribute:: fc2


   .. py:attribute:: fc3


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: pool2d_output_size(input_size, kernel_size, stride, padding=0, dilation=1) -> int


   .. py:method:: forward(x)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML training process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process. In this case it is identical to `test_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML inference process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Model outputs** -- Tensor containing the model outputs for the current batch.
      :rtype: Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Extract image and label arrays from the batch dictionary.

      This static method is the interface between the data pipeline and the
      model. Override it on the model class to reshape or select fields from
      the collated batch to match the inputs your model expects.

      Hyrax will convert the returned arrays to PyTorch tensors and move them
      to the appropriate device automatically.

      :param data_dict: The collated batch dictionary produced by the data pipeline.
                        Expected to contain a ``"data"`` key with ``"image"`` and optionally
                        ``"label"`` fields.
      :type data_dict: dict

      :returns: **inputs** -- A tuple of ``(image, label)`` as float32 and int64 arrays respectively.
      :rtype: tuple of numpy.ndarray



.. py:class:: HyraxLoopback(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   Simple model for testing which returns its own input

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: unused_module


   .. py:attribute:: config


   .. py:attribute:: load


   .. py:method:: forward(x)

      We simply return our input



   .. py:method:: train_batch(batch)

      Training is a noop



   .. py:method:: validate_batch(batch)

      Validation is just a forward pass



   .. py:method:: test_batch(batch)

      Testing is just a forward pass



   .. py:method:: infer_batch(batch)

      Inference is just a forward pass



.. py:function:: hyrax_model(cls)

   Decorator to register a model with the model registry, and to add common interface functions

   :returns: The class with additional interface functions.
   :rtype: type


.. py:class:: SimCLR(config, shape)

   Bases: :py:obj:`torch.nn.Module`


   SimCLR model. Implementation based on Chen, 2020

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: shape


   .. py:attribute:: backbone


   .. py:attribute:: projection_head


   .. py:attribute:: criterion


   .. py:method:: forward(x)


   .. py:method:: train_batch(x)


   .. py:method:: validate_batch(x)


   .. py:method:: test_batch(x)


   .. py:method:: infer_batch(x)

      Function to run inference on a batch of data.

      :param x: Input tensor of shape (batch_size, channels, height, width).
      :type x: torch.Tensor

      :returns: Output tensor of shape (batch_size, projection_dimension).
      :rtype: torch.Tensor



