hyrax.models.hyrax_autoencoderv2
================================

.. py:module:: hyrax.models.hyrax_autoencoderv2


Attributes
----------

.. autoapisummary::

   hyrax.models.hyrax_autoencoderv2.logger


Classes
-------

.. autoapisummary::

   hyrax.models.hyrax_autoencoderv2.ArcsinhActivation
   hyrax.models.hyrax_autoencoderv2.HyraxAutoencoderV2


Module Contents
---------------

.. py:data:: logger

.. py:class:: ArcsinhActivation(*args, **kwargs)

   Bases: :py:obj:`torch.nn.Module`


   Helper module for HyraxAutoencoderV2 to use the arcsinh function

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:method:: forward(x)


.. py:class:: HyraxAutoencoderV2(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

   V2 improvements:
   - Configurable final layer activation
   - Uses criterion and optimizer from config variables

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: c_hid


   .. py:attribute:: latent_dim


   .. py:attribute:: conv_end_w


   .. py:attribute:: conv_end_h


   .. py:method:: conv2d_multi_layer(input_size, num_applications, **kwargs) -> int


   .. py:method:: conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) -> int


   .. py:method:: _init_encoder()


   .. py:method:: _eval_encoder(x)


   .. py:method:: _init_decoder()


   .. py:method:: _eval_decoder(x)


   .. py:method:: forward(batch)


   .. py:method:: train_step(batch)

      This function contains the logic for a single training step. i.e. the
      contents of the inner loop of a ML training process.

      :param batch: A tuple containing the inputs and labels for the current batch.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: to_tensor(data_dict) -> tuple[torch.Tensor]
      :staticmethod:


      This function converts structured data to the input tensor we need to run

      :param data_dict: The dictionary returned from our data source
      :type data_dict: dict



