# ruff: noqa: D101, D102
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
from torch import Tensor
from torchvision.transforms.v2 import CenterCrop
# extra long import here to address a circular import issue
from hyrax.models.model_registry import hyrax_model
[docs]
logger = logging.getLogger(__name__)
[docs]
class ArcsinhActivation(nn.Module):
"""Helper module for HyraxAutoencoderV2 to use the arcsinh function"""
[docs]
def forward(self, x):
return torch.arcsinh(x)
@hyrax_model
[docs]
class HyraxAutoencoderV2(nn.Module):
"""
This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.
V2 improvements:
- Configurable final layer activation
- Uses criterion and optimizer from config variables
"""
def __init__(self, config, data_sample=None):
super().__init__()
shape = data_sample.shape
logger.debug(f"Found shape: {shape} in data sample, using this to initialize model.")
self.num_input_channels, self.image_width, self.image_height = shape
[docs]
self.c_hid = self.config["model"]["HyraxAutoencoderV2"]["base_channel_size"]
[docs]
self.latent_dim = self.config["model"]["HyraxAutoencoderV2"]["latent_dim"]
# Calculate how much our convolutional layers will affect the size of final convolution
# Formula evaluated from: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
#
# If the number of layers are changed this will need to be rewritten.
[docs]
self.conv_end_w = self.conv2d_multi_layer(self.image_width, 3, kernel_size=3, padding=1, stride=2)
[docs]
self.conv_end_h = self.conv2d_multi_layer(self.image_height, 3, kernel_size=3, padding=1, stride=2)
self._init_encoder()
self._init_decoder()
[docs]
def conv2d_multi_layer(self, input_size, num_applications, **kwargs) -> int:
for _ in range(num_applications):
input_size = self.conv2d_output_size(input_size, **kwargs)
return int(input_size)
[docs]
def conv2d_output_size(self, input_size, kernel_size, padding=0, stride=1, dilation=1) -> int:
# From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1
return int((numerator / stride) + 1)
[docs]
def _init_encoder(self):
self.encoder = nn.Sequential(
nn.Conv2d(self.num_input_channels, self.c_hid, kernel_size=3, padding=1, stride=2),
nn.GELU(),
nn.Conv2d(self.c_hid, self.c_hid, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1, stride=2),
nn.GELU(),
nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1, stride=2),
nn.GELU(),
nn.Flatten(), # Image grid to single feature vector
nn.Linear(2 * self.conv_end_h * self.conv_end_w * self.c_hid, self.latent_dim),
)
[docs]
def _eval_encoder(self, x):
return self.encoder(x)
[docs]
def _init_decoder(self):
self.dec_linear = nn.Sequential(
nn.Linear(self.latent_dim, 2 * self.conv_end_h * self.conv_end_w * self.c_hid), nn.GELU()
)
# Configure final activation
# Should be set to the same value as ["dataset"]["transform"] in most cases
final_layer_value = self.config["model"]["HyraxAutoencoderV2"]["final_layer"]
final_layer = final_layer_value if final_layer_value else "tanh"
if final_layer == "sigmoid":
self.final_activation = nn.Sigmoid()
elif final_layer == "tanh":
self.final_activation = nn.Tanh()
elif final_layer == "arcsinh":
self.final_activation = ArcsinhActivation()
elif final_layer == "identity":
self.final_activation = nn.Identity()
else:
self.final_activation = nn.Tanh()
self.decoder = nn.Sequential(
nn.ConvTranspose2d(
2 * self.c_hid, 2 * self.c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
), # 4x4 => 8x8
nn.GELU(),
nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1),
nn.GELU(),
nn.ConvTranspose2d(
2 * self.c_hid, self.c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
), # 8x8 => 16x16
nn.GELU(),
nn.Conv2d(self.c_hid, self.c_hid, kernel_size=3, padding=1),
nn.GELU(),
nn.ConvTranspose2d(
self.c_hid, self.num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2
), # 16x16 => 32x32
self.final_activation,
)
[docs]
def _eval_decoder(self, x):
x = self.dec_linear(x)
x = x.reshape(x.shape[0], -1, self.conv_end_h, self.conv_end_w)
x = self.decoder(x)
x = CenterCrop(size=(self.image_width, self.image_height))(x)
return x
[docs]
def forward(self, batch):
return self._eval_encoder(batch)
[docs]
def train_step(self, batch):
"""This function contains the logic for a single training step. i.e. the
contents of the inner loop of a ML training process.
Parameters
----------
batch : tuple
A tuple containing the inputs and labels for the current batch.
Returns
-------
Current loss value : dict
Dictionary containing the loss value for the current batch.
"""
x = batch
z = self._eval_encoder(x)
x_hat = self._eval_decoder(z)
# Configurable band reduction strategy
band_reduction = self.config["criterion"].get("band_loss_reduction", "mean")
# The loss averaging strategy here is different from v1 which averages
# over only the batch dimension. Here we always average over both batch
# and spaital dimensions; so as the loss-value is not impacted by image size.
if band_reduction == "sum":
# Sum across bands, mean over spatial dims and batch
# More channels will result in larger loss values
# but MIGHT result in better popping out of bad reconstruction
# in a single band/channel
criterion_cls = type(self.criterion)
loss = criterion_cls(reduction="none")(x_hat, x)
loss = loss.sum(dim=1).mean()
elif band_reduction == "mean":
# Default: Mean over all dimensions (batch,channel,spatial)
loss = self.criterion(x_hat, x)
else:
raise ValueError(
f"band_loss_reduction:{band_reduction} not supported by HyraxAutoencoderV2.\
Current supported options are sum and mean (default)"
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}
@staticmethod
[docs]
def to_tensor(data_dict) -> tuple[Tensor]:
"""This function converts structured data to the input tensor we need to run
Parameters
----------
data_dict : dict
The dictionary returned from our data source
"""
if "image" in data_dict:
return data_dict["image"]
else:
raise RuntimeError("Data dict did not contain image key.")