hyrax.models.image_dcae#
Classes#
Helper module for ImageDCAE to use the arcsinh function |
|
This is an autoencoder with skipconnections that should work with |
Module Contents#
- class ArcsinhActivation(*args: Any, **kwargs: Any)[source]#
Bases:
torch.nn.ModuleHelper module for ImageDCAE to use the arcsinh function
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class ImageDCAE(config, data_sample=None)[source]#
Bases:
torch.nn.ModuleThis is an autoencoder with skipconnections that should work with arbitarily sized images with arbitrary number of channels.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _calculate_conv_output_size()[source]#
Calculate the output size after all convolutional layers for the linear bottleneck.
- decode(latent, skip_connections, encoded_shape)[source]#
Decode from latent space to image with skip connections.
- train_batch(batch)[source]#
This function contains the logic for a single training step.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Current loss value – Dictionary containing the loss value for the current batch.
- Return type:
dict
- validate_batch(batch)[source]#
This function contains the logic for a single validation step that will process a single batch of data.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Current loss value – Dictionary containing the loss value for the current batch.
- Return type:
dict
- test_batch(batch)[source]#
This function contains the logic for a single testing step that will process a single batch of data. In this case, it is identical to validate_batch.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Current loss value – Dictionary containing the loss value for the current batch.
- Return type:
dict
- infer_batch(batch)[source]#
This function contains the logic for a single inference step.
- Parameters:
batch (tuple) – A tuple containing the input data for the current batch, possibly with labels that are ignored.
- Returns:
Reconstructed images – Tensor containing the reconstructed images for the current batch.
- Return type:
torch.Tensor
- static prepare_inputs(data_dict)[source]#
Extract the image array from the batch dictionary.
This static method is the interface between the data pipeline and the model. Override it on the model class to reshape or select fields from the collated batch to match the inputs your model expects.
Hyrax will convert the returned array to a PyTorch tensor and move it to the appropriate device automatically.
- Parameters:
data_dict (dict) – The collated batch dictionary produced by the data pipeline. Expected to contain a
"data"key with an"image"field.- Returns:
image – The image array extracted from the batch.
- Return type:
numpy.ndarray