import logging
import time
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Union
import numpy as np
from .verb_registry import Verb, hyrax_verb
[docs]
logger = logging.getLogger(__name__)
@hyrax_verb
[docs]
class SaveToDatabase(Verb):
"""Verb to insert inference results into a vector database index for fast
similarity search."""
[docs]
cli_name = "save_to_database"
[docs]
description = "Insert inference results into vector database."
@staticmethod
[docs]
def setup_parser(parser: ArgumentParser):
[docs]
"""Stub of parser setup"""
parser.add_argument(
"-i",
"--input-dir",
type=str,
required=False,
help="Directory containing inference results to index.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
required=False,
help="Directory of existing vector database, if adding more vectors.",
)
[docs]
def run_cli(self, args: Namespace | None = None):
"""Stub CLI implementation"""
logger.info("Creating vector db index from cli")
if args is None:
raise RuntimeError("Run CLI called with no arguments.")
return self.run(input_dir=args.input_dir, output_dir=args.output_dir)
[docs]
def run(self, input_dir: Union[Path, str] | None = None, output_dir: Union[Path, str] | None = None):
"""Insert inference results into vector database.
Parameters
----------
input_dir : str or Path, Optional
The directory containing the inference results.
output_dir : str or Path, Optional
The directory where the vector database is stored. If None, a new directory
will be created. If specified, it can point to either an empty directory
or a directory containing an existing vector database. If the latter, the
database will be updated with the new vectors.
"""
from copy import deepcopy
from tqdm import tqdm
from hyrax.config_utils import (
create_results_dir,
find_most_recent_results_dir,
log_runtime_config,
)
from hyrax.datasets.result_factories import load_results_dataset
from hyrax.tensorboardx_logger import (
close_tensorboard_logger,
get_tensorboard_logger,
init_tensorboard_logger,
)
from hyrax.vector_dbs.vector_db_factory import vector_db_factory
config = deepcopy(self.config)
# Attempt to find the directory containing inference results. Check for
# the --input-dir argument first, then check the config file for
# vector_db.infer_results_dir, and finally check for the most recent
# results directory.
infer_results_dir = None
if input_dir is not None:
infer_results_dir = input_dir
elif config["vector_db"]["infer_results_dir"]:
infer_results_dir = config["vector_db"]["infer_results_dir"]
else:
infer_results_dir = find_most_recent_results_dir(config, "infer")
if infer_results_dir is None:
raise RuntimeError("Must define infer_results_dir in the [vector_db] section of hyrax config.")
inference_results_path = Path(infer_results_dir).resolve()
if not inference_results_path.is_dir():
raise RuntimeError(f"Input directory {inference_results_path} does not exist.")
# Create an instance of the results dataset (auto-detects Lance vs .npy)
inference_data_set = load_results_dataset(config, inference_results_path, verb="infer")
# Get the vector db output directory by using the --output-dir parameter or
# config value or creating a new directory, in that order.
vector_db_dir = Path()
if output_dir is not None:
vector_db_dir = output_dir
elif config["vector_db"]["vector_db_dir"]:
vector_db_dir = config["vector_db"]["vector_db_dir"]
else:
vector_db_dir = create_results_dir(config, "vector-db")
vector_db_path = Path(vector_db_dir).resolve()
if not vector_db_path.is_dir():
raise RuntimeError(f"Database directory {str(vector_db_path)} does not exist.")
logger.info(f"Saving vector database at {vector_db_dir}")
# Create an instance of the vector database to insert into
vector_db = vector_db_factory(config, context={"results_dir": str(vector_db_path)})
if vector_db:
vector_db.create()
else:
raise RuntimeError(
"No vector database configured. "
"Please specify a supported vector db in the ['vector_db']['name'] "
"section of the hyrax config."
)
# Log the config with updated values for the input and output directories.
config["vector_db"]["infer_results_dir"] = str(inference_results_path)
config["vector_db"]["vector_db_dir"] = str(vector_db_path)
log_runtime_config(config, vector_db_path)
# Create a tensorboardX logger for metrics
init_tensorboard_logger(log_dir=vector_db_path)
tensorboardx_logger = get_tensorboard_logger()
# Process data in batches
batch_size = config["data_loader"]["batch_size"]
total_items = len(inference_data_set)
num_batches = int(np.ceil(total_items / batch_size))
logger.debug(f"Number of inference result batches to index: {num_batches}.")
total_insertion_time = 0.0
batch_count = 0
for batch_idx in tqdm(range(num_batches)):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, total_items)
indices = np.arange(start_idx, end_idx)
# Get the vectors and ids for this batch
vectors = []
ids = []
for idx in indices:
vector = np.asarray(inference_data_set[idx]).flatten()
vectors.append(vector)
ids.append(inference_data_set.get_object_id(idx))
# Time the vector database insertion
start_time = time.time()
vector_db.insert(ids=ids, vectors=vectors)
insertion_time = time.time() - start_time
# Log insertion metrics to Tensorboard
batch_count += 1
total_insertion_time += insertion_time
vectors_inserted = len(vectors)
tensorboardx_logger.add_scalar("vector_db/batch_insertion_time", insertion_time, batch_count)
tensorboardx_logger.add_scalar("vector_db/vectors_per_batch", vectors_inserted, batch_count)
rate = vectors_inserted / insertion_time if insertion_time > 0 else 0
tensorboardx_logger.add_scalar("vector_db/insertion_rate_vectors_per_second", rate, batch_count)
logger.debug(
f"Batch {batch_idx}: Inserted {vectors_inserted} vectors in {insertion_time:.3f}s "
f"({rate:.1f} vectors/sec)"
)
# Log total insertion metrics
tensorboardx_logger.add_scalar("vector_db/total_insertion_time", total_insertion_time, 1)
tensorboardx_logger.add_scalar("vector_db/total_batches", batch_count, 1)
avg_time = total_insertion_time / batch_count if batch_count > 0 else 0
tensorboardx_logger.add_scalar("vector_db/average_batch_insertion_time", avg_time, 1)
# Close the tensorboard logger
close_tensorboard_logger()
logger.info(
f"Vector database insertion complete. Total time: {total_insertion_time:.3f}s "
f"for {batch_count} batches"
)