Source code for hyrax.verbs.save_to_database

import logging
import time
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Union

import numpy as np
from tensorboardX import SummaryWriter

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
[docs] @hyrax_verb
[docs] class SaveToDatabase(Verb): """Verb to insert inference results into a vector database index for fast similarity search."""
[docs] cli_name = "save_to_database"
[docs] add_parser_kwargs = {}
@staticmethod
[docs] def setup_parser(parser: ArgumentParser): """Stub of parser setup""" parser.add_argument( "-i", "--input-dir", type=str, required=False, help="Directory containing inference results to index.", ) parser.add_argument( "-o", "--output-dir", type=str, required=False, help="Directory of existing vector database, if adding more vectors.", )
[docs] def run_cli(self, args: Namespace | None = None): """Stub CLI implementation""" logger.info("Creating vector db index from cli") if args is None: raise RuntimeError("Run CLI called with no arguments.") return self.run(input_dir=args.input_dir, output_dir=args.output_dir)
[docs] def run(self, input_dir: Union[Path, str] | None = None, output_dir: Union[Path, str] | None = None): """Insert inference results into vector database. Parameters ---------- input_dir : str or Path, Optional The directory containing the inference results. output_dir : str or Path, Optional The directory where the vector database is stored. If None, a new directory will be created. If specified, it can point to either an empty directory or a directory containing an existing vector database. If the latter, the database will be updated with the new vectors. """ from copy import deepcopy from tqdm import tqdm from hyrax.config_utils import ( create_results_dir, find_most_recent_results_dir, log_runtime_config, ) from hyrax.data_sets.inference_dataset import InferenceDataSet from hyrax.vector_dbs.vector_db_factory import vector_db_factory config = deepcopy(self.config) # Attempt to find the directory containing inference results. Check for # the --input-dir argument first, then check the config file for # vector_db.infer_results_dir, and finally check for the most recent # results directory. infer_results_dir = None if input_dir is not None: infer_results_dir = input_dir elif config["vector_db"]["infer_results_dir"]: infer_results_dir = config["vector_db"]["infer_results_dir"] else: infer_results_dir = find_most_recent_results_dir(config, "infer") if infer_results_dir is None: raise RuntimeError("Must define infer_results_dir in the [vector_db] section of hyrax config.") inference_results_path = Path(infer_results_dir).resolve() if not inference_results_path.is_dir(): raise RuntimeError(f"Input directory {inference_results_path} does not exist.") # Create an instance of the InferenceDataSet inference_data_set = InferenceDataSet(config, inference_results_path) # Get the vector db output directory by using the --output-dir parameter or # config value or creating a new directory, in that order. vector_db_dir = Path() if output_dir is not None: vector_db_dir = output_dir elif config["vector_db"]["vector_db_dir"]: vector_db_dir = config["vector_db"]["vector_db_dir"] else: vector_db_dir = create_results_dir(config, "vector-db") vector_db_path = Path(vector_db_dir).resolve() if not vector_db_path.is_dir(): raise RuntimeError(f"Database directory {str(vector_db_path)} does not exist.") logger.info(f"Saving vector database at {vector_db_dir}") # Create an instance of the vector database to insert into vector_db = vector_db_factory(config, context={"results_dir": str(vector_db_path)}) if vector_db: vector_db.create() else: raise RuntimeError( "No vector database configured. " "Please specify a supported vector db in the ['vector_db']['name'] " "section of the hyrax config." ) # Log the config with updated values for the input and output directories. config["vector_db"]["infer_results_dir"] = str(inference_results_path) config["vector_db"]["vector_db_dir"] = str(vector_db_path) log_runtime_config(config, vector_db_path) # Create a tensorboardX logger for metrics tensorboardx_logger = SummaryWriter(log_dir=vector_db_path) # Use the batch_index to get the list of batches. batches = np.unique(inference_data_set.batch_index["batch_num"]) logger.debug(f"Number of inference result batches to index: {len(batches)}.") total_insertion_time = 0.0 batch_count = 0 for batch in tqdm(batches): # Get all the indices where inference_data_set.batch_index['batch_num'] == batch index_mask = inference_data_set.batch_index["batch_num"] == batch # Get the ids of the data in this batch file ids = inference_data_set.batch_index["id"][index_mask] # Retrieve the vectors from the batch file using the ids. We use the # ids here so that we only have to open one file to get the vectors. inference_data = inference_data_set._load_from_batch_file(batch, ids) # Flatten the vectors and turn them into a list of np.arrays. vectors = list(inference_data["tensor"].reshape(len(inference_data["tensor"]), -1)) # Time the vector database insertion start_time = time.time() vector_db.insert(ids=list(inference_data["id"]), vectors=vectors) insertion_time = time.time() - start_time # Log insertion metrics to Tensorboard batch_count += 1 total_insertion_time += insertion_time vectors_inserted = len(vectors) tensorboardx_logger.add_scalar("vector_db/batch_insertion_time", insertion_time, batch_count) tensorboardx_logger.add_scalar("vector_db/vectors_per_batch", vectors_inserted, batch_count) rate = vectors_inserted / insertion_time if insertion_time > 0 else 0 tensorboardx_logger.add_scalar("vector_db/insertion_rate_vectors_per_second", rate, batch_count) logger.debug( f"Batch {batch}: Inserted {vectors_inserted} vectors in {insertion_time:.3f}s " f"({vectors_inserted / insertion_time:.1f} vectors/sec)" ) # Log total insertion metrics tensorboardx_logger.add_scalar("vector_db/total_insertion_time", total_insertion_time, 1) tensorboardx_logger.add_scalar("vector_db/total_batches", batch_count, 1) avg_time = total_insertion_time / batch_count if batch_count > 0 else 0 tensorboardx_logger.add_scalar("vector_db/average_batch_insertion_time", avg_time, 1) # Close the tensorboard logger tensorboardx_logger.close() logger.info( f"Vector database insertion complete. Total time: {total_insertion_time:.3f}s " f"for {batch_count} batches" )