hyrax.models.image_dcae

Classes

ArcsinhActivation

Helper module for ImageDCAE to use the arcsinh function

ImageDCAE

This is an autoencoder with skipconnections that should work with

Module Contents

class ArcsinhActivation(*args, **kwargs)[source]

Bases: torch.nn.Module

Helper module for ImageDCAE to use the arcsinh function

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]
class ImageDCAE(config, data_sample=None)[source]

Bases: torch.nn.Module

This is an autoencoder with skipconnections that should work with arbitarily sized images with arbitrary number of channels.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

input_shape[source]
config[source]
latent_dim[source]
base_channel_size[source]
conv_output_size[source]
encoder1[source]
encoder2[source]
encoder3[source]
encoder4[source]
pool[source]
latent_encoder[source]
latent_decoder[source]
decoder4[source]
decoder3[source]
decoder2[source]
decoder1[source]
activation[source]
_calculate_conv_output_size()[source]

Calculate the output size after all convolutional layers for the linear bottleneck.

encode(x)[source]

Encode input to latent space with skip connections.

decode(latent, skip_connections, encoded_shape)[source]

Decode from latent space to image with skip connections.

forward(x)[source]

Forward pass - returns latent representation for anomaly detection.

reconstruct(x)[source]

Full reconstruction for evaluation and anomaly detection.

train_step(batch)[source]

This function contains the logic for a single training step.

Parameters:

batch (tuple) – A tuple containing the two values the loss function

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

_optimizer()[source]

Default optimizer configuration.

static to_tensor(data_dict)[source]

Convert structured data to tensor format.