Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added qdrant support for BaseEngine and added MediaWiki RAG! #60

Merged
merged 11 commits into from
May 20, 2024
27 changes: 27 additions & 0 deletions docker-compose.test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ services:
- D_RETRIEVER_SEARCH=7
- COHERE_API_KEY=some_credentials
- OPENAI_API_KEY=some_credentials2
- QDRANT_HOST=qdrant
- QDRANT_PORT=6333
- QDRANT_API_KEY=
volumes:
- ./coverage:/project/coverage
depends_on:
Expand All @@ -43,6 +46,8 @@ services:
condition: service_healthy
postgres:
condition: service_healthy
qdrant-healthcheck:
condition: service_healthy
neo4j:
image: "neo4j:5.9.0"
environment:
Expand Down Expand Up @@ -87,3 +92,25 @@ services:
timeout: 30s
retries: 2
start_period: 40s
qdrant:
image: qdrant/qdrant:v1.9.2
restart: always
container_name: qdrant
ports:
- 6333:6333
expose:
- 6333
volumes:
- ./qdrant_data:/qdrant_data
qdrant-healthcheck:
restart: always
image: curlimages/curl:latest
entrypoint: ["/bin/sh", "-c", "--", "while true; do sleep 30; done;"]
depends_on:
- qdrant
healthcheck:
test: ["CMD", "curl", "-f", "http://qdrant:6333/readyz"]
interval: 10s
timeout: 2s
retries: 5

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ neo4j>=5.14.1, <6.0.0
coverage>=7.3.3, <8.0.0
pytest>=7.4.3, <8.0.0
python-dotenv==1.0.0
tc-hivemind-backend==1.1.5
tc-hivemind-backend==1.2.0
llama-index-question-gen-guidance==0.1.2
llama-index-vector-stores-postgres==0.1.2
celery>=5.3.6, <6.0.0
Expand Down
21 changes: 17 additions & 4 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
GDriveQueryEngine,
GitHubQueryEngine,
MediaWikiQueryEngine,
NotionQueryEngine,
prepare_discord_engine_auto_filter,
)
Expand All @@ -25,7 +26,7 @@ def query_multiple_source(
notion: bool = False,
telegram: bool = False,
github: bool = False,
media_wiki: bool = False,
mediaWiki: bool = False,
) -> tuple[str, list[NodeWithScore]]:
"""
query multiple platforms and get an answer from the multiple
Expand Down Expand Up @@ -72,7 +73,7 @@ def query_multiple_source(
# discourse_query_engine: BaseQueryEngine
gdrive_query_engine: BaseQueryEngine
notion_query_engine: BaseQueryEngine
# media_wiki_query_engine: BaseQueryEngine
mediawiki_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine

# query engine perparation
Expand Down Expand Up @@ -143,8 +144,20 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if media_wiki:
raise NotImplementedError
if mediaWiki:
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
tool_metadata = ToolMetadata(
name="WikiPedia",
description="Hosts articles about any information on internet",
)
query_engine_tools.append(
QueryEngineTool(
query_engine=mediawiki_query_engine,
metadata=tool_metadata,
)
)
Comment on lines +147 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the typo in the tool metadata name.

- name="WikiPedia",
+ name="MediaWiki",

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if mediaWiki:
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
tool_metadata = ToolMetadata(
name="WikiPedia",
description="Hosts articles about any information on internet",
)
query_engine_tools.append(
QueryEngineTool(
query_engine=mediawiki_query_engine,
metadata=tool_metadata,
)
)
if mediaWiki:
mediawiki_query_engine = MediaWikiQueryEngine(
community_id=community_id
).prepare()
tool_metadata = ToolMetadata(
name="MediaWiki",
description="Hosts articles about any information on internet",
)
query_engine_tools.append(
QueryEngineTool(
query_engine=mediawiki_query_engine,
metadata=tool_metadata,
)
)


embed_model = CohereEmbedding()
llm = OpenAI("gpt-3.5-turbo")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from unittest import TestCase

from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_pg_engine import BasePGEngine


class TestBaseEngine(TestCase):
class TestPGBaseEngine(TestCase):
def test_setup_vector_store_index(self):
"""
Tests that _setup_vector_store_index creates a PGVectorAccess object
and calls its load_index method.
"""
platform_table_name = "test_table"
community_id = "123456"
base_engine = BaseEngine(
base_engine = BasePGEngine(
platform_name=platform_table_name,
community_id=community_id,
)
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_base_qdrant_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from unittest import TestCase

from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class TestBaseQdrantEngine(TestCase):
def test_setup_vector_store_index(self):
"""
Tests that _setup_vector_store_index creates a PGVectorAccess object
and calls its load_index method.
"""
platform_table_name = "test_table"
community_id = "123456"
base_engine = BaseQdrantEngine(
platform_name=platform_table_name,
community_id=community_id,
)
base_engine = base_engine._setup_vector_store_index(
testing=True,
)

expected_collection_name = f"{community_id}_{platform_table_name}"
self.assertEqual(
base_engine.vector_store.collection_name, expected_collection_name
)
6 changes: 4 additions & 2 deletions tests/unit/test_gdrive_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from utils.query_engine import GDriveQueryEngine


Expand All @@ -12,4 +14,4 @@ def setUp(self) -> None:
def test_prepare_engine(self):
gdrive_query_engine = self.gdrive_query_engine.prepare(testing=True)
print(gdrive_query_engine.__dict__)
self.assertIsInstance(gdrive_query_engine.retriever, CustomVectorStoreRetriever)
self.assertIsInstance(gdrive_query_engine.retriever, VectorIndexRetriever)
6 changes: 4 additions & 2 deletions tests/unit/test_github_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from utils.query_engine import GitHubQueryEngine


Expand All @@ -12,4 +14,4 @@ def setUp(self) -> None:
def test_prepare_engine(self):
github_query_engine = self.github_query_engine.prepare(testing=True)
print(github_query_engine.__dict__)
self.assertIsInstance(github_query_engine.retriever, CustomVectorStoreRetriever)
self.assertIsInstance(github_query_engine.retriever, VectorIndexRetriever)
17 changes: 17 additions & 0 deletions tests/unit/test_mediawiki_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from unittest import TestCase

from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from utils.query_engine import MediaWikiQueryEngine


class TestMediaWikiQueryEngine(TestCase):
def setUp(self) -> None:
community_id = "sample_community"
self.notion_query_engine = MediaWikiQueryEngine(community_id)

def test_prepare_engine(self):
notion_query_engine = self.notion_query_engine.prepare(testing=True)
print(notion_query_engine.__dict__)
self.assertIsInstance(notion_query_engine.retriever, VectorIndexRetriever)
6 changes: 4 additions & 2 deletions tests/unit/test_notion_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from utils.query_engine import NotionQueryEngine


Expand All @@ -12,4 +14,4 @@ def setUp(self) -> None:
def test_prepare_engine(self):
notion_query_engine = self.notion_query_engine.prepare(testing=True)
print(notion_query_engine.__dict__)
self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever)
self.assertIsInstance(notion_query_engine.retriever, VectorIndexRetriever)
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# flake8: noqa
from .gdrive import GDriveQueryEngine
from .github import GitHubQueryEngine
from .media_wiki import MediaWikiQueryEngine
from .notion import NotionQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from tc_hivemind_backend.pg_vector_access import PGVectorAccess


class BaseEngine:
class BasePGEngine:
def __init__(self, platform_name: str, community_id: str) -> None:
"""
initialize the engine to query the database related to a community
initialize the pg vector db engine to query the database related to a community
and the table related to the platform

Parameters
Expand Down
67 changes: 67 additions & 0 deletions utils/query_engine/base_qdrant_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess


class BaseQdrantEngine:
def __init__(self, platform_name: str, community_id: str) -> None:
"""
initialize the qdrant db engine to query the database related to a community
and the table related to the platform

Parameters
-----------
platform_name : str
the table representative of data for a specific platform
community_id : str
the database for a community
normally the community database is saved as
`community_{community_id}`
"""
self.platform_name = platform_name
self.community_id = community_id
self.collection_name = f"{self.community_id}_{platform_name}"

def prepare(self, testing=False):
vector_store_index = self._setup_vector_store_index(
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = VectorIndexRetriever(
index=vector_store_index,
similarity_top_k=similarity_top_k,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

def _setup_vector_store_index(
self,
testing: bool = False,
**kwargs,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data

Parameters
------------
testing : bool
for testing purposes
**kwargs :
collection_name : str
to override the default collection_name
"""
collection_name = kwargs.get("collection_name", self.collection_name)
qdrant_vector = QDrantVectorAccess(
collection_name=collection_name,
testing=testing,
)
index = qdrant_vector.load_index()
return index
4 changes: 2 additions & 2 deletions utils/query_engine/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class GDriveQueryEngine(BaseEngine):
class GDriveQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "gdrive"
super().__init__(platform_name, community_id)
4 changes: 2 additions & 2 deletions utils/query_engine/github.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class GitHubQueryEngine(BaseEngine):
class GitHubQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "github"
super().__init__(platform_name, community_id)
6 changes: 3 additions & 3 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.llms.openai import OpenAI
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_pg_engine import BasePGEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils

qa_prompt = PromptTemplate(
Expand Down Expand Up @@ -116,7 +116,7 @@ def prepare_platform_engine(
)
llm = kwargs.get("llm", OpenAI("gpt-4"))
qa_prompt_ = kwargs.get("qa_prompt", qa_prompt)
base_engine = BaseEngine(platform_table_name, community_id)
base_engine = BasePGEngine(platform_table_name, community_id)
index: VectorStoreIndex = kwargs.get(
"index_raw",
base_engine._setup_vector_store_index(
Expand Down Expand Up @@ -198,7 +198,7 @@ def prepare_engine_auto_filter(
dbname = f"community_{community_id}"
summary_similarity_top_k, _, d = load_hyperparams()

base_engine = BaseEngine(
base_engine = BasePGEngine(
platform_table_name + "_summary",
community_id,
)
Expand Down
7 changes: 7 additions & 0 deletions utils/query_engine/media_wiki.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class MediaWikiQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "mediawiki"
super().__init__(platform_name, community_id)
4 changes: 2 additions & 2 deletions utils/query_engine/notion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.query_engine.base_engine import BaseEngine
from utils.query_engine.base_qdrant_engine import BaseQdrantEngine


class NotionQueryEngine(BaseEngine):
class NotionQueryEngine(BaseQdrantEngine):
def __init__(self, community_id: str) -> None:
platform_name = "notion"
super().__init__(platform_name, community_id)
Loading