hyrax.models.hyrax_cnn#

Attributes#

Classes#

HyraxCNN

This CNN is designed to work with datasets that are prepared with Hyrax's HSC Data Set class.

Module Contents#

logger[source]#
class HyraxCNN(config, data_sample=None)[source]#

Bases: torch.nn.Module

This CNN is designed to work with datasets that are prepared with Hyrax’s HSC Data Set class.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

config[source]#
conv1[source]#
pool[source]#
conv2[source]#
fc1[source]#
fc2[source]#
fc3[source]#
conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) int[source]#
pool2d_output_size(input_size, kernel_size, stride, padding=0, dilation=1) int[source]#
forward(x)[source]#
train_step(batch)[source]#

This function contains the logic for a single training step. i.e. the contents of the inner loop of a ML training process.

Parameters:

batch (tuple) – A tuple containing the inputs and labels for the current batch.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

static to_tensor(data_dict) tuple[source]#

Does NOT convert to PyTorch Tensors. This works exclusively with numpy data types and returns a tuple of numpy data types.