Source code for hyrax.models.hyrax_cnn

# ruff: noqa: D101, D102

# This example model is taken from the PyTorch CIFAR10 tutorial:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa N812

from .model_registry import hyrax_model

[docs] logger = logging.getLogger(__name__)
@hyrax_model
[docs] class HyraxCNN(nn.Module): """ This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class. """ def __init__(self, config, data_sample=None): super().__init__()
[docs] self.config = config
if data_sample is None: raise ValueError("A `data_sample` must be provided to HyraxCNN for dynamic sizing.") image_sample = data_sample[0] # Unpack the shape of the image (batch_size, num_channels, width, height) # we'll ignore the batch_size during initialization. _, self.num_input_channels, self.image_width, self.image_height = image_sample.shape hidden_channels_1 = 6 hidden_channels_2 = 16 # Calculate how much our convolutional layers and pooling will affect # the size of final convolution. # # If the number of layers are changed this will need to be rewritten. conv1_end_w = self.conv2d_output_size(self.image_width, kernel_size=5) conv1_end_h = self.conv2d_output_size(self.image_height, kernel_size=5) pool1_end_w = self.pool2d_output_size(conv1_end_w, kernel_size=2, stride=2) pool1_end_h = self.pool2d_output_size(conv1_end_h, kernel_size=2, stride=2) conv2_end_w = self.conv2d_output_size(pool1_end_w, kernel_size=5) conv2_end_h = self.conv2d_output_size(pool1_end_h, kernel_size=5) pool2_end_w = self.pool2d_output_size(conv2_end_w, kernel_size=2, stride=2) pool2_end_h = self.pool2d_output_size(conv2_end_h, kernel_size=2, stride=2)
[docs] self.conv1 = nn.Conv2d(self.num_input_channels, hidden_channels_1, 5)
[docs] self.pool = nn.MaxPool2d(2, 2)
[docs] self.conv2 = nn.Conv2d(hidden_channels_1, hidden_channels_2, 5)
[docs] self.fc1 = nn.Linear(hidden_channels_2 * pool2_end_h * pool2_end_w, 120)
[docs] self.fc2 = nn.Linear(120, 84)
[docs] self.fc3 = nn.Linear(84, self.config["model"]["HyraxCNN"]["output_classes"])
[docs] def conv2d_output_size(self, input_size, kernel_size, padding=0, stride=1, dilation=1) -> int: # From https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1 return int((numerator / stride) + 1)
[docs] def pool2d_output_size(self, input_size, kernel_size, stride, padding=0, dilation=1) -> int: # From https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html numerator = input_size + 2 * padding - dilation * (kernel_size - 1) - 1 return int((numerator / stride) + 1)
[docs] def forward(self, x): x, _ = x # Unpack data and ignore labels x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
[docs] def train_batch(self, batch): """This function contains the logic for a single training step that will process a single batch of data. i.e. the contents of the inner loop of a ML training process. Parameters ---------- batch : tuple A tuple containing the inputs and labels for the current batch. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ _, labels = batch self.optimizer.zero_grad() outputs = self(batch) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() return {"loss": loss.item()}
[docs] def validate_batch(self, batch): """This function contains the logic for a single validation step that will process a single batch of data. i.e. the contents of the inner loop of a ML validation process. In this case it is identical to `test_batch`. Parameters ---------- batch : tuple A tuple containing the inputs and labels for the current batch. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ _, labels = batch outputs = self(batch) loss = self.criterion(outputs, labels) return {"loss": loss.item()}
[docs] def test_batch(self, batch): """This function contains the logic for a single testing step that will process a single batch of data. i.e. the contents of the inner loop of a ML testing process. In this case, it is identical to `validate_batch`. Parameters ---------- batch : tuple A tuple containing the inputs and labels for the current batch. Returns ------- Current loss value : dict Dictionary containing the loss value for the current batch. """ _, labels = batch outputs = self(batch) loss = self.criterion(outputs, labels) return {"loss": loss.item()}
[docs] def infer_batch(self, batch): """This function contains the logic for a single inference step that will process a single batch of data. i.e. the contents of the inner loop of a ML inference process. Parameters ---------- batch : tuple A tuple containing the inputs and labels for the current batch. Returns ------- Model outputs : Tensor Tensor containing the model outputs for the current batch. """ return self(batch)
@staticmethod
[docs] def prepare_inputs(data_dict) -> tuple: """Extract image and label arrays from the batch dictionary. This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects. Hyrax will convert the returned arrays to PyTorch tensors and move them to the appropriate device automatically. Parameters ---------- data_dict : dict The collated batch dictionary produced by the data pipeline. Expected to contain a ``"data"`` key with ``"image"`` and optionally ``"label"`` fields. Returns ------- inputs : tuple of numpy.ndarray A tuple of ``(image, label)`` as float32 and int64 arrays respectively. """ import numpy as np if "data" not in data_dict: raise RuntimeError("Unable to find `data` key in data_dict") data = data_dict["data"] image = np.asarray(data["image"], dtype=np.float32) label = np.asarray(data.get("label", []), dtype=np.int64) return (image, label)