Source code for hyrax.verbs.visualize_v2

import logging
from argparse import ArgumentParser, Namespace
from typing import TYPE_CHECKING

from .verb_registry import Verb, hyrax_verb

if TYPE_CHECKING:
    import pandas as pd

[docs] logger = logging.getLogger(__name__)
[docs] def _disable_axis_zoom(plot, element): """Bokeh hook: disable axis-only zoom and align tick marks inward.""" from bokeh.models import WheelZoomTool for tool in plot.state.tools: if isinstance(tool, WheelZoomTool): tool.zoom_on_axis = False plot.state.match_aspect = True for axis in [*plot.state.xaxis, *plot.state.yaxis]: axis.major_tick_in = 6 axis.major_tick_out = 0 axis.minor_tick_in = 3
[docs] axis.minor_tick_out = 0 axis.major_label_standoff = -20
@hyrax_verb
[docs] class VisualizeV2(Verb): """Verb to create a hexbin visualization of a 2D latent space."""
[docs] cli_name = "visualize_v2"
[docs] add_parser_kwargs = {}
[docs] REQUIRED_DATA_GROUPS = ("visualize",)
[docs] OPTIONAL_DATA_GROUPS = ()
@staticmethod
[docs] def setup_parser(parser: ArgumentParser): """CLI not implemented for this verb""" pass
[docs] def run_cli(self, args: Namespace | None = None): """CLI not implemented for this verb""" logger.error("Running visualize_v2 from the cli is unimplemented")
[docs] def run( self, **kwargs, ): """Generate an interactive hexbin visualization of a latent space projected to 2D. Uses HoloViews HexTiles with datashader for adaptive hexbin aggregation, box/lasso selection, a metadata table, and tabbed detail plots. Parameters ---------- kwargs : Additional keyword arguments passed as HexTiles opts overrides. Returns ------- VisualizeV2 This verb instance. Use it to call ``restart_ui()`` or ``get_selected_df()`` after the UI has been displayed. """ import panel as pn # pn.extension must be called before displaying any Panel widget. # Keep it here (in addition to _build_ui) so the loading indicator works on first run. pn.extension("tabulator") def _status(dataset_done: bool, ui_done: bool) -> str: dataset_icon = "✅" if dataset_done else "⏳" ui_icon = "✅" if ui_done else "⏳" return f"- {dataset_icon} Loading dataset\n- {ui_icon} Rendering UI" _loading = pn.pane.Markdown(_status(False, False)) ipy_display = clear_output = None try: from IPython.display import clear_output from IPython.display import display as ipy_display ipy_display(_loading) except ImportError: pass self._load_data() _loading.object = _status(True, False) # Wipe the loading indicator so the full UI replaces it cleanly in the cell output. if clear_output is not None: clear_output(wait=True) self._build_ui(**kwargs) # Build the UI once to cache any heavy operations before showing it. return self
[docs] def restart_ui(self, **kwargs): """Rebuild and re-display the Panel UI without reloading data. Call this after a Jupyter websocket disconnect instead of re-running the cell. The expensive data-loading step is skipped — only the widgets are rebuilt. Parameters ---------- kwargs : Additional keyword arguments passed as HexTiles opts overrides. Returns ------- VisualizeV2 This verb instance. Use it to call ``restart_ui()`` or ``get_selected_df()`` after the UI has been displayed. """ if not getattr(self, "_data_loaded", False): raise RuntimeError("No data loaded yet. Call run() first.") self._build_ui(**kwargs) return self
[docs] def _load_data(self): """Load dataset and build the points DataFrame. Guards with a ``_data_loaded`` sentinel so the expensive steps only run once per verb instance. Safe to call multiple times. """ if getattr(self, "_data_loaded", False): return import numpy as np import pandas as pd from hyrax.pytorch_ignite import setup_dataset # ── Build DataProvider for metadata access ──────────────────────────── self.datasets = setup_dataset(self.config) if not set(VisualizeV2.REQUIRED_DATA_GROUPS).intersection(set(self.datasets.keys())): required_keys = ", ".join(sorted(VisualizeV2.REQUIRED_DATA_GROUPS)) available_keys = ", ".join(sorted(self.datasets.keys())) or "<none>" raise RuntimeError( f"VisualizeV2 requires dataset entries {required_keys} in the data request. " f"Available: {available_keys}" ) # The primary dataset (identified by primary_id_field) holds the UMAP coordinates. _primary_ds_name = self.datasets["visualize"].primary_dataset if _primary_ds_name is None: raise RuntimeError( "No sub-dataset in data_request.visualize has a 'primary_id_field'. " "Set 'primary_id_field' on the dataset that contains the UMAP 2D coordinates." ) reduced_dim_dataset = self.datasets["visualize"].prepped_datasets[_primary_ds_name] if hasattr(reduced_dim_dataset, "__get_all__"): points_array = reduced_dim_dataset.__get_all__() else: logger.warning( "Primary dataset does not implement `__get_all__`. " "Falling back to sequential access, which may be slow." ) # noqa: E501 reduced_dim_results = [reduced_dim_dataset[i] for i in range(len(reduced_dim_dataset))] points_array = np.array([np.asarray(pt) for pt in reduced_dim_results]) # ── Build DataFrame from UMAP 2D results ───────────────────────────── df = pd.DataFrame({"x": points_array[:, 0], "y": points_array[:, 1]}) # Store references on self for downstream use self.df = df self._n_points = len(df) # ── Probe available scalar fields ───────────────────────────────────── # Call DataProvider[0] to sample the structure and filter to scalar fields only. # This drops large array/tensor fields (e.g. image, data) automatically. _sample = self.datasets["visualize"][0] self._dataset_getters = self.datasets["visualize"].dataset_getters _scalar_col_options: list[str] = [] _scalar_types = (int, float, str, bool, np.integer, np.floating) if "object_id" in _sample and isinstance(_sample["object_id"], _scalar_types): _scalar_col_options.append("object_id") for _fn, _field_dict in _sample.items(): if _fn == "object_id" or not isinstance(_field_dict, dict): continue for _field, _val in _field_dict.items(): if isinstance(_val, _scalar_types): _scalar_col_options.append(f"{_fn}.{_field}") self._scalar_col_options = _scalar_col_options # Warn if any dataset sub-config has a 'fields' restriction self._fields_restricted = any( bool(ds_conf.get("fields")) for ds_conf in self.config.get("data_request", {}).get("visualize", {}).values() if isinstance(ds_conf, dict) ) self._data_loaded = True
[docs] def _build_ui(self, **kwargs): """Build and display the Panel UI using data already loaded by ``_load_data()``.""" if getattr(self, "_keep_alive_cb", None) is not None: import contextlib with contextlib.suppress(Exception): self._keep_alive_cb.stop() self._keep_alive_cb = None import math import os import threading from concurrent.futures import ThreadPoolExecutor, as_completed import datashader as ds import holoviews as hv import matplotlib.axes import matplotlib.figure as mpl_figure import matplotlib.pyplot as plt import numpy as np import pandas as pd import panel as pn from holoviews import DynamicMap, Polygons, Rectangles, extension from holoviews.element.stats import HexTiles from holoviews.operation.datashader import rasterize from holoviews.streams import BoundsXY, Lasso, Params, RangeXY from IPython import get_ipython from matplotlib.path import Path as MplPath # ── Local aliases for data loaded by _load_data() ───────────────────── df = self.df n_points = self._n_points _scalar_col_options = self._scalar_col_options _dataset_getters = self._dataset_getters _fields_restricted = self._fields_restricted # ── Config ──────────────────────────────────────────────────────────── viz_config = self.config["visualize_v2"] target_bins = viz_config["target_bins"] buffer_factor = viz_config["buffer_factor"] plot_width = viz_config["plot_width"] plot_height = viz_config["plot_height"] cmap = viz_config["cmap"] # Sets the initial colormap; user can change it via the UI slider. max_table_rows = viz_config["max_table_rows"] num_detail_plots = viz_config["num_detail_plots"] # ── HoloViews / Panel init ──────────────────────────────────────────── # pn.extension must come before hv.extension so Panel can patch HoloViews' # comm machinery before the Bokeh backend registers its own callbacks. pn.extension("tabulator") extension("bokeh") # Hidden widget used solely for server-side keep-alive pings (see below). _heartbeat = pn.widgets.IntInput(value=0, width=0, height=0, visible=False) # ── Colormap selector ───────────────────────────────────────────────── # Keys are display names, values are matplotlib colormap identifiers. _cmap_entries = { "Viridis": "viridis", "Inferno": "inferno", "Plasma": "plasma", "Magma": "magma", "Cividis": "cividis", "Turbo": "turbo", "Blues": "Blues", "Greens": "Greens", "Purples": "Purples", "Reds": "Reds", "YlOrRd": "YlOrRd", "YlGnBu": "YlGnBu", "Spring": "spring", "Summer": "summer", "Autumn": "autumn", "Winter": "winter", } _cmap_display_names = list(_cmap_entries.keys()) _cmap_initial = next( (k for k, v in _cmap_entries.items() if v.lower() == cmap.lower()), _cmap_display_names[0], ) _cmap_slider = pn.widgets.IntSlider( name=f"Colormap: {_cmap_initial}", start=0, end=len(_cmap_display_names) - 1, value=_cmap_display_names.index(_cmap_initial), width=plot_width, ) # ── Determine initial range from data ───────────────────────────────── x_lo, x_hi = float(df["x"].min()), float(df["x"].max()) y_lo, y_hi = float(df["y"].min()), float(df["y"].max()) x_pad = (x_hi - x_lo) * 0.05 y_pad = (y_hi - y_lo) * 0.05 initial_x_range = (x_lo - x_pad, x_hi + x_pad) initial_y_range = (y_lo - y_pad, y_hi + y_pad) # ── Adaptive hexbin callback ────────────────────────────────────────── range_xy = RangeXY() _cmap_stream = Params(_cmap_slider, ["value"], rename={"value": "cmap_idx"}) # Cache the last rasterized element so colormap changes skip re-rasterization. _hex_cache: dict = {"range_key": None, "element": None} def make_hexbin(x_range, y_range, cmap_idx=0): xr = x_range if (x_range is not None and None not in x_range) else initial_x_range yr = y_range if (y_range is not None and None not in y_range) else initial_y_range range_key = (xr, yr) if _hex_cache["range_key"] != range_key or _hex_cache["element"] is None: x_pad_buf = (xr[1] - xr[0]) * buffer_factor y_pad_buf = (yr[1] - yr[0]) * buffer_factor mask = ( (df["x"] >= xr[0] - x_pad_buf) & (df["x"] <= xr[1] + x_pad_buf) & (df["y"] >= yr[0] - y_pad_buf) & (df["y"] <= yr[1] + y_pad_buf) ) df_view = df[mask] bin_size = (xr[1] - xr[0]) / target_bins hexbin = rasterize( HexTiles(df_view, kdims=["x", "y"]).redim.range(x=xr, y=yr), aggregator=ds.count(), x_sampling=bin_size, y_sampling=bin_size, dynamic=False, ) if hexbin.vdims: count_col = hexbin.vdims[0].name hexbin = hexbin.clone(hexbin.data.assign(**{count_col: np.log1p(hexbin.data[count_col])})) _hex_cache["range_key"] = range_key _hex_cache["element"] = hexbin # Cheap: apply colormap to cached element (no re-rasterization) return _hex_cache["element"].opts( hv.opts.HexTiles(cmap=_cmap_entries[_cmap_display_names[cmap_idx]]) ) dmap = DynamicMap(make_hexbin, streams=[range_xy, _cmap_stream]) # ── Plot opts ───────────────────────────────────────────────────────── plot_opts = { "cnorm": "linear", "colorbar": True, "colorbar_position": "bottom", "colorbar_opts": {"title": "log(Count + 1)"}, "width": plot_width, "height": plot_height, "xlabel": "", "ylabel": "", "title": f"Hexbin — {n_points:,} samples | ~{target_bins} bins across", "toolbar": "above", "tools": ["hover", "pan", "wheel_zoom", "reset", "box_select", "lasso_select"], "line_color": "white", "line_width": 0.5, "nonselection_alpha": 1.0, "nonselection_line_alpha": 1.0, } plot_opts.update(kwargs) plot = dmap.opts(hv.opts.HexTiles(**plot_opts, hooks=[_disable_axis_zoom])) # ── Selection streams ───────────────────────────────────────────────── bounds_stream = BoundsXY(source=dmap, bounds=(0, 0, 0, 0)) lasso_stream = Lasso(source=dmap) # ── Column selector ─────────────────────────────────────────────────── col_selector_title = pn.pane.Markdown("### Columns", margin=(0, 0)) _default_cols = ["object_id"] if "object_id" in _scalar_col_options else [] col_selector = pn.widgets.CheckBoxGroup( options=_scalar_col_options, value=_default_cols, inline=False, stylesheets=["label { font-size: 16px !important; }"], ) _fields_alert = ( pn.pane.Alert( "Additional fields may be available. Remove `fields` from the data request to see them.", alert_type="info", margin=(0, 0, 6, 0), ) if _fields_restricted else None ) # ── Selection table ─────────────────────────────────────────────────── _empty = pd.DataFrame(columns=["row_index"] + _default_cols) # 21 data rows × 35px + 30px header + 35px footer/pagination bar ≈ one row of breathing room _table_height = 21 * 35 + 30 + 35 selection_table = pn.widgets.Tabulator( _empty.copy(), pagination="remote", page_size=25, show_index=False, sizing_mode="stretch_width", height=_table_height, header_align="right", configuration={"columnDefaults": {"headerSort": True}}, disabled=True, ) table_title = pn.pane.Markdown("### Selected Points", margin=(0, 0)) # Selected subsets, accessible via verb instance self.selected_box = pd.DataFrame() self.selected_lasso = pd.DataFrame() _table_df: list[pd.DataFrame] = [pd.DataFrame()] # authoritative copy unaffected by user edits _existing_pool = getattr(self, "_selection_table_pool", None) if _existing_pool is not None: _existing_pool.shutdown(wait=False, cancel_futures=True) self._selection_table_pool = ThreadPoolExecutor(max_workers=min(8, os.cpu_count() or 1)) _pool = self._selection_table_pool def _safe_execute(fn): """Schedule a UI-mutating callable on the Panel IOLoop (thread-safe). Mirrors the pattern used in ``_update_detail_panes``. Call this whenever a widget attribute is set from a background thread so the mutation goes through Bokeh's document lock instead of racing against it. """ if pn.state.curdoc is not None: pn.state.execute(fn) else: fn() def _make_fetcher(active_cols: list[str]): """Return a row-fetching callable for the given active columns.""" top_level = [c for c in active_cols if c == "object_id"] nested = [(c.split(".", 1)[0], c.split(".", 1)[1]) for c in active_cols if "." in c] def _fetch_row(idx): row: dict = {"row_index": idx} if top_level: row["object_id"] = self.datasets["visualize"][idx].get("object_id") for fn, field in nested: row[f"{fn}.{field}"] = _dataset_getters[fn][field](idx) return row return _fetch_row def _update_table( sel: pd.DataFrame, progress_callback=None, should_abort=None, ) -> None: active_cols: list[str] = col_selector.value if sel.empty: computed = pd.DataFrame(columns=["row_index"] + [c for c in active_cols]) _table_df[0] = computed.copy() def _apply_empty(df=computed): selection_table.value = df table_title.object = "### Selected Points" _safe_execute(_apply_empty) return capped_indices = list(sel.index[:max_table_rows]) if not active_cols: computed = pd.DataFrame({"row_index": capped_indices}) _table_df[0] = computed.copy() _truncated = len(sel) > max_table_rows _title = ( f"### Selected Points — {len(sel):,} (showing first {max_table_rows:,} — export to see all)" # noqa: E501 if _truncated else f"### Selected Points — {len(sel):,}" ) def _apply_no_cols(df=computed, t=_title): selection_table.value = df table_title.object = t _safe_execute(_apply_no_cols) return _fetch_row = _make_fetcher(active_cols) total = len(capped_indices) _progress_step = max(1, total // 100) # Bail out before submitting any futures if already stale if should_abort and should_abort(): return futures = {_pool.submit(_fetch_row, idx): i for i, idx in enumerate(capped_indices)} rows: list = [None] * total for done, future in enumerate(as_completed(futures)): if should_abort and should_abort(): # Cancel all queued-but-not-started futures to drain the pool backlog for f in futures: f.cancel() return rows[futures[future]] = future.result() if progress_callback and (done % _progress_step == 0 or done == total - 1): progress_callback(int((done + 1) / total * 100)) display_df = pd.DataFrame(rows).reset_index(drop=True) _table_df[0] = display_df.copy() _truncated = len(sel) > max_table_rows _title = ( f"### Selected Points — {len(sel):,} (showing first {max_table_rows:,} — export to see all)" # noqa: E501 if _truncated else f"### Selected Points — {len(sel):,}" ) def _apply_table(df=display_df, t=_title): selection_table.value = df table_title.object = t _safe_execute(_apply_table) # Re-fetch columns whenever the selector changes def _on_col_selector_change(_e): sel = self.selected_box if not self.selected_box.empty else self.selected_lasso if sel.empty: _update_table(sel) return _gen[0] += 1 my_gen = _gen[0] def _run(): _safe_execute(lambda: setattr(_progress_bar, "value", 0)) _update_table( sel, progress_callback=lambda pct: _safe_execute( lambda p=pct: setattr(_progress_bar, "value", p) ), should_abort=lambda: _gen[0] != my_gen, ) _safe_execute(lambda: setattr(_progress_bar, "value", 0)) threading.Thread(target=_run, daemon=True).start() col_selector.param.watch(_on_col_selector_change, "value") # Update slider label whenever the colormap changes def _on_cmap_slide(event): _cmap_slider.name = f"Colormap: {_cmap_display_names[event.new]}" _cmap_slider.param.watch(_on_cmap_slide, "value") # ── Detail panes ───────────────────────────────────────────────────── _total_width = plot_width _detail_pane_width = (_total_width - (20 * (num_detail_plots - 1))) // num_detail_plots _prepped_datasets = self.datasets["visualize"].prepped_datasets _tab_names = list(_prepped_datasets.keys()) # Consistent subplot margins applied to every figure shown in a detail pane. # Using fixed subplots_adjust (instead of tight_layout) together with tight=False # in the Matplotlib pane ensures all plots occupy identical canvas geometry # regardless of content type. _detail_layout = dict(left=0.03, right=0.97, bottom=0.03, top=0.9) def _make_placeholder_fig(): fig = mpl_figure.Figure(figsize=(3, 3)) ax = fig.add_subplot(111) ax.set_facecolor("#f0f0f0") ax.text( 0.5, 0.5, "No selection", ha="center", va="center", transform=ax.transAxes, color="#aaaaaa", fontsize=11, ) ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_edgecolor("#cccccc") fig.subplots_adjust(**_detail_layout) return fig def _make_text_fig(text, index): fig = mpl_figure.Figure(figsize=(3, 3)) ax = fig.add_subplot(111) ax.set_facecolor("#f8f8f8") ax.text( 0.05, 0.95, text, ha="left", va="top", transform=ax.transAxes, fontsize=12, family="monospace", wrap=True, ) ax.set_xticks([]) ax.set_yticks([]) ax.set_title(f"Index: {index}" if index is not None else "No selection") for spine in ax.spines.values(): spine.set_edgecolor("#cccccc") fig.subplots_adjust(**_detail_layout) return fig def _make_pane_row(): return [ pn.pane.Matplotlib( _make_placeholder_fig(), tight=False, format="png", width=_detail_pane_width, height=_detail_pane_width, ) for _ in range(num_detail_plots) ] detail_panes = [_make_pane_row() for _ in _tab_names] _tab_active_css = """ .bk-tab.bk-active { background-color: #d6eaf8 !important; color: #000000 !important; } """ detail_tabs = pn.Tabs( *[(name, pn.Row(*panes, margin=0)) for name, panes in zip(_tab_names, detail_panes)], width=_total_width, stylesheets=[_tab_active_css], ) def _make_detail_fig(index, tab_index): """Create a matplotlib figure for one detail pane. Accepts whatever ``dataset.display()`` returns and normalises it to a ``matplotlib.figure.Figure`` before handing it to the Matplotlib pane: - ``Figure`` → used directly - ``Axes`` → parent figure extracted - ``numpy.ndarray`` → imshow'd into a new Figure - anything else → rendered as text via ``_make_text_fig`` Any pyplot-managed figures accidentally created inside ``display()`` are closed after the call so they don't leak into the notebook output. """ if index is None: return _make_placeholder_fig() dataset_name = _tab_names[tab_index] dataset = _prepped_datasets[dataset_name] if not callable(getattr(dataset, "display", None)): fig = _make_text_fig(str(dataset[index]), index) else: _fignums_before = set(plt.get_fignums()) try: result = dataset.display(index) finally: # Close any pyplot figures created as a side-effect of display() for _fn in set(plt.get_fignums()) - _fignums_before: plt.close(_fn) if isinstance(result, mpl_figure.Figure): fig = result elif isinstance(result, matplotlib.axes.Axes): fig = result.figure elif isinstance(result, np.ndarray): fig = mpl_figure.Figure(figsize=(3, 3)) ax = fig.add_subplot(111) ax.imshow(result) ax.set_xticks([]) ax.set_yticks([]) ax.set_title(f"Index: {index}") else: fig = _make_text_fig(str(result), index) fig.set_size_inches(3, 3) fig.subplots_adjust(**_detail_layout) return fig def _update_detail_panes(indices): def _apply(): for tab_index, tab_panes in enumerate(detail_panes): for i, pane in enumerate(tab_panes): pane.object = _make_detail_fig(indices[i] if i < len(indices) else None, tab_index) if pn.state.curdoc is not None: pn.state.execute(_apply) else: _apply() # ── Pagination controls ─────────────────────────────────────────────── _page_state: dict = {"page": 0, "indices": []} btn_first = pn.widgets.Button(name="|◀", width=44, button_type="default", disabled=True) btn_prev = pn.widgets.Button(name="◀", width=44, button_type="default", disabled=True) page_numbers_row = pn.Row(margin=(0, 4)) btn_next = pn.widgets.Button(name="▶", width=44, button_type="default", disabled=True) btn_last = pn.widgets.Button(name="▶|", width=44, button_type="default", disabled=True) def _total_pages() -> int: return max(1, math.ceil(len(_page_state["indices"]) / num_detail_plots)) def _refresh_pagination_widgets() -> None: page = _page_state["page"] total = _total_pages() window_start = max(0, page - 3) window_end = min(total - 1, page + 3) page_btns = [] for p in range(window_start, window_end + 1): is_current = p == page btn = pn.widgets.Button( name=str(p + 1), width=44, button_type="primary" if is_current else "default", disabled=is_current, ) if not is_current: btn.on_click(lambda _e, _p=p: _go_to_page(_p)) page_btns.append(btn) page_numbers_row.objects = page_btns btn_first.disabled = page == 0 btn_prev.disabled = page == 0 btn_next.disabled = page >= total - 1 btn_last.disabled = page >= total - 1 def _render_page() -> None: page = _page_state["page"] start = page * num_detail_plots _update_detail_panes(list(_page_state["indices"][start : start + num_detail_plots])) _refresh_pagination_widgets() def _go_to_page(new_page: int) -> None: _page_state["page"] = max(0, min(new_page, _total_pages() - 1)) _render_page() btn_first.on_click(lambda _e: _go_to_page(0)) btn_prev.on_click(lambda _e: _go_to_page(_page_state["page"] - 1)) btn_next.on_click(lambda _e: _go_to_page(_page_state["page"] + 1)) btn_last.on_click(lambda _e: _go_to_page(_total_pages() - 1)) pagination_row = pn.Row( pn.Spacer(), btn_first, btn_prev, page_numbers_row, btn_next, btn_last, pn.Spacer(), width=_total_width, align="center", ) _refresh_pagination_widgets() # initialise the page-number buttons # ── Selection overlay callback ──────────────────────────────────────── _progress_bar = pn.indicators.Progress(width=_total_width, value=0, max=100, bar_color="info") _gen: list[int] = [0] _prev = {"bounds": (0, 0, 0, 0), "geometry": None} def selection_overlay(bounds, geometry): try: return _selection_overlay_impl(bounds, geometry) except Exception: logger.error("selection_overlay raised an exception", exc_info=True) return Rectangles([]).opts(apply_ranges=False) * Polygons([]).opts(apply_ranges=False) def _selection_overlay_impl(bounds, geometry): bounds_changed = bounds is not None and bounds != _prev["bounds"] geometry_changed = geometry is not _prev["geometry"] if bounds is not None: _prev["bounds"] = bounds _prev["geometry"] = geometry box_el = Rectangles([]).opts(apply_ranges=False) lasso_el = Polygons([]).opts(apply_ranges=False) if bounds_changed or (geometry_changed and geometry is not None): export_btn.icon = "" export_btn.name = "Export table to selected_points" export_btn.disabled = True export_all_btn.icon = "" export_all_btn.name = "Export all to selected_points" export_all_btn.disabled = True if bounds_changed: self.selected_lasso = pd.DataFrame() x0, y0, x1, y1 = bounds if x0 != x1 and y0 != y1: mask = (df["x"] >= x0) & (df["x"] <= x1) & (df["y"] >= y0) & (df["y"] <= y1) self.selected_box = df[mask] if not self.selected_box.empty: box_el = Rectangles([(x0, y0, x1, y1)]).opts( fill_alpha=0.1, fill_color="cyan", line_color="cyan", line_width=1.5, apply_ranges=False, ) else: self.selected_box = pd.DataFrame() elif geometry_changed and geometry is not None: self.selected_box = pd.DataFrame() if isinstance(geometry, dict): coords = np.array(geometry["coordinates"][0]) else: coords = np.asarray(geometry) if len(coords) >= 3: x_min, y_min = coords[:, 0].min(), coords[:, 1].min() x_max, y_max = coords[:, 0].max(), coords[:, 1].max() bbox_mask = ( (df["x"] >= x_min) & (df["x"] <= x_max) & (df["y"] >= y_min) & (df["y"] <= y_max) ) df_candidates = df[bbox_mask] path = MplPath(coords) inside = path.contains_points(df_candidates[["x", "y"]].values) self.selected_lasso = df_candidates[inside] if not self.selected_lasso.empty: lasso_el = Polygons([{("x", "y"): coords}]).opts( fill_alpha=0.1, fill_color="cyan", line_color="cyan", line_width=1.5, apply_ranges=False, ) else: self.selected_lasso = pd.DataFrame() # Spawn background thread for slow work; return overlay immediately if bounds_changed or (geometry_changed and geometry is not None): sel = self.selected_box if not self.selected_box.empty else self.selected_lasso _gen[0] += 1 my_gen = _gen[0] if sel.empty: # Empty selection is fast — handle inline to avoid thread overhead _update_table(sel) _page_state["page"] = 0 _page_state["indices"] = [] _render_page() export_btn.disabled = False export_all_btn.disabled = False else: threading.Thread(target=_do_selection_work, args=(sel, my_gen), daemon=True).start() return box_el * lasso_el def _do_selection_work(sel: "pd.DataFrame", my_gen: int) -> None: try: def _start_loading(): _progress_bar.value = 0 export_btn.name = "Loading selected data..." _safe_execute(_start_loading) if _gen[0] != my_gen: return # Render detail panes immediately — indices are available now, # before the (slower) table fetch begins. try: _page_state["page"] = 0 _page_state["indices"] = list(sel.index) _render_page() except Exception: pass # never let detail-pane errors break status cleanup if _gen[0] != my_gen: return _update_table( sel, progress_callback=lambda pct: _safe_execute( lambda p=pct: setattr(_progress_bar, "value", p) ), should_abort=lambda: _gen[0] != my_gen, ) def _finish_loading(): export_btn.name = "Export table to selected_points" export_btn.disabled = False export_all_btn.disabled = False _safe_execute(_finish_loading) finally: _safe_execute(lambda: setattr(_progress_bar, "value", 0)) selection_dmap = DynamicMap(selection_overlay, streams=[bounds_stream, lasso_stream]) # ── Export button ───────────────────────────────────────────────────── export_btn = pn.widgets.Button( name="Export table to selected_points", button_type="primary", width=300, icon="", ) def _on_export(event): sel = self.selected_box if not self.selected_box.empty else self.selected_lasso if sel.empty: return exported = _table_df[0] if not _table_df[0].empty else sel if (_ipy := get_ipython()) is not None: _ipy.user_ns["selected_points"] = exported export_btn.icon = "check-lg" export_btn.name = f"Exported {len(exported):,} rows to selected_points" export_btn.on_click(_on_export) # ── Export All button ────────────────────────────────────────── export_all_btn = pn.widgets.Button( name="Export all to selected_points", button_type="warning", width=300, icon="", disabled=True, ) def _do_export_all(sel: "pd.DataFrame") -> None: active_cols: list[str] = col_selector.value if not active_cols: if (_ipy := get_ipython()) is not None: _ipy.user_ns["selected_points"] = pd.DataFrame({"row_index": list(sel.index)}) n = len(sel) def _done_early(count=n): export_all_btn.icon = "check-lg" export_all_btn.name = f"Exported {count:,} rows to selected_points" export_all_btn.disabled = False _progress_bar.value = 0 _safe_execute(_done_early) return _fetch_row = _make_fetcher(active_cols) all_indices = list(sel.index) total = len(all_indices) # Submit futures in chunks to avoid allocating millions of Future objects at once, # which would exhaust RAM for large selections. _chunk_size = 1000 rows: list = [None] * total for chunk_start in range(0, total, _chunk_size): chunk_end = min(chunk_start + _chunk_size, total) chunk_indices = all_indices[chunk_start:chunk_end] chunk_futures = { _pool.submit(_fetch_row, idx): chunk_start + i for i, idx in enumerate(chunk_indices) } for future in as_completed(chunk_futures): rows[chunk_futures[future]] = future.result() pct = int(chunk_end / total * 100) _safe_execute(lambda p=pct: setattr(_progress_bar, "value", p)) full_df = pd.DataFrame(rows).reset_index(drop=True) if (_ipy := get_ipython()) is not None: _ipy.user_ns["selected_points"] = full_df n = len(full_df) def _done(count=n): export_all_btn.icon = "check-lg" export_all_btn.name = f"Exported {count:,} rows to selected_points" export_all_btn.disabled = False _progress_bar.value = 0 _safe_execute(_done) def _on_export_all(event): sel = self.selected_box if not self.selected_box.empty else self.selected_lasso if sel.empty: return export_all_btn.disabled = True export_all_btn.icon = "" export_all_btn.name = f"Fetching all {len(sel):,} rows..." _progress_bar.value = 0 threading.Thread(target=_do_export_all, args=(sel,), daemon=True).start() export_all_btn.on_click(_on_export_all) # ── Layout ──────────────────────────────────────────────────────────── combined = (plot * selection_dmap).opts( hv.opts.Rectangles(apply_ranges=False), hv.opts.Polygons(apply_ranges=False), ) _col_selector_col = pn.Column( col_selector_title, *([_fields_alert] if _fields_alert else []), pn.Column( col_selector, height=_table_height // 2, scroll=True, ), width=280, ) _table_col = pn.Column( table_title, selection_table, pn.Spacer(height=6), pn.Row(export_btn, pn.Spacer(width=10), export_all_btn), sizing_mode="stretch_width", ) pane = pn.Column( combined, _cmap_slider, _progress_bar, pn.Row( _col_selector_col, pn.Spacer(width=20), _table_col, width=_total_width, ), pn.Spacer(height=10), detail_tabs, pagination_row, _heartbeat, # invisible; kept in document so keep-alive pings reach the client ) # ── Server-side keep-alive ──────────────────────────────────────────── # Browsers throttle JS timers in background tabs, which can suppress # Bokeh's client-side heartbeat and kill the WebSocket. Driving the ping # from the Tornado IOLoop (server side) is immune to that throttling: # each tick mutates a Param in the served document, pushing a tiny # WebSocket message and resetting the connection's idle clock. def _keep_alive(): _heartbeat.value = (_heartbeat.value + 1) % 1_000_000 self._keep_alive_cb = pn.state.add_periodic_callback(_keep_alive, period=20_000) try: from IPython.display import display display(pane) print("Tip: if the UI stops responding, call viz.restart_ui() in a new cell to reconnect.") print("viz = h.visualize_v2()\nviz.restart_ui()") except ImportError: logger.warning("Couldn't find IPython display environment. Skipping display step.")
[docs] def get_selected_df(self): """Return the current selection as a DataFrame.""" import pandas as pd sel = self.selected_box if not self.selected_box.empty else self.selected_lasso if sel.empty: return pd.DataFrame() return sel