From a8d7fb77f2d476f874da2a701ac4176dc341d727 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 13:34:37 +0330 Subject: [PATCH] feat: Added qdrant support for BaseEngine! --- requirements.txt | 2 +- utils/query_engine/base_engine.py | 43 +++++++++++-------------------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/requirements.txt b/requirements.txt index d0cb913..6bab684 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ neo4j>=5.14.1, <6.0.0 coverage>=7.3.3, <8.0.0 pytest>=7.4.3, <8.0.0 python-dotenv==1.0.0 -tc-hivemind-backend==1.1.5 +tc-hivemind-backend==1.2.0 llama-index-question-gen-guidance==0.1.2 llama-index-vector-stores-postgres==0.1.2 celery>=5.3.6, <6.0.0 diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py index 7f05779..fd5e7cc 100644 --- a/utils/query_engine/base_engine.py +++ b/utils/query_engine/base_engine.py @@ -1,10 +1,10 @@ -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) from bot.retrievers.utils.load_hyperparams import load_hyperparams from llama_index.core import VectorStoreIndex, get_response_synthesizer -from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.query_engine import RetrieverQueryEngine -from tc_hivemind_backend.embeddings import CohereEmbedding -from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess class BaseEngine: @@ -24,16 +24,17 @@ def __init__(self, platform_name: str, community_id: str) -> None: """ self.platform_name = platform_name self.community_id = community_id - self.dbname = f"community_{self.community_id}" + self.collection_name = f"{self.community_id}_{platform_name}" def prepare(self, testing=False): - index = self._setup_vector_store_index( + vector_store_index = self._setup_vector_store_index( testing=testing, ) _, similarity_top_k, _ = load_hyperparams() - retriever = CustomVectorStoreRetriever( - index=index, similarity_top_k=similarity_top_k + retriever = VectorIndexRetriever( + index=vector_store_index, + similarity_top_k=similarity_top_k, ) query_engine = RetrieverQueryEngine( retriever=retriever, @@ -54,27 +55,13 @@ def _setup_vector_store_index( testing : bool for testing purposes **kwargs : - table_name : str - to override the default table_name - dbname : str - to override the default database name + collection_name : str + to override the default collection_name """ - table_name = kwargs.get("table_name", self.platform_name) - dbname = kwargs.get("dbname", self.dbname) - - embed_model: BaseEmbedding - if testing: - from llama_index.core import MockEmbedding - - embed_model = MockEmbedding(embed_dim=1024) - else: - embed_model = CohereEmbedding() - - pg_vector = PGVectorAccess( - table_name=table_name, - dbname=dbname, + collection_name = kwargs.get("collection_name", self.collection_name) + qdrant_vector = QDrantVectorAccess( + collection_name=collection_name, testing=testing, - embed_model=embed_model, ) - index = pg_vector.load_index() + index = qdrant_vector.load_index() return index