Source code for hyrax.verbs.reduction_algorithms.umap

import logging
import pickle
from pathlib import Path
from typing import Union

import numpy as np
import umap

from .algorithm_registry import ReductionAlgorithm

[docs] logger = logging.getLogger(__name__)
[docs] class UMAP(ReductionAlgorithm): """UMAP reduction implementation.""" def __init__(self, config: dict, reduction_results=None): super().__init__(config, reduction_results)
[docs] self.reducer = umap.UMAP(**self.config["reduce"]["umap"]["kwargs"])
[docs] def save_model(self, results_dir: Path): """ Save the fitted UMAP model to a pickle file. Parameters ---------- results_dir : Path The directory where the model should be saved. The model will be saved as 'umap.pickle' in this directory. """ with open(results_dir / "umap.pickle", "wb") as f: pickle.dump(self.reducer, f)
[docs] def load_model(self, expected_input_dim: int, model_path: Union[Path, str] | None = None): """ Load a pre-existing UMAP model from disk. Parameters ---------- expected_input_dim : int The expected number of input features for the loaded model. model_path : Path or str, optional The path to the file to load the model from. If not specified, method will look in the config for a default model path. """ if model_path is None: model_path = self.config["reduce"]["umap"]["model_path"] if not model_path: logger.info("No pre-existing UMAP model found. A new model will be fitted.") return None # Path validity check model_path = Path(model_path) if not model_path.is_file(): raise FileNotFoundError(f"UMAP model file not found: {model_path}") logger.info(f"Loading pre-existing UMAP model from {model_path}") reducer = self._load_pickle(model_path) self._validate_umap_model(reducer, expected_input_dim) self.reducer = reducer
[docs] def _validate_umap_model(self, reducer, expected_input_dim: int) -> None: """ Validate the loaded UMAP model. Checks that the loaded object is a UMAP instance and that its input and output dimensions match the expected values. Parameters ---------- reducer : object The loaded model object to validate. expected_input_dim : int The expected number of input features for the loaded model. Raises ------ ValueError If the loaded model is not a UMAP instance or if its input/output dimensions are incompatible. """ # UMAP type check if not isinstance(reducer, umap.UMAP): raise ValueError(f"The loaded model is not a UMAP instance: {type(reducer)}") # Input feature dim check if reducer._raw_data.shape[1] != expected_input_dim: raise ValueError( f"The input dimension of the loaded UMAP model ({reducer._raw_data.shape[1]})" f" does not match the dimension of the inference data ({expected_input_dim})." ) # Output dim check if reducer.n_components != self.reducer.n_components: raise ValueError( f"The output dimension of the loaded UMAP model ({reducer.n_components})" f" does not match the configured n_components ({self.reducer.n_components})." )
[docs] def fit(self, data_sample: np.ndarray): """ Fit the UMAP model to a sample of inference data. The fitted model is stored in the instance variable `self.reducer` and can be used for transforming data. Parameters ---------- data_sample : numpy.ndarray The data sample used to fit the model. """ self._log_memory_usage("Before fitting umap") logger.info("Fitting the UMAP") self.reducer.fit(data_sample) self._log_memory_usage("After fitting umap")
[docs] def transform(self, args: dict, num_batches: int): """ Transform data with a fitted UMAP model. Use parallel processing if specified in the config. Parameters ---------- args : dict A dictionary containing the data to be transformed. num_batches : int The total number of batches that the data is split into for transformation. """ if self.reducer is None or not isinstance(self.reducer, umap.UMAP): raise RuntimeError("Cannot transform data before loading or fitting a UMAP model.") from tqdm.auto import tqdm if self.config["reduce"]["parallel"]: import multiprocessing as mp # Process pool loop # Use 'spawn' context to safely create subprocesses after # OpenMP threads are being opened by other processes in hyrax # Not using spawn causes the issue linked below # https://github.com/lincc-frameworks/hyrax/issues/291 # TODO: Find more elegant solution than just using spawn with mp.get_context("spawn").Pool(processes=mp.cpu_count()) as pool: for batch_ids, transformed_batch in tqdm( pool.imap(self._transform_batch, args), desc="Creating lower dimensional representation using UMAP:", total=num_batches, ): self.reduction_results.write_batch(batch_ids, transformed_batch) else: # Sequential loop for batch_ids, batch in tqdm( args, desc="Creating lower dimensional representation using UMAP:", total=num_batches, ): transformed_batch = self.reducer.transform(batch) self._log_memory_usage(f"During transformation of batch of shape {batch.shape}") self.reduction_results.write_batch(batch_ids, transformed_batch) self.reduction_results.commit() # Ensure all data is written and finalized