hyrax.models.image_dcae
=======================

.. py:module:: hyrax.models.image_dcae


Classes
-------

.. autoapisummary::

   hyrax.models.image_dcae.ArcsinhActivation
   hyrax.models.image_dcae.ImageDCAE


Module Contents
---------------

.. py:class:: ArcsinhActivation(*args: Any, **kwargs: Any)

   Bases: :py:obj:`torch.nn.Module`


   Helper module for ImageDCAE to use the arcsinh function

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:method:: forward(x)


.. py:class:: ImageDCAE(config, data_sample=None)

   Bases: :py:obj:`torch.nn.Module`


   This is an autoencoder with skipconnections that should work with
   arbitarily sized images with arbitrary number of channels.

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: input_shape


   .. py:attribute:: config


   .. py:attribute:: latent_dim


   .. py:attribute:: base_channel_size


   .. py:attribute:: conv_output_size


   .. py:attribute:: encoder1


   .. py:attribute:: encoder2


   .. py:attribute:: encoder3


   .. py:attribute:: encoder4


   .. py:attribute:: pool


   .. py:attribute:: latent_encoder


   .. py:attribute:: latent_decoder


   .. py:attribute:: decoder4


   .. py:attribute:: decoder3


   .. py:attribute:: decoder2


   .. py:attribute:: decoder1


   .. py:attribute:: activation


   .. py:method:: _calculate_conv_output_size()

      Calculate the output size after all convolutional layers for the linear bottleneck.



   .. py:method:: encode(x)

      Encode input to latent space with skip connections.



   .. py:method:: decode(latent, skip_connections, encoded_shape)

      Decode from latent space to image with skip connections.



   .. py:method:: forward(x)

      Forward pass - returns latent representation for anomaly detection.



   .. py:method:: reconstruct(x)

      Full reconstruction for evaluation and anomaly detection.



   .. py:method:: train_batch(batch)

      This function contains the logic for a single training step.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: validate_batch(batch)

      This function contains the logic for a single validation step that will
      process a single batch of data.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: test_batch(batch)

      This function contains the logic for a single testing step that will
      process a single batch of data. In this case, it is identical to `validate_batch`.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Current loss value** -- Dictionary containing the loss value for the current batch.
      :rtype: dict



   .. py:method:: infer_batch(batch)

      This function contains the logic for a single inference step.

      :param batch: A tuple containing the input data for the current batch, possibly
                    with labels that are ignored.
      :type batch: tuple

      :returns: **Reconstructed images** -- Tensor containing the reconstructed images for the current batch.
      :rtype: torch.Tensor



   .. py:method:: prepare_inputs(data_dict)
      :staticmethod:


      Convert structured data to tensor format.



