From f7d62fdf86c2a15c8f65ccbd9d27743e473a84f2 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 16 May 2024 16:11:33 +0330 Subject: [PATCH] feat: Added MediaWiki query engine! --- subquery.py | 17 +++++++++++++++-- tests/unit/test_mediawiki_query_engine.py | 15 +++++++++++++++ utils/query_engine/__init__.py | 1 + utils/query_engine/media_wiki.py | 7 +++++++ 4 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_mediawiki_query_engine.py create mode 100644 utils/query_engine/media_wiki.py diff --git a/subquery.py b/subquery.py index b5be207..f294051 100644 --- a/subquery.py +++ b/subquery.py @@ -11,6 +11,7 @@ DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL, GDriveQueryEngine, GitHubQueryEngine, + MediaWikiQueryEngine, NotionQueryEngine, prepare_discord_engine_auto_filter, ) @@ -72,7 +73,7 @@ def query_multiple_source( # discourse_query_engine: BaseQueryEngine gdrive_query_engine: BaseQueryEngine notion_query_engine: BaseQueryEngine - # media_wiki_query_engine: BaseQueryEngine + media_wiki_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine # query engine perparation @@ -144,7 +145,19 @@ def query_multiple_source( ) ) if media_wiki: - raise NotImplementedError + mediawiki_query_engine = MediaWikiQueryEngine( + community_id=community_id + ).prepare() + tool_metadata = ToolMetadata( + name="WikiPedia", + description="Hosts articles about any information on internet", + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=mediawiki_query_engine, + metadata=tool_metadata, + ) + ) embed_model = CohereEmbedding() llm = OpenAI("gpt-3.5-turbo") diff --git a/tests/unit/test_mediawiki_query_engine.py b/tests/unit/test_mediawiki_query_engine.py new file mode 100644 index 0000000..a909bc2 --- /dev/null +++ b/tests/unit/test_mediawiki_query_engine.py @@ -0,0 +1,15 @@ +from unittest import TestCase + +from bot.retrievers.custom_retriever import CustomVectorStoreRetriever +from utils.query_engine import MediaWikiQueryEngine + + +class TestMediaWikiQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.notion_query_engine = MediaWikiQueryEngine(community_id) + + def test_prepare_engine(self): + notion_query_engine = self.notion_query_engine.prepare(testing=True) + print(notion_query_engine.__dict__) + self.assertIsInstance(notion_query_engine.retriever, CustomVectorStoreRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 0077c43..ed43657 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from .gdrive import GDriveQueryEngine from .github import GitHubQueryEngine +from .media_wiki import MediaWikiQueryEngine from .notion import NotionQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL diff --git a/utils/query_engine/media_wiki.py b/utils/query_engine/media_wiki.py new file mode 100644 index 0000000..e842e06 --- /dev/null +++ b/utils/query_engine/media_wiki.py @@ -0,0 +1,7 @@ +from utils.query_engine.base_engine import BaseEngine + + +class MediaWikiQueryEngine(BaseEngine): + def __init__(self, community_id: str) -> None: + platform_name = "mediawiki" + super().__init__(platform_name, community_id)