Source code for hyrax.datasets.lancedb_dataset

from collections import OrderedDict
from pathlib import Path
from types import MethodType

from hyrax.datasets.dataset_registry import HyraxDataset

[docs] _ROW_CACHE_SIZE = 16
[docs] class LanceDBDataset(HyraxDataset): """A minimal Hyrax wrapper around a LanceDB table."""
[docs] def __init__(self, config: dict, data_location: Path | str | None = None): if data_location is None: raise ValueError("A `data_location` must be provided.") try: import lancedb except ImportError as err: raise ImportError( "LanceDBDataset requires the `lancedb` package. Install it with `pip install lancedb`." ) from err settings = config["data_set"]["LanceDBDataset"]
[docs] self.data_location = str(data_location)
[docs] self.table_name = settings["table_name"]
[docs] self.connect_kwargs = settings["connect_kwargs"]
[docs] self.open_table_kwargs = settings["open_table_kwargs"]
[docs] self.db = lancedb.connect(self.data_location, **self.connect_kwargs)
self.table_name = self._resolve_table_name(self.table_name)
[docs] self.table = self.db.open_table(self.table_name, **self.open_table_kwargs)
[docs] self.lance_dataset = self.table.to_lance()
[docs] self._row_cache: OrderedDict = OrderedDict()
self._register_getters() super().__init__(config)
[docs] def _all_available_fields(self) -> list[str]: return list(self.table.schema.names)
[docs] def _get_row(self, idx: int): """Return the PyArrow record-batch for *idx*, using a small FIFO row cache. Caching avoids redundant ``lance_dataset.take`` calls when multiple ``get_<field>`` accessors are invoked for the same sample index, which is the common pattern when DataProvider resolves all fields for a single item. The cache holds at most ``_ROW_CACHE_SIZE`` rows; the oldest entry is evicted once that limit is reached. """ if idx not in self._row_cache: if len(self._row_cache) >= _ROW_CACHE_SIZE: self._row_cache.popitem(last=False) self._row_cache[idx] = self.lance_dataset.take([idx]) return self._row_cache[idx]
[docs] def _resolve_table_name(self, configured_table_name) -> str: if isinstance(configured_table_name, str) and configured_table_name: return configured_table_name table_names = self.db.table_names() if len(table_names) == 1: return table_names[0] available_tables = ", ".join(table_names) if len(table_names) > 0 else "(none)" raise RuntimeError( "LanceDBDataset could not infer a table to open because `table_name` is unset " "and the database does not have exactly one table. " "Set `config['data_set']['LanceDBDataset']['table_name']` " f"to one of: {available_tables}" )
[docs] def _register_getters(self) -> None: def _make_getter(field_name: str): def getter(self, idx, _field_name=field_name): row = self._get_row(int(idx)) return row[_field_name][0].as_py() return getter for field_name in self._all_available_fields(): method_name = f"get_{field_name}" if not hasattr(self, method_name): setattr(self, method_name, MethodType(_make_getter(field_name), self))
[docs] def __len__(self) -> int: return self.table.count_rows()