hyrax.models.hyrax_autoencoderv2
================================

.. py:module:: hyrax.models.hyrax_autoencoderv2


Attributes
----------

.. autoapisummary::

   hyrax.models.hyrax_autoencoderv2.logger


Classes
-------

.. autoapisummary::

   hyrax.models.hyrax_autoencoderv2.ArcsinhActivation
   hyrax.models.hyrax_autoencoderv2.HyraxAutoencoderV2


Module Contents
---------------

.. py:data:: logger

.. py:class:: ArcsinhActivation(*args: Any, **kwargs: Any)

   Bases: :py:obj:`torch.nn.Module`


   Helper module for HyraxAutoencoderV2 to use the arcsinh function

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:method:: forward(x)


.. py:class:: HyraxAutoencoderV2(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

   V2 improvements:
   - Configurable final layer activation
   - Uses criterion and optimizer from config variables

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: c_hid


   .. py:attribute:: latent_dim


   .. py:attribute:: conv_end_w


   .. py:attribute:: conv_end_h


   .. py:attribute:: band_reduction


   .. py:method:: conv2d_multi_layer(input_size, num_applications, **kwargs) -> int


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: _init_encoder()


   .. py:method:: _eval_encoder(x)


   .. py:method:: _init_decoder()


   .. py:method:: _eval_decoder(x)


   .. py:method:: forward(batch)


   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML validation process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. i.e. the contents of the inner loop of a
      ML testing process. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step. i.e. the
      contents of the inner loop of a ML inference process.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed outputs** -- The reconstructed outputs from the autoencoder.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict) -> tuple
      :staticmethod:


      Extract the image array from the batch dictionary.

      This static method is the interface between the data pipeline and the
      model. Override it on the model class to reshape or select fields from
      the collated batch to match the inputs your model expects.

      Hyrax will convert the returned array to a PyTorch tensor and move it
      to the appropriate device automatically.

      :param data_dict: The collated batch dictionary produced by the data pipeline.
                        Expected to contain a ``"data"`` key with an ``"image"`` field.
      :type data_dict: dict

      :returns: **image** -- The image array extracted from the batch.
      :rtype: numpy.ndarray



