diff --git a/subquery.py b/subquery.py index 92d53c0..b5be207 100644 --- a/subquery.py +++ b/subquery.py @@ -11,6 +11,7 @@ DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, GDriveQueryEngine, GitHubQueryEngine, + NotionQueryEngine, prepare_discord_engine_auto_filter, ) @@ -24,6 +25,7 @@ def query_multiple_source( notion: bool = False, telegram: bool = False, github: bool = False, + media_wiki: bool = False, ) -> tuple[str, list[NodeWithScore]]: """ query multiple platforms and get an answer from the multiple @@ -68,8 +70,9 @@ def query_multiple_source( discord_query_engine: BaseQueryEngine github_query_engine: BaseQueryEngine # discourse_query_engine: BaseQueryEngine - # gdrive_query_engine: BaseQueryEngine - # notion_query_engine: BaseQueryEngine + gdrive_query_engine: BaseQueryEngine + notion_query_engine: BaseQueryEngine + # media_wiki_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine # query engine perparation @@ -110,7 +113,19 @@ def query_multiple_source( ) ) if notion: - raise NotImplementedError + notion_query_engine = NotionQueryEngine(community_id=community_id) + tool_metadata = ToolMetadata( + name="Notion", + description=( + "Centralizes notes, wikis, project plans, and to-dos for the community." + ), + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=notion_query_engine, + metadata=tool_metadata, + ) + ) if telegram: raise NotImplementedError if github: @@ -128,6 +143,8 @@ def query_multiple_source( metadata=tool_metadata, ) ) + if media_wiki: + raise NotImplementedError embed_model = CohereEmbedding() llm = OpenAI("gpt-3.5-turbo") diff --git a/tests/unit/test_base_engine.py b/tests/unit/test_base_engine.py index 09820a6..43ef52b 100644 --- a/tests/unit/test_base_engine.py +++ b/tests/unit/test_base_engine.py @@ -10,12 +10,15 @@ def test_setup_vector_store_index(self): and calls its load_index method. """ platform_table_name = "test_table" - dbname = "test_db" - - base_engine = BaseEngine._setup_vector_store_index( - platform_table_name=platform_table_name, - dbname=dbname, + community_id = "123456" + base_engine = BaseEngine( + platform_name=platform_table_name, + community_id=community_id, + ) + base_engine = base_engine._setup_vector_store_index( testing=True, ) - self.assertIn(dbname, base_engine.vector_store.connection_string) + + expected_dbname = f"community_{community_id}" + self.assertIn(expected_dbname, base_engine.vector_store.connection_string) self.assertEqual(base_engine.vector_store.table_name, platform_table_name) diff --git a/tests/unit/test_notion_query_engine.py b/tests/unit/test_notion_query_engine.py new file mode 100644 index 0000000..caaf903 --- /dev/null +++ b/tests/unit/test_notion_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import NotionQueryEngine + + +class TestNotionQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.notion_query_engine = NotionQueryEngine(community_id) + + def test_prepare_engine(self): + notion_query_engine = self.notion_query_engine.prepare(testing=True) + print(notion_query_engine.__dict__) + self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index bfd2879..0077c43 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa from .gdrive import GDriveQueryEngine from .github import GitHubQueryEngine +from .notion import NotionQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL diff --git a/utils/query_engine/base_engine.py b/utils/query_engine/base_engine.py index 72bda45..7f05779 100644 --- a/utils/query_engine/base_engine.py +++ b/utils/query_engine/base_engine.py @@ -1,24 +1,80 @@ -from llama_index.core import VectorStoreIndex +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import VectorStoreIndex, get_response_synthesizer +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.query_engine import RetrieverQueryEngine from tc_hivemind_backend.embeddings import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess class BaseEngine: - @classmethod + def __init__(self, platform_name: str, community_id: str) -> None: + """ + initialize the engine to query the database related to a community + and the table related to the platform + + Parameters + ----------- + platform_name : str + the table representative of data for a specific platform + community_id : str + the database for a community + normally the community database is saved as + `community_{community_id}` + """ + self.platform_name = platform_name + self.community_id = community_id + self.dbname = f"community_{self.community_id}" + + def prepare(self, testing=False): + index = self._setup_vector_store_index( + testing=testing, + ) + _, similarity_top_k, _ = load_hyperparams() + + retriever = CustomVectorStoreRetriever( + index=index, similarity_top_k=similarity_top_k + ) + query_engine = RetrieverQueryEngine( + retriever=retriever, + response_synthesizer=get_response_synthesizer(), + ) + return query_engine + def _setup_vector_store_index( - cls, - platform_table_name: str, - dbname: str, + self, testing: bool = False, + **kwargs, ) -> VectorStoreIndex: """ prepare the vector_store for querying data + + Parameters + ------------ + testing : bool + for testing purposes + **kwargs : + table_name : str + to override the default table_name + dbname : str + to override the default database name """ + table_name = kwargs.get("table_name", self.platform_name) + dbname = kwargs.get("dbname", self.dbname) + + embed_model: BaseEmbedding + if testing: + from llama_index.core import MockEmbedding + + embed_model = MockEmbedding(embed_dim=1024) + else: + embed_model = CohereEmbedding() + pg_vector = PGVectorAccess( - table_name=platform_table_name, + table_name=table_name, dbname=dbname, testing=testing, - embed_model=CohereEmbedding(), + embed_model=embed_model, ) index = pg_vector.load_index() return index diff --git a/utils/query_engine/gdrive.py b/utils/query_engine/gdrive.py index 5a775e2..7be4bc8 100644 --- a/utils/query_engine/gdrive.py +++ b/utils/query_engine/gdrive.py @@ -1,30 +1,7 @@ -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever -from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.core import get_response_synthesizer -from llama_index.core.query_engine import RetrieverQueryEngine from utils.query_engine.base_engine import BaseEngine class GDriveQueryEngine(BaseEngine): - platform_name = "gdrive" - def __init__(self, community_id: str) -> None: - dbname = f"community_{community_id}" - self.dbname = dbname - - def prepare(self, testing=False): - index = self._setup_vector_store_index( - platform_table_name=self.platform_name, - dbname=self.dbname, - testing=testing, - ) - _, similarity_top_k, _ = load_hyperparams() - - retriever = CustomVectorStoreRetriever( - index=index, similarity_top_k=similarity_top_k - ) - query_engine = RetrieverQueryEngine( - retriever=retriever, - response_synthesizer=get_response_synthesizer(), - ) - return query_engine + platform_name = "gdrive" + super().__init__(platform_name, community_id) diff --git a/utils/query_engine/github.py b/utils/query_engine/github.py index b5adf61..ad91ead 100644 --- a/utils/query_engine/github.py +++ b/utils/query_engine/github.py @@ -1,30 +1,7 @@ -from bot.retrievers.custom_retriever import CustomVectorStoreRetriever -from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.core import get_response_synthesizer -from llama_index.core.query_engine import RetrieverQueryEngine from utils.query_engine.base_engine import BaseEngine class GitHubQueryEngine(BaseEngine): - platform_name = "github" - def __init__(self, community_id: str) -> None: - dbname = f"community_{community_id}" - self.dbname = dbname - - def prepare(self, testing=False): - index = self._setup_vector_store_index( - platform_table_name=self.platform_name, - dbname=self.dbname, - testing=testing, - ) - _, similarity_top_k, _ = load_hyperparams() - - retriever = CustomVectorStoreRetriever( - index=index, similarity_top_k=similarity_top_k - ) - query_engine = RetrieverQueryEngine( - retriever=retriever, - response_synthesizer=get_response_synthesizer(), - ) - return query_engine + platform_name = "github" + super().__init__(platform_name, community_id) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 6647876..6c0edf7 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -28,7 +28,7 @@ ) -class LevelBasedPlatformQueryEngine(CustomQueryEngine, BaseEngine): +class LevelBasedPlatformQueryEngine(CustomQueryEngine): retriever: BaseRetriever response_synthesizer: BaseSynthesizer llm: OpenAI @@ -116,17 +116,22 @@ def prepare_platform_engine( ) llm = kwargs.get("llm", OpenAI("gpt-4")) qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) + base_engine = BaseEngine(platform_table_name, community_id) index: VectorStoreIndex = kwargs.get( "index_raw", - cls._setup_vector_store_index(platform_table_name, dbname, testing), + base_engine._setup_vector_store_index( + testing=testing, + ), ) summary_nodes_filters = kwargs.get("summary_nodes_filters", None) retriever = index.as_retriever() cls._summary_vector_store = kwargs.get( "index_summary", - cls._setup_vector_store_index( - platform_table_name + "_summary", dbname, testing + base_engine._setup_vector_store_index( + table_name=platform_table_name + "_summary", + dbname=dbname, + testing=testing, ), )._vector_store @@ -193,8 +198,13 @@ def prepare_engine_auto_filter( dbname = f"community_{community_id}" summary_similarity_top_k, _, d = load_hyperparams() - index_summary = cls._setup_vector_store_index( - platform_table_name + "_summary", dbname, False + base_engine = BaseEngine( + platform_table_name + "_summary", + community_id, + ) + + index_summary = base_engine._setup_vector_store_index( + testing=False, ) vector_store = index_summary._vector_store diff --git a/utils/query_engine/notion.py b/utils/query_engine/notion.py new file mode 100644 index 0000000..c32aebd --- /dev/null +++ b/utils/query_engine/notion.py @@ -0,0 +1,7 @@ +from utils.query_engine.base_engine import BaseEngine + + +class NotionQueryEngine(BaseEngine): + def __init__(self, community_id: str) -> None: + platform_name = "notion" + super().__init__(platform_name, community_id)