Skip to content

Commit

Permalink
Merge pull request #58 from TogetherCrew/feat/notion-rag
Browse files Browse the repository at this point in the history
Feat/notion rag
  • Loading branch information
amindadgar authored May 9, 2024
2 parents 51b06f5 + ccfefb8 commit 219c091
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 72 deletions.
23 changes: 20 additions & 3 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
GDriveQueryEngine,
GitHubQueryEngine,
NotionQueryEngine,
prepare_discord_engine_auto_filter,
)

Expand All @@ -24,6 +25,7 @@ def query_multiple_source(
notion: bool = False,
telegram: bool = False,
github: bool = False,
media_wiki: bool = False,
) -> tuple[str, list[NodeWithScore]]:
"""
query multiple platforms and get an answer from the multiple
Expand Down Expand Up @@ -68,8 +70,9 @@ def query_multiple_source(
discord_query_engine: BaseQueryEngine
github_query_engine: BaseQueryEngine
# discourse_query_engine: BaseQueryEngine
# gdrive_query_engine: BaseQueryEngine
# notion_query_engine: BaseQueryEngine
gdrive_query_engine: BaseQueryEngine
notion_query_engine: BaseQueryEngine
# media_wiki_query_engine: BaseQueryEngine
# telegram_query_engine: BaseQueryEngine

# query engine perparation
Expand Down Expand Up @@ -110,7 +113,19 @@ def query_multiple_source(
)
)
if notion:
raise NotImplementedError
notion_query_engine = NotionQueryEngine(community_id=community_id)
tool_metadata = ToolMetadata(
name="Notion",
description=(
"Centralizes notes, wikis, project plans, and to-dos for the community."
),
)
query_engine_tools.append(
QueryEngineTool(
query_engine=notion_query_engine,
metadata=tool_metadata,
)
)
if telegram:
raise NotImplementedError
if github:
Expand All @@ -128,6 +143,8 @@ def query_multiple_source(
metadata=tool_metadata,
)
)
if media_wiki:
raise NotImplementedError

embed_model = CohereEmbedding()
llm = OpenAI("gpt-3.5-turbo")
Expand Down
15 changes: 9 additions & 6 deletions tests/unit/test_base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ def test_setup_vector_store_index(self):
and calls its load_index method.
"""
platform_table_name = "test_table"
dbname = "test_db"

base_engine = BaseEngine._setup_vector_store_index(
platform_table_name=platform_table_name,
dbname=dbname,
community_id = "123456"
base_engine = BaseEngine(
platform_name=platform_table_name,
community_id=community_id,
)
base_engine = base_engine._setup_vector_store_index(
testing=True,
)
self.assertIn(dbname, base_engine.vector_store.connection_string)

expected_dbname = f"community_{community_id}"
self.assertIn(expected_dbname, base_engine.vector_store.connection_string)
self.assertEqual(base_engine.vector_store.table_name, platform_table_name)
15 changes: 15 additions & 0 deletions tests/unit/test_notion_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from unittest import TestCase

from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from utils.query_engine import NotionQueryEngine


class TestNotionQueryEngine(TestCase):
def setUp(self) -> None:
community_id = "sample_community"
self.notion_query_engine = NotionQueryEngine(community_id)

def test_prepare_engine(self):
notion_query_engine = self.notion_query_engine.prepare(testing=True)
print(notion_query_engine.__dict__)
self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever)
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa
from .gdrive import GDriveQueryEngine
from .github import GitHubQueryEngine
from .notion import NotionQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
70 changes: 63 additions & 7 deletions utils/query_engine/base_engine.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,80 @@
from llama_index.core import VectorStoreIndex
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.embeddings import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess


class BaseEngine:
@classmethod
def __init__(self, platform_name: str, community_id: str) -> None:
"""
initialize the engine to query the database related to a community
and the table related to the platform
Parameters
-----------
platform_name : str
the table representative of data for a specific platform
community_id : str
the database for a community
normally the community database is saved as
`community_{community_id}`
"""
self.platform_name = platform_name
self.community_id = community_id
self.dbname = f"community_{self.community_id}"

def prepare(self, testing=False):
index = self._setup_vector_store_index(
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

def _setup_vector_store_index(
cls,
platform_table_name: str,
dbname: str,
self,
testing: bool = False,
**kwargs,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data
Parameters
------------
testing : bool
for testing purposes
**kwargs :
table_name : str
to override the default table_name
dbname : str
to override the default database name
"""
table_name = kwargs.get("table_name", self.platform_name)
dbname = kwargs.get("dbname", self.dbname)

embed_model: BaseEmbedding
if testing:
from llama_index.core import MockEmbedding

embed_model = MockEmbedding(embed_dim=1024)
else:
embed_model = CohereEmbedding()

pg_vector = PGVectorAccess(
table_name=platform_table_name,
table_name=table_name,
dbname=dbname,
testing=testing,
embed_model=CohereEmbedding(),
embed_model=embed_model,
)
index = pg_vector.load_index()
return index
27 changes: 2 additions & 25 deletions utils/query_engine/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from utils.query_engine.base_engine import BaseEngine


class GDriveQueryEngine(BaseEngine):
platform_name = "gdrive"

def __init__(self, community_id: str) -> None:
dbname = f"community_{community_id}"
self.dbname = dbname

def prepare(self, testing=False):
index = self._setup_vector_store_index(
platform_table_name=self.platform_name,
dbname=self.dbname,
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine
platform_name = "gdrive"
super().__init__(platform_name, community_id)
27 changes: 2 additions & 25 deletions utils/query_engine/github.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
from bot.retrievers.custom_retriever import CustomVectorStoreRetriever
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from utils.query_engine.base_engine import BaseEngine


class GitHubQueryEngine(BaseEngine):
platform_name = "github"

def __init__(self, community_id: str) -> None:
dbname = f"community_{community_id}"
self.dbname = dbname

def prepare(self, testing=False):
index = self._setup_vector_store_index(
platform_table_name=self.platform_name,
dbname=self.dbname,
testing=testing,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = CustomVectorStoreRetriever(
index=index, similarity_top_k=similarity_top_k
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine
platform_name = "github"
super().__init__(platform_name, community_id)
22 changes: 16 additions & 6 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)


class LevelBasedPlatformQueryEngine(CustomQueryEngine, BaseEngine):
class LevelBasedPlatformQueryEngine(CustomQueryEngine):
retriever: BaseRetriever
response_synthesizer: BaseSynthesizer
llm: OpenAI
Expand Down Expand Up @@ -116,17 +116,22 @@ def prepare_platform_engine(
)
llm = kwargs.get("llm", OpenAI("gpt-4"))
qa_prompt_ = kwargs.get("qa_prompt", qa_prompt)
base_engine = BaseEngine(platform_table_name, community_id)
index: VectorStoreIndex = kwargs.get(
"index_raw",
cls._setup_vector_store_index(platform_table_name, dbname, testing),
base_engine._setup_vector_store_index(
testing=testing,
),
)
summary_nodes_filters = kwargs.get("summary_nodes_filters", None)

retriever = index.as_retriever()
cls._summary_vector_store = kwargs.get(
"index_summary",
cls._setup_vector_store_index(
platform_table_name + "_summary", dbname, testing
base_engine._setup_vector_store_index(
table_name=platform_table_name + "_summary",
dbname=dbname,
testing=testing,
),
)._vector_store

Expand Down Expand Up @@ -193,8 +198,13 @@ def prepare_engine_auto_filter(
dbname = f"community_{community_id}"
summary_similarity_top_k, _, d = load_hyperparams()

index_summary = cls._setup_vector_store_index(
platform_table_name + "_summary", dbname, False
base_engine = BaseEngine(
platform_table_name + "_summary",
community_id,
)

index_summary = base_engine._setup_vector_store_index(
testing=False,
)
vector_store = index_summary._vector_store

Expand Down
7 changes: 7 additions & 0 deletions utils/query_engine/notion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from utils.query_engine.base_engine import BaseEngine


class NotionQueryEngine(BaseEngine):
def __init__(self, community_id: str) -> None:
platform_name = "notion"
super().__init__(platform_name, community_id)

0 comments on commit 219c091

Please sign in to comment.