Source code for hyrax.models.hsc_autoencoder

# ruff: noqa: D101, D102


import torch.nn as nn

# extra long import here to address a circular import issue
from hyrax.models.model_registry import hyrax_model


@hyrax_model
[docs] class HSCAutoencoder(nn.Module): # These shapes work with [3,258,258] inputs """ This autoencoder is designed to work with datasets that are prepared with Hyrax's HSC Data Set class. """ def __init__(self, config, data_sample=None): super().__init__() # Encoder
[docs] self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.ReLU(), )
# Decoder
[docs] self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=4, output_padding=1), nn.Sigmoid(), # Output pixel values between 0 and 1 )
[docs] self.config = config
[docs] def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded
[docs] def train_batch(self, batch): """ This function contains the logic for a single training step. i.e. the contents of the inner loop of a ML training process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ data = batch[0] self.optimizer.zero_grad() decoded = self.forward(data) loss = self.criterion(decoded, data) loss.backward() self.optimizer.step() return {"loss": loss.item()}
[docs] def validate_batch(self, batch): """ This function contains the logic for a single validation step that will process a single batch of data. i.e. the contents of the inner loop of a ML validation process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ data = batch[0] decoded = self.forward(data) loss = self.criterion(decoded, data) return {"loss": loss.item()}
[docs] def test_batch(self, batch): """ This function contains the logic for a single testing step that will process a single batch of data. i.e. the contents of the inner loop of a ML testing process. In this case, it is identical to `validate_batch`. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ data = batch[0] decoded = self.forward(data) loss = self.criterion(decoded, data) return {"loss": loss.item()}
[docs] def infer_batch(self, batch): """ This function contains the logic for a single inference step that will process a single batch of data. i.e. the contents of the inner loop of a ML inference process. Parameters ---------- batch : tuple A tuple containing the input data for the current batch, possibly with labels that are ignored. Returns ------- Reconstructed outputs : torch.Tensor The reconstructed outputs from the autoencoder. """ data = batch[0] return self.forward(data)