Source code for hyrax.verbs.visualize

import logging
import random
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Union

import numpy.typing as npt
import torch
from matplotlib.colors import LogNorm

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
[docs] @hyrax_verb
[docs] class Visualize(Verb): """Verb to create a visualization"""
[docs] cli_name = "visualize"
[docs] add_parser_kwargs = {}
@staticmethod
[docs] def setup_parser(parser: ArgumentParser): """CLI not implemented for this verb""" pass
[docs] def run_cli(self, args: Namespace | None = None): """CLI not implemented for this verb""" logger.error("Running visualize from the cli is unimplemented")
[docs] def run( self, input_dir: Union[Path, str] | None = None, *, return_verb: bool = False, make_lupton_rgb_opts: dict | None = None, **kwargs, ): """Generate an interactive notebook visualization of a latent space that has been umapped down to 2d. The plot contains two holoviews objects, a scatter plot of the latent space, and a table of objects which can be populated by selecting from the scatter plot. Parameters ---------- input_dir : Optional[Union[Path, str]], optional Directory holding the output from the 'umap' verb, by default None. When not provided, we use [results][inference_dir] from config. If that's false; we the most recent umap in the current results directory. return_verb : bool, optional If True, also return the underlying Visualize instance for post-hoc access to selection state. Defaults to False. make_lupton_rgb_opts : dict, optional Dictionary of options to pass to astropy's make_lupton_rgb function for RGB image creation. Default is {"stretch": 5, "Q": 8}. Common parameters include stretch (brightness/contrast) and Q (softening parameter for asinh transformation). kwargs : Keyword arguments are passed through as options for the plot object as ``plot_pane.opts(**plot_options)``. It is not recommended to override the "tools" plot option, because that will break the integration between the plot selection operations and the table. Returns ------- Holoviews, if return_verb = True (defaul) A Collection of Haloviews Panes tuple of (pane, Visualize), if return_verb = True Returns a 2-tuple with the pane and the verb instance. """ import numpy as np import panel as pn from holoviews import DynamicMap, extension from holoviews.operation.datashader import dynspread, rasterize from holoviews.streams import Lasso, Params, RangeXY, SelectionXY, Tap from scipy.spatial import KDTree from hyrax.data_sets.inference_dataset import InferenceDataSet if self.config["data_set"]["object_id_column_name"]: self.object_id_column_name = self.config["data_set"]["object_id_column_name"] else: self.object_id_column_name = "object_id" fields = [self.object_id_column_name] fields += self.config["visualize"]["fields"] self.cmap = self.config["visualize"]["cmap"] if self.config["data_set"]["filename_column_name"]: self.filename_column_name = self.config["data_set"]["filename_column_name"] else: self.filename_column_name = "filename_data" if self.config["visualize"]["display_images"]: fields += [self.filename_column_name] # If no input directory is specified, read from config. if input_dir is None: logger.info("UMAP directory not specified at runtime. Reading from config values.") input_dir = self.config["results"]["inference_dir"] # Get the umap data and put it in a kdtree for indexing. self.umap_results = InferenceDataSet(self.config, results_dir=input_dir, verb="umap") logger.info(f"Rendering UMAP from the following directory: {self.umap_results.results_dir}") available_fields = self.umap_results.metadata_fields() for field in fields.copy(): if field not in available_fields: logger.warning(f"Field {field} is unavailable for this dataset") fields.remove(field) if self.object_id_column_name not in fields: msg = "Umap dataset must support object_id field" raise RuntimeError(msg) self.data_fields = fields.copy() self.data_fields.remove(self.object_id_column_name) self.tree = KDTree(self.umap_results) # Store color column and extract color values if specified self.color_column = self.config["visualize"]["color_column"] self.color_values = None # Validate torch_tensor_bands configuration self.torch_tensor_bands = self.config["visualize"]["torch_tensor_bands"] # Defaults to i-band if len(self.torch_tensor_bands) not in [1, 3]: raise ValueError( f"torch_tensor_bands must specify either 1 band (single-band) or 3 bands (RGB). " f"Got {len(self.torch_tensor_bands)} bands: {self.torch_tensor_bands}" ) # Store make_lupton_rgb options with defaults self.make_lupton_rgb_opts = make_lupton_rgb_opts or {"stretch": 5, "Q": 8} if self.color_column: try: # Check if column exists available_fields = self.umap_results.metadata_fields() if self.color_column not in available_fields: logger.warning( f"Column '{self.color_column}' not found in dataset." f" Available fields: {available_fields}" ) self.color_column = False else: # Get all indices for the dataset all_indices = list(range(len(self.umap_results))) # Extract metadata for the specified column metadata = self.umap_results.metadata(all_indices, [self.color_column]) self.color_values = metadata[self.color_column] logger.info(f"Successfully loaded color values from column '{self.color_column}'") import numpy as np logger.debug( f"Color values range: {np.nanmin(self.color_values)} " f"to {np.nanmax(self.color_values)}" ) logger.debug(f"NaN count: {np.sum(np.isnan(self.color_values))}") except Exception as e: logger.warning(f"Could not load column '{self.color_column}': {e}") logger.warning("Proceeding without coloring") self.color_column = False self.color_values = None # Initialize holoviews with bokeh. extension("bokeh") # Set up the plot pane xmin, xmax, ymin, ymax = self._even_aspect_bounding_box() self.plot_options = { "tools": ["box_select", "lasso_select", "tap"], "width": 500, "height": 500, "xlim": (xmin, xmax), "ylim": (ymin, ymax), "cnorm": "eq_hist", } self.plot_options.update(kwargs) if self.color_column: # For colored plots, show all points to preserve colorbar # This is a current Hack to overcome the fact that the # RangeXY stream breaks the colorbar. Needs to be investigated # further for permanent solution. plot_dm = DynamicMap( lambda: self.visible_points( x_range=[float("-inf"), float("inf")], y_range=[float("-inf"), float("inf")] ) ) else: plot_dm = DynamicMap(self.visible_points, streams=[RangeXY()]) if self.config["visualize"]["rasterize_plot"]: # Note that reasterization will break color-bar feature plot_pane = dynspread(rasterize(plot_dm).opts(**self.plot_options)) else: plot_pane = plot_dm.opts(**self.plot_options) # Setup the table pane event handler self.prev_kwargs = { # For Lasso "geometry": None, # For Tap "x": None, "y": None, # For SelectionXY "bounds": None, "x_selection": None, "y_selection": None, } table_streams = [ Lasso(source=plot_pane), Tap(source=plot_pane), SelectionXY(source=plot_pane), ] # Setup the table pane # self.table = Table(tuple([[0]]*(3+len(self.data_fields))), ["object_id"], self.data_fields) self.points = np.array([]) self.points_id = np.array([]) self.points_idx = np.array([]) self.table = self._table_from_points() table_options = {"width": self.plot_options["width"]} table_pane = DynamicMap(self.selected_objects, streams=table_streams).opts(**table_options) # If display_images is set to True then display randomly chosen images from the selected # sample underneath the table pane if self.config["visualize"]["display_images"]: pn.extension() # Create a small loading spinner same height as button self.spinner = pn.indicators.LoadingSpinner( value=False, # Start with spinner off height=30, # Smaller height to match button width=30, # Smaller width margin=(5, 10, 5, 0), # Add some margin for spacing ) refresh_btn = pn.widgets.Button(name="Resample Images", button_type="primary") # Create a button row with spinner next to button button_row = pn.Row(refresh_btn, self.spinner, align="start") image_pane = DynamicMap( self._load_images, streams=[Params(refresh_btn, ["clicks"]), *table_streams] ) images_panel = pn.pane.HoloViews(image_pane) plot_panel = pn.panel(plot_pane) # Set the table pane to be max 30% of the height table_h = int(self.plot_options["height"] * 0.3) table_panel = pn.panel(table_pane, height=table_h) right = pn.Column(table_panel, images_panel, button_row) pane = pn.Row(plot_panel, right) else: # Plot pane and table pane side by side pane = plot_pane + table_pane # We attempt to display the pane (fails outside a notebook) try: from IPython.display import display display(pane) except ImportError: logger.warning("Couldn't find IPython display environment. Skipping display step.") if return_verb: return pane, self else: return pane
[docs] def visible_points(self, x_range: Union[tuple, list], y_range: Union[tuple, list]): """Generate a hv.Points object with the points inside the bounding box passed. This is the event handler for moving or scaling the latent space plot, and is called by Holoviews. Parameters ---------- x_range : tuple or list min and max x values y_range : tuple or list min and max y values Returns ------- hv.Points Points lying inside the bounding box passed """ import numpy as np from holoviews import Points if x_range is None or y_range is None: return Points([]) if np.any(np.isinf([x_range, y_range])): # Show all points without filtering points = np.array([point.numpy() for point in self.umap_results]) point_indices = list(range(len(self.umap_results))) else: # Use existing filtering logic points, _, point_indices = self.box_select_points(x_range, y_range) if self.color_values is not None and len(point_indices) > 0: visible_colors = self.color_values[point_indices] # Create Points object with color data (x, y, color) point_data = np.column_stack([points, visible_colors]) pts = Points(point_data, vdims=[self.color_column]) # Apply color options directly to the Points object pts = pts.opts( color=self.color_column, cmap=self.cmap, colorbar=True, colorbar_opts={ "width": 18, "title": self.color_column, "title_text_font_size": "14pt", "title_text_font_style": "normal", }, ) else: pts = Points(points) return pts
[docs] def update_points(self, **kwargs) -> None: """ This is the main UI event handler for selection tools on the plot. If you are a dynamic map in the layout of the visualizer who updates based on plot selection you MUST call this function. This function accepts the data values from all streams and uses the differences between the current call and prior calls to differentiate between different UI events. The self.prev_kwargs dictionary is used to store previous calls to this function, and the ``_called_*`` helpers perform the differencing for each case. Calling this function GUARANTEES that self.points, self.points_id, and self.points_idx are up-to-date with the user's latest selection, regardless of the order that Holoviews evaluates the DynamicMaps in. """ import numpy as np if self._called_lasso(kwargs): self.points, self.points_id, self.points_idx = self.poly_select_points(kwargs["geometry"]) elif self._called_tap(kwargs): _, idx = self.tree.query([kwargs["x"], kwargs["y"]]) self.points = np.array([self.umap_results[idx].numpy()]) self.points_id = np.array([list(self.umap_results.ids())[idx]]) self.points_idx = np.array([idx]) elif self._called_box_select(kwargs): self.points, self.points_id, self.points_idx = self.box_select_points( kwargs["x_selection"], kwargs["y_selection"] ) else: # We saw no change that indicated a user intent; therefore, this is either initialization # OR we are not the first DynamicMap to run. pass self.prev_kwargs = kwargs
[docs] def _called_lasso(self, kwargs): return kwargs["geometry"] is not None and ( self.prev_kwargs["geometry"] is None or len(self.prev_kwargs["geometry"]) != len(kwargs["geometry"]) or any(self.prev_kwargs["geometry"].flatten() != kwargs["geometry"].flatten()) )
[docs] def _called_tap(self, kwargs): return ( kwargs["x"] is not None and kwargs["y"] is not None and (self.prev_kwargs["x"] != kwargs["x"] or self.prev_kwargs["y"] != kwargs["y"]) )
[docs] def _called_box_select(self, kwargs): return ( kwargs["x_selection"] is not None and kwargs["y_selection"] is not None and ( (self.prev_kwargs["x_selection"] is None and self.prev_kwargs["x_selection"] is None) or ( self.prev_kwargs["x_selection"] != kwargs["x_selection"] or self.prev_kwargs["y_selection"] != kwargs["y_selection"] ) ) )
[docs] def poly_select_points(self, geometry) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]: """Select points inside a polygon. Parameters ---------- geometry : list List of x/y points describing the verticies of the polygon Returns ------- Tuple First element is an ndarray of x/y points in latent space inside the polygon Second element is an ndarray of corresponding object ids """ import numpy as np from scipy.spatial import Delaunay # Coarse grain the points within the axis-aligned bounding box of the geometry (xmin, xmax, ymin, ymax) = Visualize._bounding_box(geometry) point_indexes_coarse = self.box_select_indexes([xmin, xmax], [ymin, ymax]) points_coarse = self.umap_results[point_indexes_coarse].numpy() tri = Delaunay(geometry) mask = tri.find_simplex(points_coarse) != -1 mask = np.asarray(mask) if any(mask): points = points_coarse[mask] point_indexes = np.array(point_indexes_coarse)[mask] points_id = np.array(list(self.umap_results.ids()))[point_indexes] return points, points_id, point_indexes else: return np.array([[]]), np.array([]), np.array([])
[docs] def box_select_points( self, x_range: Union[tuple, list], y_range: Union[tuple, list] ) -> tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]: """Return the points and IDs for a box in the latent space Parameters ---------- x_range : tuple or list min and max x values y_range : tuple or list min and max y values Returns ------- Tuple First element is an ndarray of x/y points in latent space inside the box Second element is an ndarray of corresponding object ids """ import numpy as np indexes = self.box_select_indexes(x_range, y_range) ids = np.array(list(self.umap_results.ids()))[indexes] points = self.umap_results[indexes].numpy() return points, ids, indexes
[docs] def box_select_indexes(self, x_range: Union[tuple, list], y_range: Union[tuple, list]): """Return the indexes inside of a particular box in the latent space Parameters ---------- x_range : tuple or list min and max x values y_range : tuple or list min and max y values Returns ------- np.ndarray Array of data indexes where the latent space representation falls inside the given box. """ import numpy as np # Find center xc = (x_range[0] + x_range[1]) / 2.0 yc = (y_range[0] + y_range[1]) / 2.0 query_pt = [xc, yc] # Find larger of half-width and half-height to use as our search radius. radius = np.max([np.max(x_range) - xc, np.max(y_range) - yc]) # This is slightly overzealous, grabbing points outside the box sometimes. indexes = self.tree.query_ball_point(query_pt, radius, p=np.inf) def _inside_box(pt): x, y = pt xmin, xmax = x_range ymin, ymax = y_range return x > xmin and x < xmax and y > ymin and y < ymax # Filter for points properly inside the box return [i for i in indexes if _inside_box(self.umap_results[i].numpy())]
[docs] def selected_objects(self, **kwargs): """ Generate the holoview table for a selected set of objects based on input from the Lasso, Tap, and SelectionXY streams. Returns ------- hv.Table Table with Object ID, x, y locations of the selected objects """ self.update_points(**kwargs) self.table = self._table_from_points() return self.table
[docs] def _table_from_points(self): from holoviews import Table # Basic table with x/y pairs key_dims = [self.object_id_column_name] value_dims = ["x", "y"] + self.data_fields if not len(self.points_id): columns = [[1]] * (len(key_dims) + len(value_dims)) return Table(tuple(columns), key_dims, value_dims) # these are the object_id, x, and y columns columns = [self.points_id, self.points.T[0], self.points.T[1]] # type: ignore[list-item] # These are the rest of the columns, pulled from metadata try: metadata = self.umap_results.metadata(self.points_idx, self.data_fields) except Exception as e: # Leave in this try/catch beause some notebook implementations dont # allow us to return an exception to the console. return Table(([str(e)]), ["message"]) columns += [metadata[field] for field in self.data_fields] # type: ignore[call-overload,misc,index] return Table(tuple(columns), key_dims, value_dims)
@staticmethod
[docs] def _bounding_box(points): import numpy as np # Find bounding box for the current dataset. xmin, xmax, ymin, ymax = (np.inf, -np.inf, np.inf, -np.inf) for x, y in points: xmin = x if x < xmin else xmin xmax = x if x > xmax else xmax ymin = y if y < ymin else ymin ymax = y if y > ymax else ymax return (xmin, xmax, ymin, ymax)
[docs] def _even_aspect_bounding_box(self): # Bring aspect ratio to 1:1 by expanding the smaller axis range (xmin, xmax, ymin, ymax) = Visualize._bounding_box(point.numpy() for point in self.umap_results) x_dim = xmax - xmin x_center = (xmax + xmin) / 2.0 y_dim = ymax - ymin y_center = (ymax + ymin) / 2.0 if x_dim > y_dim: ymin = y_center - x_dim / 2.0 ymax = y_center + x_dim / 2.0 else: xmin = x_center - y_dim / 2.0 xmax = x_center + x_dim / 2.0 return (xmin, xmax, ymin, ymax)
[docs] def get_selected_df(self): r""" Retrieve a pandas DataFrame containing the currently selected points and their associated metadata. Returns ------- pd.DataFrame A DataFrame with one row per selected point and columns: ["object_id", "x", "y", \*additional_fields]. """ import pandas as pd if len(self.points_id) == 0: logger.error("No points selected") df = pd.DataFrame(self.points, columns=["x", "y"]) df[self.object_id_column_name] = self.points_id meta = self.umap_results.metadata(self.points_idx, self.data_fields) meta_df = pd.DataFrame(meta, columns=self.data_fields) cols = [self.object_id_column_name, "x", "y"] + self.data_fields result = pd.concat([df.reset_index(drop=True), meta_df.reset_index(drop=True)], axis=1) return result.reindex(columns=cols)
[docs] def _load_images(self, **kwargs): # Turn on spinner manually before loading self.spinner.value = True self.update_points(**kwargs) # Load images result = self._make_image_pane(total_width=self.plot_options["width"]) # Turn off spinner when done self.spinner.value = False return result
[docs] def _make_image_pane(self, total_width: int = 500, *args, **kwargs): """ Sample up to 6 of the selected object_ids, load their FITS cutouts from [general][data_dir], and render as small hv.Image thumbnails in a grid. """ import numpy as np from astropy.io import fits from astropy.visualization import make_lupton_rgb from holoviews import RGB, Image, Layout def style_plot(plot, element): bokeh_plot = plot.state bokeh_plot.toolbar.autohide = True bokeh_plot.title.text_font_size = "8pt" def crop_center(arr: np.ndarray, crop_shape: tuple[int, int]) -> np.ndarray: crop_h, crop_w = crop_shape if arr.ndim == 3: # RGB case h, w, c = arr.shape else: # Single-band case h, w = arr.shape if crop_h > h or crop_w > w: logger.warning(f"Crop size {crop_shape} exceeds image size {(h, w)}. Skipping crop.") return arr top = (h - crop_h) // 2 left = (w - crop_w) // 2 if arr.ndim == 3: return arr[top : top + crop_h, left : left + crop_w, :] else: return arr[top : top + crop_h, left : left + crop_w] n_images = 6 n_rows = 2 n_cols = int(n_images / n_rows) imgs = [] if len(self.points_id) > 0: id_map = dict(zip(self.points_idx, self.points_id)) # If we have fewer than n_images points, use all of them but force a fresh load if len(self.points_idx) <= n_images: chosen_idx = list(self.points_idx) else: chosen_idx = random.sample(list(self.points_idx), n_images) # Get sampled ids correspoinding to the idxs sampled_ids = [id_map[idx] for idx in chosen_idx] # Get metadata - this is in the same order as chosen_idx meta = self.umap_results.metadata( chosen_idx, [self.object_id_column_name, self.filename_column_name] ) # Extract metadata directly # DEBUG: object_ids = meta[self.object_id_column_name] raw_filenames = meta[self.filename_column_name] filenames = [f.decode("utf-8") for f in raw_filenames] else: sampled_ids = [] filenames = [] crop_to = self.config["data_set"]["crop_to"] base_dir = Path(self.config["general"]["data_dir"]) # Defining a Fallback Image to Display in case of errors # Matching Shape is important because otherwise Haloviews' # DynamicMap fails silently if len(self.torch_tensor_bands) == 3: placeholder_arr = np.full((64, 64, 3), 1.0) else: placeholder_arr = np.full((64, 64), 1.0) for i in range(n_images): if i < len(sampled_ids): try: cutout_path = Path(filenames[i]) if not cutout_path.is_absolute(): cutout_path = base_dir / cutout_path if cutout_path.suffix.lower() == ".fits": arr = fits.getdata(cutout_path) elif cutout_path.suffix.lower() == ".pt": tensor = torch.load(cutout_path, map_location="cpu", weights_only=True) if len(self.torch_tensor_bands) == 1: # Single-band extraction band_idx = self.torch_tensor_bands[0] arr = tensor[band_idx].numpy() else: # RGB extraction (3 bands) rgb_arrays = [] for band_idx in self.torch_tensor_bands: rgb_arrays.append(tensor[band_idx].numpy()) # Stack along new axis to create (H, W, 3) RGB array arr = np.stack(rgb_arrays, axis=-1) else: raise ValueError( f"Unsupported file format: {cutout_path.suffix}. Currently\ the visualize module only supports FITS and PyTorch files" ) if crop_to: arr = crop_center(arr, crop_to) # Handle normalization differently for single-band vs RGB if arr.ndim == 3: # Use astropy's Lupton RGB # arr shape is (H, W, 3) but make_lupton_rgb expects (r, g, b) as separate arrays r_band = arr[:, :, 0] g_band = arr[:, :, 1] b_band = arr[:, :, 2] # make_lupton_rgb applies an asinh stretch and returns values in [0, 1] range # Use configurable options for make_lupton_rgb arr = make_lupton_rgb(r_band, g_band, b_band, **self.make_lupton_rgb_opts) else: # Single-band case # Ensure data is positive for log scaling min_positive = np.min(arr[arr > 0]) if np.any(arr > 0) else 1e-10 arr = np.maximum( arr, min_positive ) # Replace zeros/negatives with minimum positive value # Apply LogNorm-like scaling norm = LogNorm(vmin=min_positive, vmax=np.max(arr)) arr = norm(arr) # DEBUG: title = f"{chosen_idx[i]}:{object_ids[i]}\n{sampled_ids[i]}" title = f"{sampled_ids[i]}" except Exception as e: logger.warning(f"Could not load file: {e}") with open("./hyrax_visualize.log", "a") as f: f.write(f"Could not load FITS file: {e}\n") arr = placeholder_arr title = f"NL:{sampled_ids[i]}" else: arr = placeholder_arr title = "No Selection" # Configure image options based on array dimensions img_opts = { "width": int((0.9 * total_width) / n_cols), "height": int((0.9 * total_width) / n_cols), "title": title, "tools": [], "shared_axes": False, "hooks": [style_plot], "xaxis": None, "yaxis": None, } if arr.ndim == 3: # RGB case img = RGB(arr).opts(**img_opts) else: # Single-band case img_opts["cmap"] = "gray_r" img = Image(arr).opts(**img_opts) imgs.append(img) return Layout(imgs).cols(n_cols)