Source code for hyrax.models.hyrax_autoencoderv2

# ruff: noqa: D101, D102
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa N812
from torchvision.transforms.v2 import CenterCrop

# extra long import here to address a circular import issue
from hyrax.models.model_registry import hyrax_model

[docs] logger = logging.getLogger(__name__)
[docs] class ArcsinhActivation(nn.Module): """Helper module for HyraxAutoencoderV2 to use the arcsinh function"""
[docs] def forward(self, x): return torch.arcsinh(x)
@hyrax_model
[docs] class HyraxAutoencoderV2(nn.Module): """ This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets. V2 improvements: - Configurable final layer activation - Uses criterion and optimizer from config variables """ def __init__(self, config, data_sample=None): super().__init__()
[docs] self.config = config
shape = data_sample.shape logger.debug(f"Found shape: {shape} in data sample, using this to initialize model.") # Unpack the shape of the image (batch_size, num_channels, width, height) # we'll ignore the batch_size during initialization. _, self.num_input_channels, self.image_width, self.image_height = shape
[docs] self.c_hid = self.config["model"]["HyraxAutoencoderV2"]["base_channel_size"]
[docs] self.latent_dim = self.config["model"]["HyraxAutoencoderV2"]["latent_dim"]
# Calculate how much our convolutional layers will affect the size of final convolution # Formula evaluated from: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html # # If the number of layers are changed this will need to be rewritten.
[docs] self.conv_end_w = self.conv2d_multi_layer(self.image_width, 3, kernel_size=3, padding=1, stride=2)
[docs] self.conv_end_h = self.conv2d_multi_layer(self.image_height, 3, kernel_size=3, padding=1, stride=2)
self._init_encoder() self._init_decoder() # Configurable band reduction strategy
[docs] self.band_reduction = self.config["criterion"]["band_loss_reduction"]
[docs] def conv2d_multi_layer(self, input_size, num_applications, **kwargs) -> int: for _ in range(num_applications): input_size = self.conv2d_output_size(input_size, **kwargs) return int(input_size)
[docs] def conv2d_output_size(self, input_size, kernel_size, padding=0, stride=1, dilation=1) -> int: # From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1 return int((numerator / stride) + 1)
[docs] def _init_encoder(self): self.encoder = nn.Sequential( nn.Conv2d(self.num_input_channels, self.c_hid, kernel_size=3, padding=1, stride=2), nn.GELU(), nn.Conv2d(self.c_hid, self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1, stride=2), nn.GELU(), nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1, stride=2), nn.GELU(), nn.Flatten(), # Image grid to single feature vector nn.Linear(2 * self.conv_end_h * self.conv_end_w * self.c_hid, self.latent_dim), )
[docs] def _eval_encoder(self, x): return self.encoder(x)
[docs] def _init_decoder(self): self.dec_linear = nn.Sequential( nn.Linear(self.latent_dim, 2 * self.conv_end_h * self.conv_end_w * self.c_hid), nn.GELU() ) # Configure final activation # Should be set to the same value as ["dataset"]["transform"] in most cases final_layer_value = self.config["model"]["HyraxAutoencoderV2"]["final_layer"] final_layer = final_layer_value if final_layer_value else "tanh" if final_layer == "sigmoid": self.final_activation = nn.Sigmoid() elif final_layer == "tanh": self.final_activation = nn.Tanh() elif final_layer == "arcsinh": self.final_activation = ArcsinhActivation() elif final_layer == "identity": self.final_activation = nn.Identity() else: self.final_activation = nn.Tanh() self.decoder = nn.Sequential( nn.ConvTranspose2d( 2 * self.c_hid, 2 * self.c_hid, kernel_size=3, output_padding=1, padding=1, stride=2 ), # 4x4 => 8x8 nn.GELU(), nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.ConvTranspose2d( 2 * self.c_hid, self.c_hid, kernel_size=3, output_padding=1, padding=1, stride=2 ), # 8x8 => 16x16 nn.GELU(), nn.Conv2d(self.c_hid, self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.ConvTranspose2d( self.c_hid, self.num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2 ), # 16x16 => 32x32 self.final_activation, )
[docs] def _eval_decoder(self, x): x = self.dec_linear(x) x = x.reshape(x.shape[0], -1, self.conv_end_h, self.conv_end_w) x = self.decoder(x) x = CenterCrop(size=(self.image_width, self.image_height))(x) return x
[docs] def forward(self, batch): return self._eval_encoder(batch)
[docs] def train_batch(self, batch): """This function contains the logic for a single training step. i.e. the contents of the inner loop of a ML training process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ x = batch z = self._eval_encoder(x) x_hat = self._eval_decoder(z) # The loss averaging strategy here is different from v1 which averages # over only the batch dimension. Here we always average over both batch # and spaital dimensions; so as the loss-value is not impacted by image size. if self.band_reduction == "sum": # Sum across bands, mean over spatial dims and batch # More channels will result in larger loss values # but MIGHT result in better popping out of bad reconstruction # in a single band/channel criterion_cls = type(self.criterion) loss = criterion_cls(reduction="none")(x_hat, x) loss = loss.sum(dim=1).mean() elif self.band_reduction == "mean": # Default: Mean over all dimensions (batch,channel,spatial) loss = self.criterion(x_hat, x) else: raise ValueError( f"band_loss_reduction:{self.band_reduction} not supported by HyraxAutoencoderV2.\ Current supported options are sum and mean (default)" ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return {"loss": loss.item()}
[docs] def validate_batch(self, batch): """This function contains the logic for a single validation step that will process a single batch of data. i.e. the contents of the inner loop of a ML validation process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ x = batch z = self._eval_encoder(x) x_hat = self._eval_decoder(z) if self.band_reduction == "sum": criterion_cls = type(self.criterion) loss = criterion_cls(reduction="none")(x_hat, x) loss = loss.sum(dim=1).mean() elif self.band_reduction == "mean": loss = self.criterion(x_hat, x) else: raise ValueError( f"band_loss_reduction:{self.band_reduction} not supported by HyraxAutoencoderV2.\ Current supported options are sum and mean (default)" ) return {"loss": loss.item()}
[docs] def test_batch(self, batch): """This function contains the logic for a single testing step that will process a single batch of data. i.e. the contents of the inner loop of a ML testing process. In this case, it is identical to `validate_batch`. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ x = batch z = self._eval_encoder(x) x_hat = self._eval_decoder(z) if self.band_reduction == "sum": criterion_cls = type(self.criterion) loss = criterion_cls(reduction="none")(x_hat, x) loss = loss.sum(dim=1).mean() elif self.band_reduction == "mean": loss = self.criterion(x_hat, x) else: raise ValueError( f"band_loss_reduction:{self.band_reduction} not supported by HyraxAutoencoderV2.\ Current supported options are sum and mean (default)" ) return {"loss": loss.item()}
[docs] def infer_batch(self, batch): """This function contains the logic for a single inference step. i.e. the contents of the inner loop of a ML inference process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Reconstructed outputs : torch.Tensor The reconstructed outputs from the autoencoder. """ return self.forward(batch)
@staticmethod
[docs] def prepare_inputs(data_dict) -> tuple: """Extract the image array from the batch dictionary. This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects. Hyrax will convert the returned array to a PyTorch tensor and move it to the appropriate device automatically. Parameters ---------- data_dict : dict The collated batch dictionary produced by the data pipeline. Expected to contain a ``"data"`` key with an ``"image"`` field. Returns ------- image : numpy.ndarray The image array extracted from the batch. """ if "data" not in data_dict: raise RuntimeError("Unable to find `data` key in data_dict") data_dict = data_dict["data"] if "image" in data_dict: return data_dict["image"] else: raise RuntimeError("Data dict did not contain image key.")