Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added qdrant support for BaseEngine and added MediaWiki RAG! #60

Merged
merged 11 commits into from
May 20, 2024
27 changes: 27 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ services:
- D_RETRIEVER_SEARCH=7
- COHERE_API_KEY=some_credentials
- OPENAI_API_KEY=some_credentials2
- QDRANT_HOST=qdrant
- QDRANT_PORT=6333
- QDRANT_API_KEY=
volumes:
- ./coverage:/project/coverage
depends_on:
Expand All @@ -43,6 +46,8 @@ services:
condition: service_healthy
postgres:
condition: service_healthy
qdrant-healthcheck:
condition: service_healthy
neo4j:
image: "neo4j:5.9.0"
environment:
Expand Down Expand Up @@ -87,3 +92,25 @@ services:
timeout: 30s
retries: 2
start_period: 40s
qdrant:
image: qdrant/qdrant:v1.9.2
restart: always
container_name: qdrant
ports:
- 6333:6333
expose:
- 6333
volumes:
- ./qdrant_data:/qdrant_data
qdrant-healthcheck:
restart: always
image: curlimages/curl:latest
entrypoint: ["/bin/sh", "-c", "--", "while true; do sleep 30; done;"]
depends_on:
- qdrant
healthcheck:
test: ["CMD", "curl", "-f", "http://qdrant:6333/readyz"]
interval: 10s
timeout: 2s
retries: 5

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ neo4j>=5.14.1, <6.0.0
coverage>=7.3.3, <8.0.0
pytest>=7.4.3, <8.0.0
python-dotenv==1.0.0
tc-hivemind-backend==1.1.5
tc-hivemind-backend==1.2.0
llama-index-question-gen-guidance==0.1.2
llama-index-vector-stores-postgres==0.1.2
celery>=5.3.6, <6.0.0
Expand Down
21 changes: 17 additions & 4 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
GDriveQueryEngine,
GitHubQueryEngine,
MediaWikiQueryEngine,
NotionQueryEngine,
prepare_discord_engine_auto_filter,
)
Expand All @@ -25,7 +26,7 @@ def query_multiple_source(
notion: bool = False,
telegram: bool = False,
github: bool = False,
media_wiki: bool = False,
mediaWiki: bool = False,
) -> tuple[str, list[NodeWithScore]]:
"""
query multiple platforms and get an answer from the multiple
Expand Down Expand Up @@ -72,7 +73,7 @@ def query_multiple_source(
# discourse_query_engine: BaseQueryEngine
gdrive_query_engine: BaseQueryEngine
notion_query_engine: BaseQueryEngine
# media_wiki_query_engine: BaseQueryEngine
mediawiki_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine

# query engine perparation
Expand Down Expand Up @@ -143,8 +144,20 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if media_wiki:
raise NotImplementedError
if mediaWiki:
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
tool_metadata = ToolMetadata(
name="WikiPedia",
description="Hosts articles about any information on internet",
)
query_engine_tools.append(
QueryEngineTool(
query_engine=mediawiki_query_engine,
metadata=tool_metadata,
)
)
Comment on lines +147 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the typo in the tool metadata name.

- name="WikiPedia",
+ name="MediaWiki",

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if mediaWiki:
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
tool_metadata = ToolMetadata(
name="WikiPedia",
description="Hosts articles about any information on internet",
)
query_engine_tools.append(
QueryEngineTool(
query_engine=mediawiki_query_engine,
metadata=tool_metadata,
)
)
if mediaWiki:
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
tool_metadata = ToolMetadata(
name="MediaWiki",
description="Hosts articles about any information on internet",
)
query_engine_tools.append(
QueryEngineTool(
query_engine=mediawiki_query_engine,
metadata=tool_metadata,
)
)


embed_model = CohereEmbedding()
llm = OpenAI("gpt-3.5-turbo")
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_mediawiki_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from utils.query_engine import MediaWikiQueryEngine


class TestMediaWikiQueryEngine(TestCase):
def setUp(self) -> None:
community_id = "sample_community"
self.notion_query_engine = MediaWikiQueryEngine(community_id)

def test_prepare_engine(self):
notion_query_engine = self.notion_query_engine.prepare(testing=True)
print(notion_query_engine.__dict__)
self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever)
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
from .gdrive import GDriveQueryEngine
from .github import GitHubQueryEngine
from .media_wiki import MediaWikiQueryEngine
from .notion import NotionQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
43 changes: 15 additions & 28 deletions utils/query_engine/base_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.embeddings import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess


class BaseEngine:
Expand All @@ -24,16 +24,17 @@ def __init__(self, platform_name: str, community_id: str) -> None:
"""
self.platform_name = platform_name
self.community_id = community_id
self.dbname = f"community_{self.community_id}"
self.collection_name = f"{self.community_id}_{platform_name}"

def prepare(self, testing=False):
index = self._setup_vector_store_index(
vector_store_index = self._setup_vector_store_index(
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
retriever = VectorIndexRetriever(
index=vector_store_index,
similarity_top_k=similarity_top_k,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
Expand All @@ -54,27 +55,13 @@ def _setup_vector_store_index(
testing : bool
for testing purposes
**kwargs :
table_name : str
to override the default table_name
dbname : str
to override the default database name
collection_name : str
to override the default collection_name
"""
table_name = kwargs.get("table_name", self.platform_name)
dbname = kwargs.get("dbname", self.dbname)

embed_model: BaseEmbedding
if testing:
from llama_index.core import MockEmbedding

embed_model = MockEmbedding(embed_dim=1024)
else:
embed_model = CohereEmbedding()

pg_vector = PGVectorAccess(
table_name=table_name,
dbname=dbname,
collection_name = kwargs.get("collection_name", self.collection_name)
qdrant_vector = QDrantVectorAccess(
collection_name=collection_name,
testing=testing,
embed_model=embed_model,
)
index = pg_vector.load_index()
index = qdrant_vector.load_index()
return index
7 changes: 7 additions & 0 deletions utils/query_engine/media_wiki.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from utils.query_engine.base_engine import BaseEngine


class MediaWikiQueryEngine(BaseEngine):
def __init__(self, community_id: str) -> None:
platform_name = "mediawiki"
super().__init__(platform_name, community_id)
Loading