Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added qdrant support for BaseEngine and added MediaWiki RAG! #60

Merged
merged 11 commits into from
May 20, 2024
27 changes: 27 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ services:
- D_RETRIEVER_SEARCH=7
- COHERE_API_KEY=some_credentials
- OPENAI_API_KEY=some_credentials2
- QDRANT_HOST=localhost
- QDRANT_PORT=6333
- QDRANT_API_KEY=
volumes:
- ./coverage:/project/coverage
depends_on:
Expand All @@ -43,6 +46,8 @@ services:
condition: service_healthy
postgres:
condition: service_healthy
qdrant-healthcheck:
condition: service_healthy
neo4j:
image: "neo4j:5.9.0"
environment:
Expand Down Expand Up @@ -87,3 +92,25 @@ services:
timeout: 30s
retries: 2
start_period: 40s
qdrant:
image: qdrant/qdrant:v1.9.2
restart: always
container_name: qdrant
ports:
- 6333:6333
expose:
- 6333
volumes:
- ./qdrant_data:/qdrant_data
qdrant-healthcheck:
restart: always
image: curlimages/curl:latest
entrypoint: ["/bin/sh", "-c", "--", "while true; do sleep 30; done;"]
depends_on:
- qdrant
healthcheck:
test: ["CMD", "curl", "-f", "http://qdrant:6333/readyz"]
interval: 10s
timeout: 2s
retries: 5

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ neo4j>=5.14.1, <6.0.0
coverage>=7.3.3, <8.0.0
pytest>=7.4.3, <8.0.0
python-dotenv==1.0.0
tc-hivemind-backend==1.1.5
tc-hivemind-backend==1.2.0
llama-index-question-gen-guidance==0.1.2
llama-index-vector-stores-postgres==0.1.2
celery>=5.3.6, <6.0.0
Expand Down
43 changes: 15 additions & 28 deletions utils/query_engine/base_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.embeddings import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess


class BaseEngine:
Expand All @@ -24,16 +24,17 @@ def __init__(self, platform_name: str, community_id: str) -> None:
"""
self.platform_name = platform_name
self.community_id = community_id
self.dbname = f"community_{self.community_id}"
self.collection_name = f"{self.community_id}_{platform_name}"

def prepare(self, testing=False):
index = self._setup_vector_store_index(
vector_store_index = self._setup_vector_store_index(
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
retriever = VectorIndexRetriever(
index=vector_store_index,
similarity_top_k=similarity_top_k,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
Expand All @@ -54,27 +55,13 @@ def _setup_vector_store_index(
testing : bool
for testing purposes
**kwargs :
table_name : str
to override the default table_name
dbname : str
to override the default database name
collection_name : str
to override the default collection_name
"""
table_name = kwargs.get("table_name", self.platform_name)
dbname = kwargs.get("dbname", self.dbname)

embed_model: BaseEmbedding
if testing:
from llama_index.core import MockEmbedding

embed_model = MockEmbedding(embed_dim=1024)
else:
embed_model = CohereEmbedding()

pg_vector = PGVectorAccess(
table_name=table_name,
dbname=dbname,
collection_name = kwargs.get("collection_name", self.collection_name)
qdrant_vector = QDrantVectorAccess(
collection_name=collection_name,
testing=testing,
embed_model=embed_model,
)
index = pg_vector.load_index()
index = qdrant_vector.load_index()
return index
Loading