hyrax.models.simclr
===================

.. py:module:: hyrax.models.simclr


Classes
-------

.. autoapisummary::

   hyrax.models.simclr.NTXentLoss
   hyrax.models.simclr.PositiveRescale
   hyrax.models.simclr.SimCLR


Module Contents
---------------

.. py:class:: NTXentLoss(temperature=0.1)

   Bases: :py:obj:`torch.nn.Module`


   Normalized Temperature-scaled Cross Entropy Loss. Based on Chen, 2020

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: temperature
      :value: 0.1



   .. py:attribute:: criterion


   .. py:method:: forward(z_i, z_j)

      Forward function of NTXentLoss. Based on Chen, 2020.
      Loss is calculated from representations from two augmented views of the same batch.



.. py:class:: PositiveRescale(transform)

   Transformation Class specifically for ColorJitter to prevent wrong domain during the augmentation


   .. py:attribute:: transform


   .. py:method:: __call__(x)


.. py:class:: SimCLR(config, shape)

   Bases: :py:obj:`torch.nn.Module`


   SimCLR model. Implementation based on Chen, 2020

   Initialize internal Module state, shared by both nn.Module and ScriptModule.


   .. py:attribute:: config


   .. py:attribute:: shape


   .. py:attribute:: backbone


   .. py:attribute:: projection_head


   .. py:attribute:: criterion


   .. py:method:: forward(x)


   .. py:method:: train_batch(x)


   .. py:method:: validate_batch(x)


   .. py:method:: test_batch(x)


   .. py:method:: infer_batch(x)

      Function to run inference on a batch of data.

      :param x: Input tensor of shape (batch_size, channels, height, width).
      :type x: torch.Tensor

      :returns: Output tensor of shape (batch_size, projection_dimension).
      :rtype: torch.Tensor



