hyrax.models.image_dcae#
Classes#
Helper module for ImageDCAE to use the arcsinh function |
|
This is an autoencoder with skipconnections that should work with |
Module Contents#
- class ArcsinhActivation(*args: Any, **kwargs: Any)[source]#
Bases:
torch.nn.ModuleHelper module for ImageDCAE to use the arcsinh function
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class ImageDCAE(config, data_sample=None)[source]#
Bases:
torch.nn.ModuleThis is an autoencoder with skipconnections that should work with arbitarily sized images with arbitrary number of channels.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _calculate_conv_output_size()[source]#
Calculate the output size after all convolutional layers for the linear bottleneck.
- decode(latent, skip_connections, encoded_shape)[source]#
Decode from latent space to image with skip connections.
- train_batch(batch)[source]#
This function contains the logic for a single training step.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Current loss value – Dictionary containing the loss value for the current batch.
- Return type:
dict
- validate_batch(batch)[source]#
This function contains the logic for a single validation step that will process a single batch of data.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Current loss value – Dictionary containing the loss value for the current batch.
- Return type:
dict
- test_batch(batch)[source]#
This function contains the logic for a single testing step that will process a single batch of data. In this case, it is identical to validate_batch.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Current loss value – Dictionary containing the loss value for the current batch.
- Return type:
dict
- infer_batch(batch)[source]#
This function contains the logic for a single inference step.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Reconstructed images – Tensor containing the reconstructed images for the current batch.
- Return type:
torch.Tensor