From f9dff40702d1e85ee6bd916846222c5b95cdee1a Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 27 Nov 2024 10:26:52 +0330 Subject: [PATCH] feat: Added website data source! --- subquery.py | 18 ++++++++++++++++++ tests/unit/test_website_query_engine.py | 17 +++++++++++++++++ utils/query_engine/__init__.py | 1 + utils/query_engine/website.py | 7 +++++++ 4 files changed, 43 insertions(+) create mode 100644 tests/unit/test_website_query_engine.py create mode 100644 utils/query_engine/website.py diff --git a/subquery.py b/subquery.py index 1dacf83..d28a415 100644 --- a/subquery.py +++ b/subquery.py @@ -15,6 +15,7 @@ NotionQueryEngine, TelegramDualQueryEngine, TelegramQueryEngine, + WebsiteQueryEngine, prepare_discord_engine_auto_filter, ) @@ -29,6 +30,7 @@ def query_multiple_source( telegram: bool = False, github: bool = False, mediaWiki: bool = False, + website: bool = False, ) -> tuple[str, list[NodeWithScore]]: """ query multiple platforms and get an answer from the multiple @@ -180,6 +182,22 @@ def query_multiple_source( ) ) + if website and check_collection("website"): + website_query_engine = WebsiteQueryEngine(community_id=community_id).prepare() + tool_metadata = ToolMetadata( + name="Website", + description=( + "Hosts a diverse collection of crawled data from various " + "online sources to facilitate community insights and analysis." + ), + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=website_query_engine, + metadata=tool_metadata, + ) + ) + embed_model = CohereEmbedding() llm = OpenAI("gpt-3.5-turbo") Settings.embed_model = embed_model diff --git a/tests/unit/test_website_query_engine.py b/tests/unit/test_website_query_engine.py new file mode 100644 index 0000000..056d76d --- /dev/null +++ b/tests/unit/test_website_query_engine.py @@ -0,0 +1,17 @@ +from unittest import TestCase + +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) +from utils.query_engine import WebsiteQueryEngine + + +class TestNotionQueryEngine(TestCase): + def setUp(self) -> None: + community_id = "sample_community" + self.notion_query_engine = WebsiteQueryEngine(community_id) + + def test_prepare_engine(self): + notion_query_engine = self.notion_query_engine.prepare(testing=True) + print(notion_query_engine.__dict__) + self.assertIsInstance(notion_query_engine.retriever, VectorIndexRetriever) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 71642ec..607e46c 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -7,3 +7,4 @@ from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL from .telegram import TelegramDualQueryEngine, TelegramQueryEngine +from .website import WebsiteQueryEngine diff --git a/utils/query_engine/website.py b/utils/query_engine/website.py new file mode 100644 index 0000000..6097cfa --- /dev/null +++ b/utils/query_engine/website.py @@ -0,0 +1,7 @@ +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine + + +class WebsiteQueryEngine(BaseQdrantEngine): + def __init__(self, community_id: str) -> None: + platform_name = "website" + super().__init__(platform_name, community_id)