hyrax.models.hyrax_autoencoderv2#

Attributes#

Classes#

ArcsinhActivation

Helper module for HyraxAutoencoderV2 to use the arcsinh function

HyraxAutoencoderV2

This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

Module Contents#

logger[source]#
class ArcsinhActivation(*args: Any, **kwargs: Any)[source]#

Bases: torch.nn.Module

Helper module for HyraxAutoencoderV2 to use the arcsinh function

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#
class HyraxAutoencoderV2(config, data_sample=None)[source]#

Bases: torch.nn.Module

This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

V2 improvements: - Configurable final layer activation - Uses criterion and optimizer from config variables

Initialize internal Module state, shared by both nn.Module and ScriptModule.

config[source]#
c_hid[source]#
latent_dim[source]#
conv_end_w[source]#
conv_end_h[source]#
band_reduction[source]#
conv2d_multi_layer(input_size, num_applications, **kwargs) int[source]#
conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) int[source]#
_init_encoder()[source]#
_eval_encoder(x)[source]#
_init_decoder()[source]#
_eval_decoder(x)[source]#
forward(batch)[source]#
train_batch(batch)[source]#

This function contains the logic for a single training step. i.e. the contents of the inner loop of a ML training process.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

validate_batch(batch)[source]#

This function contains the logic for a single validation step that will process a single batch of data. i.e. the contents of the inner loop of a ML validation process.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

test_batch(batch)[source]#

This function contains the logic for a single testing step that will process a single batch of data. i.e. the contents of the inner loop of a ML testing process. In this case, it is identical to validate_batch.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

infer_batch(batch)[source]#

This function contains the logic for a single inference step. i.e. the contents of the inner loop of a ML inference process.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Reconstructed outputs – The reconstructed outputs from the autoencoder.

Return type:

torch.Tensor

static prepare_inputs(data_dict) tuple[source]#

This function converts structured data to the input tensor we need to run

Parameters:

data_dict (dict) – The dictionary returned from our data source