diff --git a/autorag/vectordb/__init__.py b/autorag/vectordb/__init__.py index ea8e569d..e26fc639 100644 --- a/autorag/vectordb/__init__.py +++ b/autorag/vectordb/__init__.py @@ -18,6 +18,8 @@ def get_support_vectordb(vectordb_name: str): "Pinecone": ("autorag.vectordb.pinecone", "Pinecone"), "couchbase": ("autorag.vectordb.couchbase", "Couchbase"), "Couchbase": ("autorag.vectordb.couchbase", "Couchbase"), + "qdrant": ("autorag.vectordb.qdrant", "Qdrant"), + "Qdrant": ("autorag.vectordb.qdrant", "Qdrant"), } return dynamically_find_function(vectordb_name, support_vectordb) diff --git a/autorag/vectordb/qdrant.py b/autorag/vectordb/qdrant.py new file mode 100644 index 00000000..fa80d100 --- /dev/null +++ b/autorag/vectordb/qdrant.py @@ -0,0 +1,153 @@ +import logging + +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, + VectorParams, + PointStruct, + PointIdsList, + HasIdCondition, + Filter, + SearchRequest, +) + +from typing import List, Tuple + +from autorag.vectordb import BaseVectorStore + +logger = logging.getLogger("AutoRAG") + + +class Qdrant(BaseVectorStore): + def __init__( + self, + embedding_model: str, + collection_name: str, + embedding_batch: int = 100, + similarity_metric: str = "cosine", + client_type: str = "docker", + url: str = "http://localhost:6333", + host: str = "", + api_key: str = "", + dimension: int = 1536, + ingest_batch: int = 64, + parallel: int = 1, + max_retries: int = 3, + ): + super().__init__(embedding_model, similarity_metric, embedding_batch) + + self.collection_name = collection_name + self.ingest_batch = ingest_batch + self.parallel = parallel + self.max_retries = max_retries + + if similarity_metric == "cosine": + distance = Distance.COSINE + elif similarity_metric == "ip": + distance = Distance.DOT + elif similarity_metric == "l2": + distance = Distance.EUCLID + else: + raise ValueError( + f"similarity_metric {similarity_metric} is not supported\n" + "supported similarity metrics are: cosine, ip, l2" + ) + + if client_type == "docker": + self.client = QdrantClient( + url=url, + ) + elif client_type == "cloud": + self.client = QdrantClient( + host=host, + api_key=api_key, + ) + else: + raise ValueError( + f"client_type {client_type} is not supported\n" + "supported client types are: docker, cloud" + ) + + if not self.client.collection_exists(collection_name): + self.client.create_collection( + collection_name, + vectors_config=VectorParams( + size=dimension, + distance=distance, + ), + ) + self.collection = self.client.get_collection(collection_name) + + async def add(self, ids: List[str], texts: List[str]): + texts = self.truncated_inputs(texts) + text_embeddings = await self.embedding.aget_text_embedding_batch(texts) + + points = list( + map(lambda x: PointStruct(id=x[0], vector=x[1]), zip(ids, text_embeddings)) + ) + + self.client.upload_points( + collection_name=self.collection_name, + points=points, + batch_size=self.ingest_batch, + parallel=self.parallel, + max_retries=self.max_retries, + wait=True, + ) + + async def fetch(self, ids: List[str]) -> List[List[float]]: + # Fetch vectors by IDs + fetched_results = self.client.retrieve( + collection_name=self.collection_name, + ids=ids, + with_vectors=True, + ) + return list(map(lambda x: x.vector, fetched_results)) + + async def is_exist(self, ids: List[str]) -> List[bool]: + existed_result = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=Filter( + must=[ + HasIdCondition(has_id=ids), + ], + ), + ) + # existed_result is tuple. So we use existed_result[0] to get list of Record + existed_ids = list(map(lambda x: x.id, existed_result[0])) + return list(map(lambda x: x in existed_ids, ids)) + + async def query( + self, queries: List[str], top_k: int, **kwargs + ) -> Tuple[List[List[str]], List[List[float]]]: + queries = self.truncated_inputs(queries) + query_embeddings: List[ + List[float] + ] = await self.embedding.aget_text_embedding_batch(queries) + + search_queries = list( + map( + lambda x: SearchRequest(vector=x, limit=top_k, with_vector=True), + query_embeddings, + ) + ) + + search_result = self.client.search_batch( + collection_name=self.collection_name, requests=search_queries + ) + + # Extract IDs and distances + ids = [[str(hit.id) for hit in result] for result in search_result] + scores = [[hit.score for hit in result] for result in search_result] + + return ids, scores + + async def delete(self, ids: List[str]): + self.client.delete( + collection_name=self.collection_name, + points_selector=PointIdsList(points=ids), + ) + + def delete_collection(self): + # Delete the collection + self.client.delete_collection(self.collection_name) diff --git a/docs/source/api_spec/autorag.vectordb.rst b/docs/source/api_spec/autorag.vectordb.rst index a5a5b760..ef8b6a07 100644 --- a/docs/source/api_spec/autorag.vectordb.rst +++ b/docs/source/api_spec/autorag.vectordb.rst @@ -20,6 +20,14 @@ autorag.vectordb.chroma module :undoc-members: :show-inheritance: +autorag.vectordb.couchbase module +--------------------------------- + +.. automodule:: autorag.vectordb.couchbase + :members: + :undoc-members: + :show-inheritance: + autorag.vectordb.milvus module ------------------------------ @@ -36,6 +44,14 @@ autorag.vectordb.pinecone module :undoc-members: :show-inheritance: +autorag.vectordb.qdrant module +------------------------------ + +.. automodule:: autorag.vectordb.qdrant + :members: + :undoc-members: + :show-inheritance: + autorag.vectordb.weaviate module -------------------------------- diff --git a/docs/source/integration/vectordb/qdrant.md b/docs/source/integration/vectordb/qdrant.md new file mode 100644 index 00000000..62af5800 --- /dev/null +++ b/docs/source/integration/vectordb/qdrant.md @@ -0,0 +1,127 @@ +# Qdrant + +Qdrant is a high-performance vector similarity search engine and database. +It offers a robust, production-ready service with an intuitive API that allows users to store, search, and manage vectors, along with additional payloads. + +Qdrant supports advanced filtering, making it ideal for applications involving neural network or semantic-based matching, faceted search, and more. +Its capabilities are particularly beneficial for developing applications that require efficient and scalable vector search solutions. + +## Configuration + +To use the Qdrant vector database, you need to configure it in your YAML configuration file. Here's an example configuration: + +```yaml +- name: openai_embed_3_large + db_type: qdrant + embedding_model: openai_embed_3_large + collection_name: openai_embed_3_large + client_type: docker + embedding_batch: 50 + similarity_metric: cosine + dimension: 1536 +``` + +1. `embedding_model: str` + - Purpose: Specifies the name or identifier of the embedding model to be used. + - Example: "openai_embed_3_large" + - Note: This should correspond to a valid embedding model that your system can use to generate vector embeddings. For more information see [custom your embedding model](https://docs.auto-rag.com/local_model.html#configure-the-embedding-model) documentation. + +2. `collection_name: str` + - Purpose: Sets the name of the Qdrant collection where the vectors will be stored. + - Example: "my_vector_collection" + - Note: If the collection doesn't exist, it will be created. If it exists, it will be loaded. + +3. `embedding_batch: int = 100` + - Purpose: Determines the number of embeddings to process in a single batch. + - Default: 100 + - Note: Adjust this based on your system's memory and processing capabilities. Larger batches may be faster but require more memory. + +4. `similarity_metric: str = "cosine"` + - Purpose: Specifies the metric used to calculate similarity between vectors. + - Default: "cosine" + - Options: "cosine", "l2" (Euclidean distance), "ip" (Inner Product) + - Note: Choose the metric that best suits your use case and data characteristics. + - Not support "manhattan" + +5. `client_type = "docker"` + - Purpose: Specifies the type of client you're using to connect to Weaviate. + - Default: "docker" + - Options: "docker", "cloud" + - Note: Choose the appropriate client type based on your deployment. + - [docker quick start](https://qdrant.tech/documentation/quickstart/) + - [cloud quick start](https://qdrant.tech/documentation/quickstart-cloud/) + +6. `url: str = "http://localhost:6333"` + - Purpose: The URL of the Qdrant server. + - Default: "http://localhost:6333" + - Note: Use only `client_type: docker`. You can see full information at [here](https://qdrant.tech/documentation/quickstart/) + +7. `host: str` + - Purpose: The host of the Qdrant server. + - Default: "" + - Note: Use only `client_type: cloud`. You can see full information at [here](https://qdrant.tech/documentation/quickstart-cloud/) + +8. `api_key: str` + - Purpose: The API key for authentication with the Qdrant server. + - Default: "" + - Note: Use only `client_type: cloud`. You can see full information at [here](https://qdrant.tech/documentation/quickstart-cloud/) + +9. `dimension: int = 1536` + - Purpose: Specifies the dimension of the vector embeddings. + - Default: 1536 + - Note: This should correspond to the dimension of the embeddings generated by the specified embedding model. + +10. `ingest_batch: int = 64` + - Purpose: Determines the number of vectors to ingest in a single batch. + - Default: 64 + - Note: Adjust this based on your system's memory and processing capabilities. Larger batches may be faster but require more memory. + +11. `parallel: int = 1` + - Purpose: Determines the number of parallel requests to the Qdrant server. + - Default: 1 + - Note: Adjust this based on your system's processing capabilities. Increasing parallel requests can improve performance. + +12. `max_retries: int = 3` + - Purpose: Specifies the maximum number of retries for failed requests to the Qdrant server. + - Default: 3 + - Note: Set this based on your system's network reliability and the expected failure rate. + +#### Usage + +Here's a brief overview of how to use the main functions of the Qdrant vector database: + +1. **Adding Vectors**: + ```python + await qdrant_db.add(ids, texts) + ``` + This method adds new vectors to the database. It takes a list of IDs and corresponding texts, generates embeddings, and inserts them into the Qdrant collection. + +2. **Querying**: + ```python + ids, distances = await qdrant_db.query(queries, top_k) + ``` + Performs a similarity search on the stored vectors. It returns the IDs of the most similar vectors and their distances. + +3. **Fetching Vectors**: + ```python + vectors = await qdrant_db.fetch(ids) + ``` + Retrieves the vectors associated with the given IDs. + +4. **Checking Existence**: + ```python + exists = await qdrant_db.is_exist(ids) + ``` + Checks if the given IDs exist in the database. + +5. **Deleting Vectors**: + ```python + await qdrant_db.delete(ids) + ``` + Deletes the vectors associated with the given IDs from the database. + +6. **Deleting the Collection**: + ```python + qdrant_db.delete_collection() + ``` + Deletes the collection from the Qdrant server. diff --git a/docs/source/integration/vectordb/vectordb.md b/docs/source/integration/vectordb/vectordb.md index f662a8f6..7046d0e2 100644 --- a/docs/source/integration/vectordb/vectordb.md +++ b/docs/source/integration/vectordb/vectordb.md @@ -103,4 +103,5 @@ milvus.md weaviate.md pinecone.md couchbase.md +qdrant.md ``` diff --git a/requirements.txt b/requirements.txt index c00f1649..3d67058a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,6 +28,7 @@ chromadb>=0.5.0 # for chroma vectordb weaviate-client # for weaviate vectordb pinecone[grpc] # for pinecone vectordb couchbase # for couchbase vectordb +qdrant-client # for qdrant vectordb ### API server ### quart diff --git a/tests/autorag/vectordb/test_couchbase.py b/tests/autorag/vectordb/test_couchbase.py index 5e4e2b53..a41bfebe 100644 --- a/tests/autorag/vectordb/test_couchbase.py +++ b/tests/autorag/vectordb/test_couchbase.py @@ -28,7 +28,7 @@ def couchbase_instance(): @pytest.mark.skipif( is_github_action(), - reason="This test needs pinecone api key", + reason="This test needs couchbase connection string, username, and password", ) @pytest.mark.asyncio async def test_add_and_query_documents(couchbase_instance): @@ -61,7 +61,7 @@ async def test_add_and_query_documents(couchbase_instance): @pytest.mark.skipif( is_github_action(), - reason="This test needs pinecone api key", + reason="This test needs couchbase connection string, username, and password", ) @pytest.mark.asyncio async def test_delete_documents(couchbase_instance): diff --git a/tests/autorag/vectordb/test_qdrant.py b/tests/autorag/vectordb/test_qdrant.py new file mode 100644 index 00000000..7770610f --- /dev/null +++ b/tests/autorag/vectordb/test_qdrant.py @@ -0,0 +1,81 @@ +import asyncio +import os +import uuid + +import pytest + +from autorag.vectordb.qdrant import Qdrant +from tests.delete_tests import is_github_action + + +@pytest.mark.skipif( + is_github_action(), + reason="This test needs qdrant docker server.", +) +@pytest.fixture +def qdrant_instance(): + qdrant = Qdrant( + embedding_model="mock", + collection_name="autorag_t", + client_type="docker", + dimension=768, + ) + yield qdrant + qdrant.delete_collection() + + +@pytest.mark.skipif( + is_github_action(), + reason="This test needs qdrant docker server.", +) +@pytest.mark.asyncio +async def test_add_and_query_documents(qdrant_instance): + # Add documents + ids = [str(uuid.uuid4()) for _ in range(2)] + texts = ["This is a test document.", "This is another test document."] + await qdrant_instance.add(ids, texts) + + await asyncio.sleep(1) + + # Query documents + queries = ["test document"] + contents, scores = await qdrant_instance.query(queries, top_k=2) + + assert len(contents) == 1 + assert len(scores) == 1 + assert len(contents[0]) == 2 + assert len(scores[0]) == 2 + assert scores[0][0] > scores[0][1] + + embeddings = await qdrant_instance.fetch([ids[0]]) + assert len(embeddings) == 1 + assert len(embeddings[0]) == 768 + + exist = await qdrant_instance.is_exist([ids[0], str(uuid.uuid4())]) + assert len(exist) == 2 + assert exist[0] is True + assert exist[1] is False + + +@pytest.mark.skipif( + is_github_action(), + reason="This test needs qdrant docker server.", +) +@pytest.mark.asyncio +async def test_delete_documents(qdrant_instance): + # Add documents + ids = [str(uuid.uuid4()) for _ in range(2)] + texts = ["This is a test document.", "This is another test document."] + await qdrant_instance.add(ids, texts) + + await asyncio.sleep(1) + + # Delete documents + await qdrant_instance.delete([ids[0]]) + + # Query documents to ensure they are deleted + queries = ["test document"] + contents, scores = await qdrant_instance.query(queries, top_k=2) + + assert len(contents[0]) == 1 + assert len(scores[0]) == 1