From f77013a6cde20ca67c3f613b2bf780d084b3a7e1 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 6 May 2024 12:27:22 +0330 Subject: [PATCH 1/4] feat: Added google drive RAG pipeline! --- subquery.py | 16 +++++++++++++- tests/unit/test_gdrive_query_engine.py | 15 +++++++++++++ utils/query_engine/__init__.py | 1 + utils/query_engine/gdrive.py | 30 ++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_gdrive_query_engine.py create mode 100644 utils/query_engine/gdrive.py diff --git a/subquery.py b/subquery.py index c2db9c7..7b53a20 100644 --- a/subquery.py +++ b/subquery.py @@ -9,6 +9,7 @@ from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from utils.query_engine import ( DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, + GDriveQueryEngine, GitHubQueryEngine, prepare_discord_engine_auto_filter, ) @@ -94,7 +95,20 @@ def query_multiple_source( if discourse: raise NotImplementedError if gdrive: - raise NotImplementedError + github_query_engine = GDriveQueryEngine(community_id=community_id).prepare() + tool_metadata = ToolMetadata( + name="Google-Drive", + description=( + "Stores and manages documents, spreadsheets, presentations," + " and other files for the community." + ), + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=github_query_engine, + metadata=tool_metadata, + ) + ) if notion: raise NotImplementedError if telegram: diff --git a/tests/unit/test_gdrive_query_engine.py b/tests/unit/test_gdrive_query_engine.py new file mode 100644 index 0000000..3ccab43 --- /dev/null +++ b/tests/unit/test_gdrive_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import GDriveQueryEngine + + +class TestGitHubQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.gdrive_query_engine = GDriveQueryEngine(community_id) + + def test_prepare_engine(self): + gdrive_query_engine = self.gdrive_query_engine.prepare(testing=True) + print(gdrive_query_engine.__dict__) + self.assertIsInstance(gdrive_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index ce734c4..5ce4609 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,4 +1,5 @@ # flake8: noqa from .github import GitHubQueryEngine +from .gdrive import GDriveQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL diff --git a/utils/query_engine/gdrive.py b/utils/query_engine/gdrive.py new file mode 100644 index 0000000..5a775e2 --- /dev/null +++ b/utils/query_engine/gdrive.py @@ -0,0 +1,30 @@ +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import get_response_synthesizer +from llama_index.core.query_engine import RetrieverQueryEngine +from utils.query_engine.base_engine import BaseEngine + + +class GDriveQueryEngine(BaseEngine): + platform_name = "gdrive" + + def __init__(self, community_id: str) -> None: + dbname = f"community_{community_id}" + self.dbname = dbname + + def prepare(self, testing=False): + index = self._setup_vector_store_index( + platform_table_name=self.platform_name, + dbname=self.dbname, + testing=testing, + ) + _, similarity_top_k, _ = load_hyperparams() + + retriever = CustomVectorStoreRetriever( + index=index, similarity_top_k=similarity_top_k + ) + query_engine = RetrieverQueryEngine( + retriever=retriever, + response_synthesizer=get_response_synthesizer(), + ) + return query_engine From 1f5a86aa8b4751535b78b548931c15dde953ce8f Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 6 May 2024 12:53:02 +0330 Subject: [PATCH 2/4] fix: isort linter issue! --- utils/query_engine/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 5ce4609..bfd2879 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,5 +1,5 @@ # flake8: noqa -from .github import GitHubQueryEngine from .gdrive import GDriveQueryEngine +from .github import GitHubQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL From 9e3f22a751c4056887203f1569c56a221d738094 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 6 May 2024 12:54:14 +0330 Subject: [PATCH 3/4] fix: typo in variable name! --- subquery.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/subquery.py b/subquery.py index 7b53a20..92d53c0 100644 --- a/subquery.py +++ b/subquery.py @@ -95,7 +95,7 @@ def query_multiple_source( if discourse: raise NotImplementedError if gdrive: - github_query_engine = GDriveQueryEngine(community_id=community_id).prepare() + gdrive_query_engine = GDriveQueryEngine(community_id=community_id).prepare() tool_metadata = ToolMetadata( name="Google-Drive", description=( @@ -105,7 +105,7 @@ def query_multiple_source( ) query_engine_tools.append( QueryEngineTool( - query_engine=github_query_engine, + query_engine=gdrive_query_engine, metadata=tool_metadata, ) ) From 978a0a88609c6f025e43cb249d54a94adb1bc0a7 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 6 May 2024 12:55:25 +0330 Subject: [PATCH 4/4] fix: typo in test case class name! --- tests/unit/test_gdrive_query_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_gdrive_query_engine.py b/tests/unit/test_gdrive_query_engine.py index 3ccab43..706d211 100644 --- a/tests/unit/test_gdrive_query_engine.py +++ b/tests/unit/test_gdrive_query_engine.py @@ -4,7 +4,7 @@ from utils.query_engine import GDriveQueryEngine -class TestGitHubQueryEngine(TestCase): +class TestGDriveQueryEngine(TestCase): def setUp(self) -> None: community_id = "sample_community" self.gdrive_query_engine = GDriveQueryEngine(community_id)