# ruff: noqa: D101, D102
# This is a more flexible version of hsc_dcae.py that should
# work with a variety of image sizes and includes a true latent bottleneck
# for better anomaly detection capabilities.
import torch
import torch.nn as nn
import torch.nn.functional as f
# extra long import here to address a circular import issue
from hyrax.models.model_registry import hyrax_model
[docs]
class ArcsinhActivation(nn.Module):
"""Helper module for ImageDCAE to use the arcsinh function"""
[docs]
def forward(self, x):
return torch.arcsinh(x)
@hyrax_model
[docs]
class ImageDCAE(nn.Module):
"""
This is an autoencoder with skipconnections that should work with
arbitarily sized images with arbitrary number of channels.
"""
def __init__(self, config, data_sample=None):
super().__init__()
if data_sample is None:
raise ValueError("data_sample must be provided to ImageDCAE for dynamic sizing.")
# Store input shape for dynamic sizing
# Extract dimensions from input shape
if len(self.input_shape) == 4: # Batch dimension included
self.num_input_channels = self.input_shape[1]
self.image_height, self.image_width = self.input_shape[2], self.input_shape[3]
else: # No batch dimension
self.num_input_channels = self.input_shape[0]
self.image_height, self.image_width = self.input_shape[1], self.input_shape[2]
# Get latent dimension from config (similar to HyraxAutoencoder)
[docs]
self.latent_dim = config["model"].get("latent_dim", 512)
[docs]
self.base_channel_size = config["model"].get("base_channel_size", 32)
# Calculate the size after convolutional layers for the linear bottleneck
[docs]
self.conv_output_size = self._calculate_conv_output_size()
in_channels = self.num_input_channels
# Encoder - using configurable base channel size
[docs]
self.encoder1 = nn.Conv2d(in_channels, self.base_channel_size, kernel_size=3, stride=1, padding=1)
[docs]
self.encoder2 = nn.Conv2d(
self.base_channel_size, self.base_channel_size * 2, kernel_size=3, stride=1, padding=1
)
[docs]
self.encoder3 = nn.Conv2d(
self.base_channel_size * 2, self.base_channel_size * 4, kernel_size=3, stride=1, padding=1
)
[docs]
self.encoder4 = nn.Conv2d(
self.base_channel_size * 4, self.base_channel_size * 8, kernel_size=3, stride=1, padding=1
)
[docs]
self.pool = nn.MaxPool2d(2, 2)
# Latent bottleneck - similar to HyraxAutoencoder
[docs]
self.latent_encoder = nn.Sequential(
nn.Flatten(), nn.Linear(self.conv_output_size, self.latent_dim), nn.GELU()
)
[docs]
self.latent_decoder = nn.Sequential(nn.Linear(self.latent_dim, self.conv_output_size), nn.GELU())
# Decoder - using normal Conv2d with upsampling instead of ConvTranspose2d
# This approach is more flexible for different image sizes
[docs]
self.decoder4 = nn.Conv2d(
self.base_channel_size * 8, self.base_channel_size * 4, kernel_size=3, stride=1, padding=1
)
[docs]
self.decoder3 = nn.Conv2d(
self.base_channel_size * 4, self.base_channel_size * 2, kernel_size=3, stride=1, padding=1
)
[docs]
self.decoder2 = nn.Conv2d(
self.base_channel_size * 2, self.base_channel_size, kernel_size=3, stride=1, padding=1
)
[docs]
self.decoder1 = nn.Conv2d(self.base_channel_size, in_channels, kernel_size=3, stride=1, padding=1)
[docs]
self.activation = nn.GELU() # Better gradients than ReLU
# Configure final activation
final_layer = config["model"].get("final_layer", "identity")
if final_layer == "sigmoid":
self.final_activation = nn.Sigmoid()
elif final_layer == "tanh":
self.final_activation = nn.Tanh()
elif final_layer == "arcsinh":
self.final_activation = ArcsinhActivation()
else:
self.final_activation = nn.Identity()
[docs]
def _calculate_conv_output_size(self):
"""Calculate the output size after all convolutional layers for the linear bottleneck."""
# Simulate the forward pass through conv layers to get the size
h, w = self.image_height, self.image_width
# After 3 pooling operations (each divides by 2)
h = h // 8
w = w // 8
# Final feature map size: (base_channel_size * 8) * h * w
return (self.base_channel_size * 8) * h * w
[docs]
def encode(self, x):
"""Encode input to latent space with skip connections."""
# Encoder with skip connections
x1 = self.activation(self.encoder1(x))
p1 = self.pool(x1)
x2 = self.activation(self.encoder2(p1))
p2 = self.pool(x2)
x3 = self.activation(self.encoder3(p2))
p3 = self.pool(x3)
x4 = self.activation(self.encoder4(p3))
# Store the spatial dimensions for reconstruction
self.encoded_spatial_shape = x4.shape
# Pass through latent bottleneck
latent = self.latent_encoder(x4)
return latent, [x3, x2, x1], x4.shape
[docs]
def decode(self, latent, skip_connections, encoded_shape):
"""Decode from latent space to image with skip connections."""
# Reconstruct from latent space
x = self.latent_decoder(latent)
# Reshape back to convolutional feature map
x = x.reshape(encoded_shape)
# Decoder with skip connections and dynamic upsampling
x = f.interpolate(x, size=skip_connections[0].shape[2:], mode="bilinear", align_corners=False)
x = self.activation(self.decoder4(x) + skip_connections[0])
x = f.interpolate(x, size=skip_connections[1].shape[2:], mode="bilinear", align_corners=False)
x = self.activation(self.decoder3(x) + skip_connections[1])
x = f.interpolate(x, size=skip_connections[2].shape[2:], mode="bilinear", align_corners=False)
x = self.activation(self.decoder2(x) + skip_connections[2])
# Final interpolation to input size
if hasattr(self, "original_size"):
x = f.interpolate(x, size=self.original_size, mode="bilinear", align_corners=False)
x = self.final_activation(self.decoder1(x))
return x
[docs]
def forward(self, x):
"""Forward pass - returns latent representation for anomaly detection."""
# Store original spatial dimensions for decoding
self.original_size = x.shape[2:]
# Encode to latent space
latent, skip_connections, encoded_shape = self.encode(x)
return latent
[docs]
def reconstruct(self, x):
"""Full reconstruction for evaluation and anomaly detection."""
# Dropping labels if present
x = x[0] if isinstance(x, tuple) else x
# Store original spatial dimensions for decoding
self.original_size = x.shape[2:]
# Encode to latent space
latent, skip_connections, encoded_shape = self.encode(x)
# Decode back to image
reconstructed = self.decode(latent, skip_connections, encoded_shape)
return reconstructed
[docs]
def train_step(self, batch):
"""This function contains the logic for a single training step.
Parameters
----------
batch : tuple
A tuple containing the two values the loss function
Returns
-------
Current loss value : dict
Dictionary containing the loss value for the current batch.
"""
data = batch
self.optimizer.zero_grad()
# Store original spatial dimensions for decoding
self.original_size = data.shape[2:]
# Encode to latent space
latent, skip_connections, encoded_shape = self.encode(data)
# Decode back to image
decoded = self.decode(latent, skip_connections, encoded_shape)
# Compute loss
loss = self.criterion(decoded, data)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}
@staticmethod
[docs]
def to_tensor(data_dict):
"""Convert structured data to tensor format."""
data_dict = data_dict["data"]
if "image" in data_dict:
return data_dict["image"]
else:
raise RuntimeError("Data dict did not contain image key.")