Customizing prepare_inputs#

Every Hyrax model has a prepare_inputs static method that acts as the bridge between the raw data coming out of a dataset and the NumPy arrays the model actually consumes. By swapping this one function you can change how input data is normalized, reshaped, or combined — without touching the model or dataset code.

In this notebook we will:

  1. Load a model class with h.model().

  2. Inspect the default prepare_inputs behaviour.

  3. Write and attach custom prepare_inputs functions that normalize, reshape, or select different fields from the data dictionary.

Setup#

We start by creating a Hyrax instance and choosing a model.

[1]:
from hyrax import Hyrax

h = Hyrax()
h.set_config("model.name", "HyraxCNN")

h.model() returns the model class (not an instance). This lets us inspect and modify class-level attributes — including prepare_inputs — before any training or inference begins.

[2]:
ModelClass = h.model()
print(ModelClass)
<class 'hyrax.models.hyrax_cnn.HyraxCNN'>

Understanding the data dictionary#

During training, each batch is delivered to prepare_inputs as a nested dictionary. For example, the HyraxRandomDataset produces batches that look like this after collation:

{
    "data": {
        "image": np.ndarray,   # shape (batch, channels, height, width)
        "label": np.ndarray,   # shape (batch,)
    }
}

The job of prepare_inputs is to extract and transform whatever the model needs from this dictionary and return it as NumPy arrays. Hyrax takes care of converting the result to PyTorch tensors and moving them to the correct device — you never need to call torch.from_numpy yourself.

Inspecting the default implementation#

Let’s look at the default prepare_inputs on HyraxCNN. It extracts the image and label arrays and returns them as a tuple.

[3]:
import inspect

print(inspect.getsource(ModelClass.prepare_inputs))
    @staticmethod
    def prepare_inputs(data_dict) -> tuple:
        """Extract image and label arrays from the batch dictionary.

        This static method is the interface between the data pipeline and the
        model. Override it on the model class to reshape or select fields from
        the collated batch to match the inputs your model expects.

        Hyrax will convert the returned arrays to PyTorch tensors and move them
        to the appropriate device automatically.

        Parameters
        ----------
        data_dict : dict
            The collated batch dictionary produced by the data pipeline.
            Expected to contain a ``"data"`` key with ``"image"`` and optionally
            ``"label"`` fields.

        Returns
        -------
        inputs : tuple of numpy.ndarray
            A tuple of ``(image, label)`` as float32 and int64 arrays respectively.
        """

        import numpy as np

        if "data" not in data_dict:
            raise RuntimeError("Unable to find `data` key in data_dict")

        data = data_dict["data"]
        image = np.asarray(data["image"], dtype=np.float32)
        label = np.asarray(data.get("label", []), dtype=np.int64)

        return (image, label)

We can call it directly with a sample dictionary to see what it returns.

[4]:
import numpy as np

sample_batch = {
    "data": {
        "image": np.random.rand(2, 3, 29, 29).astype(np.float32),
        "label": np.array([0, 1], dtype=np.int64),
    }
}

result = ModelClass.prepare_inputs(sample_batch)
print(f"Type: {type(result)}")
print(f"Image shape: {result[0].shape}, dtype: {result[0].dtype}")
print(f"Label shape: {result[1].shape}, dtype: {result[1].dtype}")
Type: <class 'tuple'>
Image shape: (2, 3, 29, 29), dtype: float32
Label shape: (2,), dtype: int64

Example 1 — Normalizing pixel values#

Suppose the images from your survey are stored as unsigned 16-bit integers in the range 0–65 535. Many models train better when inputs are normalized to a smaller range. The function below converts the image to float32 and scales it to [0, 1].

[5]:
@staticmethod
def prepare_inputs_normalized(data_dict):
    """Normalize uint16 images to [0, 1] float32."""
    import numpy as np

    data = data_dict["data"]
    image = np.asarray(data["image"], dtype=np.float32) / 65535.0
    label = np.asarray(data.get("label", []), dtype=np.int64)
    return (image, label)


# Assign the new function to the model class.
ModelClass.prepare_inputs = prepare_inputs_normalized

Let’s verify it works with a synthetic uint16 batch.

[6]:
uint16_batch = {
    "data": {
        "image": np.random.randint(0, 65535, size=(2, 3, 29, 29), dtype=np.uint16),
        "label": np.array([0, 1]),
    }
}

image, label = ModelClass.prepare_inputs(uint16_batch)
print(f"Image dtype: {image.dtype}, min: {image.min():.4f}, max: {image.max():.4f}")
print(f"Label: {label}")
Image dtype: float32, min: 0.0004, max: 0.9997
Label: [0 1]

Example 2 — Log-scaling astronomical fluxes#

Astronomical images often span many orders of magnitude. A common preprocessing step is to apply a log transform so the model does not have to learn across such a wide dynamic range.

[7]:
@staticmethod
def prepare_inputs_logscale(data_dict):
    """Apply log1p scaling to images."""
    import numpy as np

    data = data_dict["data"]
    image = np.asarray(data["image"], dtype=np.float32)

    # np.log1p computes log(1 + x), which is safe for values near zero.
    image = np.log1p(np.clip(image, 0, None))

    label = np.asarray(data.get("label", []), dtype=np.int64)
    return (image, label)


ModelClass.prepare_inputs = prepare_inputs_logscale
[8]:
flux_batch = {
    "data": {
        "image": np.random.exponential(scale=1000, size=(2, 1, 64, 64)).astype(np.float32),
        "label": np.array([0, 1]),
    }
}

image, label = ModelClass.prepare_inputs(flux_batch)
print(f"Image dtype: {image.dtype}, min: {image.min():.2f}, max: {image.max():.2f}")
Image dtype: float32, min: 0.03, max: 9.21

Example 3 — Selecting a single channel#

If your dataset provides multi-channel images but you want to experiment with a model that accepts only one channel, you can slice the array inside prepare_inputs.

[9]:
@staticmethod
def prepare_inputs_single_channel(data_dict):
    """Keep only the first channel of a multi-channel image."""
    import numpy as np

    data = data_dict["data"]
    image = np.asarray(data["image"], dtype=np.float32)

    # Select the first channel: (batch, C, H, W) -> (batch, 1, H, W)
    image = image[:, 0:1, :, :]

    label = np.asarray(data.get("label", []), dtype=np.int64)
    return (image, label)


ModelClass.prepare_inputs = prepare_inputs_single_channel
[10]:
multi_ch_batch = {
    "data": {
        "image": np.random.rand(2, 5, 32, 32).astype(np.float32),
        "label": np.array([0, 1]),
    }
}

image, label = ModelClass.prepare_inputs(multi_ch_batch)
print(f"Original channels: 5, Output image shape: {image.shape}")
Original channels: 5, Output image shape: (2, 1, 32, 32)

Example 4 — Images only (autoencoders)#

Autoencoders reconstruct their own input, so they don’t need labels. Your prepare_inputs can return a single array instead of a tuple.

[11]:
@staticmethod
def prepare_inputs_image_only(data_dict):
    """Return only the image, suitable for an autoencoder."""
    import numpy as np

    data = data_dict["data"]
    return np.asarray(data["image"], dtype=np.float32)


ModelClass.prepare_inputs = prepare_inputs_image_only
[12]:
result = ModelClass.prepare_inputs(sample_batch)
print(f"Type: {type(result)}, shape: {result.shape}")
Type: <class 'numpy.ndarray'>, shape: (2, 3, 29, 29)

Important notes#

  • Return NumPy arrays, not PyTorch tensors. Hyrax handles the conversion to tensors and moves them to the appropriate device (CPU / GPU) automatically.

  • Import inside the function. When a model is saved, prepare_inputs is written to a standalone Python file so it can be restored later. Any imports the function needs should therefore be placed inside the function body.

  • Return a tuple of any size In these examples at most 2 arrays were returned in a tuple from prepare_inputs. However, any number of values can be returned depending on what data is required for your model.

  • Decorate with ``@staticmethod``. When defining your custom prepare_inputs in a notebook or module, use the @staticmethod decorator so that the function does not receive self as a first argument.

  • Saving and loading. When you call model.save(), any custom prepare_inputs you assigned is persisted alongside the model weights. Calling model.load() restores it, so your preprocessing is reproducible across sessions.