hyrax.models.hyrax_autoencoderv2#

Attributes#

Classes#

ArcsinhActivation

Helper module for HyraxAutoencoderV2 to use the arcsinh function

HyraxAutoencoderV2

This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

Module Contents#

logger[source]#
class ArcsinhActivation(*args: Any, **kwargs: Any)[source]#

Bases: torch.nn.Module

Helper module for HyraxAutoencoderV2 to use the arcsinh function

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#
class HyraxAutoencoderV2(config, data_sample=None)[source]#

Bases: torch.nn.Module

This is tweaked version of HyraxAutoencoder and is designed to work with a wide range of imaging datasets.

V2 improvements: - Configurable final layer activation - Uses criterion and optimizer from config variables

Initialize internal Module state, shared by both nn.Module and ScriptModule.

config[source]#
c_hid[source]#
latent_dim[source]#
conv_end_w[source]#
conv_end_h[source]#
band_reduction[source]#
conv2d_multi_layer(input_size, num_applications, **kwargs) int[source]#
conv2d_output_size(input_size, kernel_size, padding=0, stride=1, dilation=1) int[source]#
_init_encoder()[source]#
_eval_encoder(x)[source]#
_init_decoder()[source]#
_eval_decoder(x)[source]#
forward(batch)[source]#
train_batch(batch)[source]#

This function contains the logic for a single training step. i.e. the contents of the inner loop of a ML training process.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

validate_batch(batch)[source]#

This function contains the logic for a single validation step that will process a single batch of data. i.e. the contents of the inner loop of a ML validation process.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

test_batch(batch)[source]#

This function contains the logic for a single testing step that will process a single batch of data. i.e. the contents of the inner loop of a ML testing process. In this case, it is identical to validate_batch.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Current loss value – Dictionary containing the loss value for the current batch.

Return type:

dict

infer_batch(batch)[source]#

This function contains the logic for a single inference step. i.e. the contents of the inner loop of a ML inference process.

Parameters:

batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.

Returns:

Reconstructed outputs – The reconstructed outputs from the autoencoder.

Return type:

torch.Tensor

static prepare_inputs(data_dict) tuple[source]#

Extract the image array from the batch dictionary.

This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects.

Hyrax will convert the returned array to a PyTorch tensor and move it to the appropriate device automatically.

Parameters:

data_dict (dict) – The collated batch dictionary produced by the data pipeline. Expected to contain a "data" key with an "image" field.

Returns:

image – The image array extracted from the batch.

Return type:

numpy.ndarray