diff --git a/subquery.py b/subquery.py index c2db9c7..92d53c0 100644 --- a/subquery.py +++ b/subquery.py @@ -9,6 +9,7 @@ from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from utils.query_engine import ( DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, + GDriveQueryEngine, GitHubQueryEngine, prepare_discord_engine_auto_filter, ) @@ -94,7 +95,20 @@ def query_multiple_source( if discourse: raise NotImplementedError if gdrive: - raise NotImplementedError + gdrive_query_engine = GDriveQueryEngine(community_id=community_id).prepare() + tool_metadata = ToolMetadata( + name="Google-Drive", + description=( + "Stores and manages documents, spreadsheets, presentations," + " and other files for the community." + ), + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=gdrive_query_engine, + metadata=tool_metadata, + ) + ) if notion: raise NotImplementedError if telegram: diff --git a/tests/unit/test_gdrive_query_engine.py b/tests/unit/test_gdrive_query_engine.py new file mode 100644 index 0000000..706d211 --- /dev/null +++ b/tests/unit/test_gdrive_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import GDriveQueryEngine + + +class TestGDriveQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.gdrive_query_engine = GDriveQueryEngine(community_id) + + def test_prepare_engine(self): + gdrive_query_engine = self.gdrive_query_engine.prepare(testing=True) + print(gdrive_query_engine.__dict__) + self.assertIsInstance(gdrive_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index ce734c4..bfd2879 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,4 +1,5 @@ # flake8: noqa +from .gdrive import GDriveQueryEngine from .github import GitHubQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL diff --git a/utils/query_engine/gdrive.py b/utils/query_engine/gdrive.py new file mode 100644 index 0000000..5a775e2 --- /dev/null +++ b/utils/query_engine/gdrive.py @@ -0,0 +1,30 @@ +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import get_response_synthesizer +from llama_index.core.query_engine import RetrieverQueryEngine +from utils.query_engine.base_engine import BaseEngine + + +class GDriveQueryEngine(BaseEngine): + platform_name = "gdrive" + + def __init__(self, community_id: str) -> None: + dbname = f"community_{community_id}" + self.dbname = dbname + + def prepare(self, testing=False): + index = self._setup_vector_store_index( + platform_table_name=self.platform_name, + dbname=self.dbname, + testing=testing, + ) + _, similarity_top_k, _ = load_hyperparams() + + retriever = CustomVectorStoreRetriever( + index=index, similarity_top_k=similarity_top_k + ) + query_engine = RetrieverQueryEngine( + retriever=retriever, + response_synthesizer=get_response_synthesizer(), + ) + return query_engine