Skip to content

Commit

Permalink
Merge pull request #55 from TogetherCrew/feat/gdrive-rag
Browse files Browse the repository at this point in the history
feat: Added google drive RAG pipeline!
  • Loading branch information
amindadgar authored May 8, 2024
2 parents fb82bee + 978a0a8 commit 51b06f5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
16 changes: 15 additions & 1 deletion subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.query_engine import (
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
GDriveQueryEngine,
GitHubQueryEngine,
prepare_discord_engine_auto_filter,
)
Expand Down Expand Up @@ -94,7 +95,20 @@ def query_multiple_source(
if discourse:
raise NotImplementedError
if gdrive:
raise NotImplementedError
gdrive_query_engine = GDriveQueryEngine(community_id=community_id).prepare()
tool_metadata = ToolMetadata(
name="Google-Drive",
description=(
"Stores and manages documents, spreadsheets, presentations,"
" and other files for the community."
),
)
query_engine_tools.append(
QueryEngineTool(
query_engine=gdrive_query_engine,
metadata=tool_metadata,
)
)
if notion:
raise NotImplementedError
if telegram:
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_gdrive_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from utils.query_engine import GDriveQueryEngine


class TestGDriveQueryEngine(TestCase):
def setUp(self) -> None:
community_id = "sample_community"
self.gdrive_query_engine = GDriveQueryEngine(community_id)

def test_prepare_engine(self):
gdrive_query_engine = self.gdrive_query_engine.prepare(testing=True)
print(gdrive_query_engine.__dict__)
self.assertIsInstance(gdrive_query_engine.retriever, CustomVectorStoreRetriever)
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa
from .gdrive import GDriveQueryEngine
from .github import GitHubQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
30 changes: 30 additions & 0 deletions utils/query_engine/gdrive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from utils.query_engine.base_engine import BaseEngine


class GDriveQueryEngine(BaseEngine):
platform_name = "gdrive"

def __init__(self, community_id: str) -> None:
dbname = f"community_{community_id}"
self.dbname = dbname

def prepare(self, testing=False):
index = self._setup_vector_store_index(
platform_table_name=self.platform_name,
dbname=self.dbname,
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

0 comments on commit 51b06f5

Please sign in to comment.