Using UMAP to Reduce Inference Output#

UMAP (Uniform Manifold Approximation and Projection) is a dimensionality-reduction algorithm that projects high-dimensional data into 2 or 3 dimensions while preserving local structure. This makes it ideal for visualizing the latent space learned by an unsupervised model.

In this notebook we will:

  1. Train a simple autoencoder on random data.

  2. Run inference to obtain latent-space representations.

  3. Apply UMAP to reduce those representations to 2D.

  4. Load and plot the UMAP output.

  5. Show how to switch to 3D output.

1. Setup#

Create a Hyrax instance and configure it to use the built-in HyraxAutoencoder model with the HyraxRandomDataset. The random dataset lets us run the full pipeline quickly without downloading real data.

[1]:
from hyrax import Hyrax

h = Hyrax()
[2]:
# Use an autoencoder (unsupervised) so inference produces a latent vector.
h.config["model"]["name"] = "HyraxAutoencoder"

# Point at the random dataset — no download required.
data_request = {
    "train": {
        "data": {
            "dataset_class": "HyraxRandomDataset",
            "data_location": ".",
            "fields": ["image"],
            "primary_id_field": "object_id",
            "split_fraction": 1.0,
        },
    },
    "infer": {
        "data": {
            "dataset_class": "HyraxRandomDataset",
            "data_location": ".",
            "fields": ["image"],
            "primary_id_field": "object_id",
        },
    },
}
h.set_config("data_request", data_request)

# Train for just 1 epoch to keep things fast.
h.config["train"]["epochs"] = 1
[2026-03-27 12:57:12,976 hyrax.config_utils:WARNING] Runtime config contains key or section 'data_request' which has no default defined. All configuration keys and sections must be defined in /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml

2. Train and Infer#

Train the autoencoder and then run inference to produce latent-space representations for every item in the dataset.

[ ]:
model = h.train()
inference_results = h.infer()

3. Run UMAP (2D)#

By default Hyrax reduces to 2 components (n_components = 2). Calling h.umap() fits a UMAP model on a sample of the inference output, then transforms the entire dataset.

[5]:
umap_results = h.umap()
[2026-03-27 12:57:27,307 hyrax.verbs.umap:INFO] Saving UMAP results to /Users/drew/code/hyrax/docs/notebooks/results/20260327-125727-umap-f4yX
[2026-03-27 12:57:27,361 hyrax.verbs.umap:INFO] Fitting the UMAP
[2026-03-27 12:57:31,786 hyrax.verbs.umap:INFO] Saving fitted UMAP Reducer
[2026-03-27T19:57:33Z WARN  lance::dataset::write::insert] No existing dataset at /Users/drew/code/hyrax/docs/notebooks/results/20260327-125727-umap-f4yX/lance_db/results.lance, it will be created
[2026-03-27 12:57:33,594 hyrax.datasets.result_dataset:INFO] Optimizing Lance table after 1 batches
[2026-03-27 12:57:33,595 hyrax.datasets.result_dataset:INFO] Lance table optimization complete
[2026-03-27 12:57:33,595 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP

h.umap() returns a ResultDataset that you can index directly. Each element is a NumPy array with shape (n_components,).

[6]:
import numpy as np

# Stack all UMAP embeddings into a single array.
embeddings = np.array([umap_results[i] for i in range(len(umap_results))])
print(f"Shape: {embeddings.shape}  (samples × components)")
Shape: (100, 2)  (samples × components)

4. Plot the 2D Embedding#

A quick scatter plot of the two UMAP dimensions.

[7]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(embeddings[:, 0], embeddings[:, 1], s=5, alpha=0.7)
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("2D UMAP of Latent Space")
plt.tight_layout()
plt.show()
../_images/pre_executed_using_umap_11_0.png

5. Switch to 3D#

To produce a 3-dimensional embedding instead, set n_components to 3 before calling h.umap() again.

[8]:
h.config["umap"]["UMAP"]["n_components"] = 3
umap_results_3d = h.umap()
[2026-03-27 12:57:33,927 hyrax.verbs.umap:INFO] Saving UMAP results to /Users/drew/code/hyrax/docs/notebooks/results/20260327-125733-umap-WJk5
[2026-03-27 12:57:33,931 hyrax.verbs.umap:INFO] Fitting the UMAP
[2026-03-27 12:57:34,018 hyrax.verbs.umap:INFO] Saving fitted UMAP Reducer
[2026-03-27T19:57:34Z WARN  lance::dataset::write::insert] No existing dataset at /Users/drew/code/hyrax/docs/notebooks/results/20260327-125733-umap-WJk5/lance_db/results.lance, it will be created
[2026-03-27 12:57:34,320 hyrax.datasets.result_dataset:INFO] Optimizing Lance table after 1 batches
[2026-03-27 12:57:34,321 hyrax.datasets.result_dataset:INFO] Lance table optimization complete
[2026-03-27 12:57:34,321 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP
[9]:
embeddings_3d = np.array([umap_results_3d[i] for i in range(len(umap_results_3d))])
print(f"Shape: {embeddings_3d.shape}  (samples × components)")

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection="3d")
ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2], s=5, alpha=0.7)
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_zlabel("UMAP 3")
ax.set_title("3D UMAP of Latent Space")
plt.tight_layout()
plt.show()
Shape: (100, 3)  (samples × components)
../_images/pre_executed_using_umap_14_1.png

Key Configuration Options#

All UMAP settings live under the [umap] section of the Hyrax config. The most commonly adjusted parameters are:

Config key

Default

Description

umap.UMAP.n_components

2

Number of output dimensions (2 or 3).

umap.UMAP.n_neighbors

15

Balances local vs. global structure.

umap.fit_sample_size

1024

Number of points used to fit the UMAP model.

umap.parallel

false

Use multiprocessing during the transform step.

See the UMAP documentation for the full list of parameters you can pass under umap.UMAP.