Source code for hyrax.models.image_dcae

# ruff: noqa: D101, D102

# This is a more flexible version of hsc_dcae.py that should
# work with a variety of image sizes and includes a true latent bottleneck
# for better anomaly detection capabilities.

import torch
import torch.nn as nn
import torch.nn.functional as f

# extra long import here to address a circular import issue
from hyrax.models.model_registry import hyrax_model


[docs] class ArcsinhActivation(nn.Module): """Helper module for ImageDCAE to use the arcsinh function"""
[docs] def forward(self, x): return torch.arcsinh(x)
@hyrax_model
[docs] class ImageDCAE(nn.Module): """ This is an autoencoder with skipconnections that should work with arbitarily sized images with arbitrary number of channels. """ def __init__(self, config, data_sample=None): super().__init__() if data_sample is None: raise ValueError("data_sample must be provided to ImageDCAE for dynamic sizing.") # Store input shape for dynamic sizing
[docs] self.input_shape = data_sample.shape
[docs] self.config = config
# Extract dimensions from input shape if len(self.input_shape) == 4: # Batch dimension included self.num_input_channels = self.input_shape[1] self.image_height, self.image_width = self.input_shape[2], self.input_shape[3] else: # No batch dimension self.num_input_channels = self.input_shape[0] self.image_height, self.image_width = self.input_shape[1], self.input_shape[2] # Get latent dimension from config (similar to HyraxAutoencoder)
[docs] self.latent_dim = config["model"]["ImageDCAE"]["latent_dim"]
[docs] self.base_channel_size = config["model"]["ImageDCAE"]["base_channel_size"]
# Calculate the size after convolutional layers for the linear bottleneck
[docs] self.conv_output_size = self._calculate_conv_output_size()
in_channels = self.num_input_channels # Encoder - using configurable base channel size
[docs] self.encoder1 = nn.Conv2d(in_channels, self.base_channel_size, kernel_size=3, stride=1, padding=1)
[docs] self.encoder2 = nn.Conv2d( self.base_channel_size, self.base_channel_size * 2, kernel_size=3, stride=1, padding=1 )
[docs] self.encoder3 = nn.Conv2d( self.base_channel_size * 2, self.base_channel_size * 4, kernel_size=3, stride=1, padding=1 )
[docs] self.encoder4 = nn.Conv2d( self.base_channel_size * 4, self.base_channel_size * 8, kernel_size=3, stride=1, padding=1 )
[docs] self.pool = nn.MaxPool2d(2, 2)
# Latent bottleneck - similar to HyraxAutoencoder
[docs] self.latent_encoder = nn.Sequential( nn.Flatten(), nn.Linear(self.conv_output_size, self.latent_dim), nn.GELU() )
[docs] self.latent_decoder = nn.Sequential(nn.Linear(self.latent_dim, self.conv_output_size), nn.GELU())
# Decoder - using normal Conv2d with upsampling instead of ConvTranspose2d # This approach is more flexible for different image sizes
[docs] self.decoder4 = nn.Conv2d( self.base_channel_size * 8, self.base_channel_size * 4, kernel_size=3, stride=1, padding=1 )
[docs] self.decoder3 = nn.Conv2d( self.base_channel_size * 4, self.base_channel_size * 2, kernel_size=3, stride=1, padding=1 )
[docs] self.decoder2 = nn.Conv2d( self.base_channel_size * 2, self.base_channel_size, kernel_size=3, stride=1, padding=1 )
[docs] self.decoder1 = nn.Conv2d(self.base_channel_size, in_channels, kernel_size=3, stride=1, padding=1)
[docs] self.activation = nn.GELU() # Better gradients than ReLU
# Configure final activation final_layer = config["model"]["ImageDCAE"]["final_layer"] if final_layer == "sigmoid": self.final_activation = nn.Sigmoid() elif final_layer == "tanh": self.final_activation = nn.Tanh() elif final_layer == "arcsinh": self.final_activation = ArcsinhActivation() else: self.final_activation = nn.Identity()
[docs] def _calculate_conv_output_size(self): """Calculate the output size after all convolutional layers for the linear bottleneck.""" # Simulate the forward pass through conv layers to get the size h, w = self.image_height, self.image_width # After 3 pooling operations (each divides by 2) h = h // 8 w = w // 8 # Final feature map size: (base_channel_size * 8) * h * w return (self.base_channel_size * 8) * h * w
[docs] def encode(self, x): """Encode input to latent space with skip connections.""" # Encoder with skip connections x1 = self.activation(self.encoder1(x)) p1 = self.pool(x1) x2 = self.activation(self.encoder2(p1)) p2 = self.pool(x2) x3 = self.activation(self.encoder3(p2)) p3 = self.pool(x3) x4 = self.activation(self.encoder4(p3)) # Store the spatial dimensions for reconstruction self.encoded_spatial_shape = x4.shape # Pass through latent bottleneck latent = self.latent_encoder(x4) return latent, [x3, x2, x1], x4.shape
[docs] def decode(self, latent, skip_connections, encoded_shape): """Decode from latent space to image with skip connections.""" # Reconstruct from latent space x = self.latent_decoder(latent) # Reshape back to convolutional feature map x = x.reshape(encoded_shape) # Decoder with skip connections and dynamic upsampling x = f.interpolate(x, size=skip_connections[0].shape[2:], mode="bilinear", align_corners=False) x = self.activation(self.decoder4(x) + skip_connections[0]) x = f.interpolate(x, size=skip_connections[1].shape[2:], mode="bilinear", align_corners=False) x = self.activation(self.decoder3(x) + skip_connections[1]) x = f.interpolate(x, size=skip_connections[2].shape[2:], mode="bilinear", align_corners=False) x = self.activation(self.decoder2(x) + skip_connections[2]) # Final interpolation to input size if hasattr(self, "original_size"): x = f.interpolate(x, size=self.original_size, mode="bilinear", align_corners=False) x = self.final_activation(self.decoder1(x)) return x
[docs] def forward(self, x): """Forward pass - returns latent representation for anomaly detection.""" # Store original spatial dimensions for decoding self.original_size = x.shape[2:] # Encode to latent space latent, skip_connections, encoded_shape = self.encode(x) return latent
[docs] def reconstruct(self, x): """Full reconstruction for evaluation and anomaly detection.""" # Dropping labels if present x = x[0] if isinstance(x, tuple) else x # Store original spatial dimensions for decoding self.original_size = x.shape[2:] # Encode to latent space latent, skip_connections, encoded_shape = self.encode(x) # Decode back to image reconstructed = self.decode(latent, skip_connections, encoded_shape) return reconstructed
[docs] def train_batch(self, batch): """This function contains the logic for a single training step. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ data = batch self.optimizer.zero_grad() # Store original spatial dimensions for decoding self.original_size = data.shape[2:] # Encode to latent space latent, skip_connections, encoded_shape = self.encode(data) # Decode back to image decoded = self.decode(latent, skip_connections, encoded_shape) # Compute loss loss = self.criterion(decoded, data) loss.backward() self.optimizer.step() return {"loss": loss.item()}
[docs] def validate_batch(self, batch): """This function contains the logic for a single validation step that will process a single batch of data. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ data = batch # Store original spatial dimensions for decoding self.original_size = data.shape[2:] # Encode to latent space latent, skip_connections, encoded_shape = self.encode(data) # Decode back to image decoded = self.decode(latent, skip_connections, encoded_shape) # Compute loss loss = self.criterion(decoded, data) return {"loss": loss.item()}
[docs] def test_batch(self, batch): """This function contains the logic for a single testing step that will process a single batch of data. In this case, it is identical to `validate_batch`. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ data = batch # Store original spatial dimensions for decoding self.original_size = data.shape[2:] # Encode to latent space latent, skip_connections, encoded_shape = self.encode(data) # Decode back to image decoded = self.decode(latent, skip_connections, encoded_shape) # Compute loss loss = self.criterion(decoded, data) return {"loss": loss.item()}
[docs] def infer_batch(self, batch): """This function contains the logic for a single inference step. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Reconstructed images : torch.Tensor Tensor containing the reconstructed images for the current batch. """ return self.forward(batch)
@staticmethod
[docs] def prepare_inputs(data_dict): """Extract the image array from the batch dictionary. This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects. Hyrax will convert the returned array to a PyTorch tensor and move it to the appropriate device automatically. Parameters ---------- data_dict : dict The collated batch dictionary produced by the data pipeline. Expected to contain a ``"data"`` key with an ``"image"`` field. Returns ------- image : numpy.ndarray The image array extracted from the batch. """ data_dict = data_dict["data"] if "image" in data_dict: return data_dict["image"] else: raise RuntimeError("Data dict did not contain image key.")