Source code for hyrax.vector_dbs.qdrantdb_impl

import os
import uuid
from typing import Union

import numpy as np
from qdrant_client import QdrantClient, models

from hyrax.vector_dbs.vector_db_interface import VectorDB


[docs] class QdrantDB(VectorDB): """Implementation of the VectorDB interface using Qdrant as the backend."""
[docs] def __init__(self, config, context): super().__init__(config, context)
[docs] self.client = None
[docs] self.collection_size = 0
[docs] def _convert_id_to_uuid(self, id: Union[str, int]) -> str: """Convert an id to a UUID string using the OID namespace.""" return uuid.uuid5(uuid.NAMESPACE_OID, str(id)).urn
[docs] def connect(self): """Connect to the Qdrant database and return an instance of the client.""" # Results_dir is the directory where the Qdrant database is stored. results_dir = self.context["results_dir"] self.client = QdrantClient(path=results_dir) collections = self.client.get_collections().collections if len(collections): self.collection_name = collections[0].name else: self.collection_name = None return self.client
[docs] def create(self): """Create a new Qdrant database""" if self.client is None: self.connect() # We'll get the number of collection that are in the db, but for now # we follow the advice of the documentation, and restrict the database # to a single collection. # https://qdrant.tech/documentation/concepts/collections/#setting-up-multitenancy self.collection_index = len(self.client.get_collections().collections) # Note: Qdrant has an internal definition of "shard" that is different than # what is currently used by Hyrax (specifically ChromaDB). Here we set # shard_number to CPU core count * 2. Qdrant docs advocate starting with 12: # https://qdrant.tech/documentation/guides/distributed_deployment/#choosing-the-right-number-of-shards self.collection_name = f"shard_{self.collection_index}" created_collection = None if not self.client.collection_exists(self.collection_name): created_collection = self.client.create_collection( collection_name=self.collection_name, vectors_config=models.VectorParams( #! This stinks - we should just check the size of the data #! when we call `save_to_database` and then set this automatically #! as a parameter in self.context["blah"] or something. size=self.config["vector_db"]["qdrant"]["vector_size"], distance=models.Distance.EUCLID, on_disk=True, ), shard_number=os.cpu_count() * 2, ) if not created_collection: raise RuntimeError(f"Failed to create collection {self.collection_name} in Qdrant.") self.collection_size = self.client.count(collection_name=self.collection_name, exact=True) return self.collection_name
[docs] def insert(self, ids: list[Union[str, int]], vectors: list[np.ndarray]): """Insert a batch of vectors into the Qdrant database. Parameters ---------- ids : list[Union[str, int]] The ids to associate with the vectors vectors : list[np.ndarray] The vectors to insert into the database """ if self.client is None: self.connect() expected_size = self.config["vector_db"]["qdrant"]["vector_size"] for idx, vector in enumerate(vectors): if len(vector) != expected_size: raise ValueError( f"Vector at index {idx} has size {len(vector)}, but expected size is {expected_size}." ) uuids = [self._convert_id_to_uuid(i) for i in ids] # Insert data into the collection self.client.upsert( collection_name=self.collection_name, points=models.Batch( ids=uuids, vectors=vectors, payloads=[{"id": id} for id in ids], ), ) # Update the collection size after insertion self.collection_size = self.client.count(collection_name=self.collection_name, exact=True).count return self.collection_size
[docs] def search_by_id(self, id: Union[str, int], k: int = 1) -> dict[int, list[Union[str, int]]]: """Get the ids of the k nearest neighbors for a given id in the database. Qdrant will exclude the id itself from the results, thus we first retrieve the vector for a given id, and then use that vector to find the k nearest neighbors. Parameters ---------- id : Union[str, int] The id of the vector in the database for which we want to find the k nearest neighbors k : int, optional The number of nearest neighbors to return, by default 1, return only the closest neighbor Returns ------- dict[int, list[Union[str, int]]] Dictionary with input vector id as the key and the ids of the k nearest neighbors as the value. """ if self.client is None: self.connect() # Retrieve the vector for the given id query_vector = self.get_by_id(id) # Find the k nearest neighbors for that vector res = {id: self._query_by_vector(query_vector[id], k)} return res
[docs] def search_by_vector( self, vectors: Union[np.ndarray, list[np.ndarray]], k: int = 1 ) -> dict[int, list[Union[str, int]]]: """Get the ids of the k nearest neighbors for a given vector. Parameters ---------- vectors : Union[np.array, list[np.ndarray]] The one or more vectors to use when searching for nearest neighbors k : int, optional The number of nearest neighbors to return, by default 1, return only the closest neighbor Returns ------- dict[int, list[Union[str, int]]] Dictionary with input vector index as the key and the ids of the k nearest neighbors as the value. """ if self.client is None: self.connect() # If a single vector is provided, convert it to a list if isinstance(vectors, np.ndarray): vectors = [vectors] # Find the k nearest neighbors for the provided vector res = {i: self._query_by_vector(v, k) for i, v in enumerate(vectors)} return res
[docs] def _query_by_vector(self, vector: np.ndarray, k: int = 1) -> list[str]: """Query the Qdrant database for the k nearest neighbors of a given vector.""" query_results = self.client.query_points( collection_name=self.collection_name, query=models.NearestQuery(nearest=vector), search_params=models.SearchParams(), limit=k, ) return [point.payload["id"] for point in query_results.points]
[docs] def get_by_id(self, ids: list[Union[str, int]]) -> dict[Union[str, int], list[float]]: """Retrieve the vectors associated with a list of ids. Parameters ---------- ids : list[Union[str, int]] The ids of the vectors to retrieve. Returns ------- dict[Union[str, int], list[float]] Dictionary with the ids as the keys and the vectors as the values. """ if self.client is None: self.connect() if not isinstance(ids, list): ids = [ids] uuids = [self._convert_id_to_uuid(i) for i in ids] points = self.client.retrieve( collection_name=self.collection_name, ids=uuids, with_vectors=True, with_payload=True, ) # Return the vectors for the given ids return {point.payload["id"]: point.vector for point in points}