hyrax.models.simclr

Classes

NTXentLoss

Normalized Temperature-scaled Cross Entropy Loss. Based on Chen, 2020

PositiveRescale

Transformation Class specifically for ColorJitter to prevent wrong domain during the augmentation

SimCLR

SimCLR model. Implementation based on Chen, 2020

Module Contents

class NTXentLoss(temperature=0.1)[source]

Bases: torch.nn.Module

Normalized Temperature-scaled Cross Entropy Loss. Based on Chen, 2020

Initialize internal Module state, shared by both nn.Module and ScriptModule.

temperature = 0.1[source]
criterion[source]
forward(z_i, z_j)[source]

Forward function of NTXentLoss. Based on Chen, 2020. Loss is calculated from representations from two augmented views of the same batch.

class PositiveRescale(transform)[source]

Transformation Class specifically for ColorJitter to prevent wrong domain during the augmentation

transform[source]
__call__(x)[source]
class SimCLR(config, shape)[source]

Bases: torch.nn.Module

SimCLR model. Implementation based on Chen, 2020

Initialize internal Module state, shared by both nn.Module and ScriptModule.

config[source]
shape[source]
backbone[source]
projection_head[source]
criterion[source]
forward(x)[source]
train_step(x)[source]