hyrax.models.image_dcae#

Classes#

ArcsinhActivation

Helper module for ImageDCAE to use the arcsinh function

ImageDCAE

This is an autoencoder with skipconnections that should work with

Module Contents#

class ArcsinhActivation(*args: Any, **kwargs: Any)[source]#

Bases: torch.nn.Module

Helper module for ImageDCAE to use the arcsinh function

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#
class ImageDCAE(config, data_sample=None)[source]#

Bases: torch.nn.Module

This is an autoencoder with skipconnections that should work with arbitarily sized images with arbitrary number of channels.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

input_shape[source]#
config[source]#
latent_dim[source]#
base_channel_size[source]#
conv_output_size[source]#
encoder1[source]#
encoder2[source]#
encoder3[source]#
encoder4[source]#
pool[source]#
latent_encoder[source]#
latent_decoder[source]#
decoder4[source]#
decoder3[source]#
decoder2[source]#
decoder1[source]#
activation[source]#
_calculate_conv_output_size()[source]#

Calculate the output size after all convolutional layers for the linear bottleneck.

encode(x)[source]#

Encode input to latent space with skip connections.

decode(latent, skip_connections, encoded_shape)[source]#

Decode from latent space to image with skip connections.

forward(x)[source]#

Forward pass - returns latent representation for anomaly detection.

reconstruct(x)[source]#

Full reconstruction for evaluation and anomaly detection.

train_batch(batch)[source]#

This function contains the logic for a single training step.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

validate_batch(batch)[source]#

This function contains the logic for a single validation step that will process a single batch of data.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

test_batch(batch)[source]#

This function contains the logic for a single testing step that will process a single batch of data. In this case, it is identical to validate_batch.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

infer_batch(batch)[source]#

This function contains the logic for a single inference step.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Reconstructed images – Tensor containing the reconstructed images for the current batch.

Return type:

torch.Tensor

static prepare_inputs(data_dict)[source]#

Convert structured data to tensor format.