Source code for hyrax.models.hyrax_autoencoder

# ruff: noqa: D101, D102
import logging

import torch.nn as nn
import torch.nn.functional as F  # noqa N812
import torch.optim as optim
from torchvision.transforms.v2 import CenterCrop

# extra long import here to address a circular import issue
from hyrax.models.model_registry import hyrax_model

[docs] logger = logging.getLogger(__name__)
@hyrax_model
[docs] class HyraxAutoencoder(nn.Module): """ This autoencoder is designed to work with a wide range of image datasets to allow testing. This example model is taken from this `autoenocoder tutorial <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html>`_ The train function has been converted into train_batch for use with pytorch-ignite. """ def __init__(self, config, data_sample=None): super().__init__()
[docs] self.config = config
shape = data_sample.shape logger.debug(f"Found shape: {shape} in data sample, using this to initialize model.") # Unpack the shape of the image (batch_size, num_channels, width, height) # we'll ignore the batch_size during initialization. _, self.num_input_channels, self.image_width, self.image_height = shape
[docs] self.c_hid = self.config["model"]["HyraxAutoencoder"]["base_channel_size"]
[docs] self.latent_dim = self.config["model"]["HyraxAutoencoder"]["latent_dim"]
# Calculate how much our convolutional layers will affect the size of final convolution # Formula evaluated from: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html # # If the number of layers are changed this will need to be rewritten.
[docs] self.conv_end_w = self.conv2d_multi_layer(self.image_width, 3, kernel_size=3, padding=1, stride=2)
[docs] self.conv_end_h = self.conv2d_multi_layer(self.image_height, 3, kernel_size=3, padding=1, stride=2)
self._init_encoder() self._init_decoder()
[docs] def conv2d_multi_layer(self, input_size, num_applications, **kwargs) -> int: for _ in range(num_applications): input_size = self.conv2d_output_size(input_size, **kwargs) return int(input_size)
[docs] def conv2d_output_size(self, input_size, kernel_size, padding=0, stride=1, dilation=1) -> int: # From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1 return int((numerator / stride) + 1)
[docs] def _init_encoder(self): self.encoder = nn.Sequential( nn.Conv2d(self.num_input_channels, self.c_hid, kernel_size=3, padding=1, stride=2), nn.GELU(), nn.Conv2d(self.c_hid, self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1, stride=2), nn.GELU(), nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1, stride=2), nn.GELU(), nn.Flatten(), # Image grid to single feature vector nn.Linear(2 * self.conv_end_h * self.conv_end_w * self.c_hid, self.latent_dim), )
[docs] def _eval_encoder(self, x): return self.encoder(x)
[docs] def _init_decoder(self): self.dec_linear = nn.Sequential( nn.Linear(self.latent_dim, 2 * self.conv_end_h * self.conv_end_w * self.c_hid), nn.GELU() ) self.decoder = nn.Sequential( nn.ConvTranspose2d( 2 * self.c_hid, 2 * self.c_hid, kernel_size=3, output_padding=1, padding=1, stride=2 ), # 4x4 => 8x8 nn.GELU(), nn.Conv2d(2 * self.c_hid, 2 * self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.ConvTranspose2d( 2 * self.c_hid, self.c_hid, kernel_size=3, output_padding=1, padding=1, stride=2 ), # 8x8 => 16x16 nn.GELU(), nn.Conv2d(self.c_hid, self.c_hid, kernel_size=3, padding=1), nn.GELU(), nn.ConvTranspose2d( self.c_hid, self.num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2 ), # 16x16 => 32x32 nn.Tanh(), # The input images is scaled between -1 and 1, so the output has to be bounded as well )
[docs] def _eval_decoder(self, x): x = self.dec_linear(x) x = x.reshape(x.shape[0], -1, self.conv_end_h, self.conv_end_w) x = self.decoder(x) x = CenterCrop(size=(self.image_width, self.image_height))(x) return x
[docs] def forward(self, batch): return self._eval_encoder(batch)
[docs] def train_batch(self, batch): """This function contains the logic for a single training step. i.e. the contents of the inner loop of a ML training process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ z = self._eval_encoder(batch) x_hat = self._eval_decoder(z) loss = F.mse_loss(batch, x_hat, reduction="none") loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0]) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return {"loss": loss.item()}
[docs] def validate_batch(self, batch): """This function contains the logic for a single validation step that will process a single batch of data. i.e. the contents of the inner loop of a ML validation process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ z = self._eval_encoder(batch) x_hat = self._eval_decoder(z) loss = F.mse_loss(batch, x_hat, reduction="none") loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0]) return {"loss": loss.item()}
[docs] def test_batch(self, batch): """This function contains the logic for a single testing step that will process a single batch of data. i.e. the contents of the inner loop of a ML testing process. In this case, it is identical to `validate_batch`. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ z = self._eval_encoder(batch) x_hat = self._eval_decoder(z) loss = F.mse_loss(batch, x_hat, reduction="none") loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0]) return {"loss": loss.item()}
[docs] def infer_batch(self, batch): """This function contains the logic for a single inference step. i.e. the contents of the inner loop of a ML inference process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Reconstructed inputs : torch.Tensor The reconstructed inputs from the autoencoder. """ return self.forward(batch)
@staticmethod
[docs] def prepare_inputs(data_dict) -> tuple: """Extract the image array from the batch dictionary. This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects. Hyrax will convert the returned array to a PyTorch tensor and move it to the appropriate device automatically. Parameters ---------- data_dict : dict The collated batch dictionary produced by the data pipeline. Expected to contain a ``"data"`` key with an ``"image"`` field. Returns ------- image : numpy.ndarray The image array extracted from the batch. """ # Necessary for e2e tests to pass since this function ends up written out as a file and # must stand alone from imports when loaded for inference import numpy as np # noqa: F811 data = data_dict.get("data", {}) image = data.get("image", np.ndarray([])) return image
[docs] def _optimizer(self): return optim.Adam(self.parameters(), lr=1e-3)