diff --git a/subquery.py b/subquery.py index 97efe52..4b46da2 100644 --- a/subquery.py +++ b/subquery.py @@ -15,6 +15,7 @@ NotionQueryEngine, prepare_discord_engine_auto_filter, ) +from utils.qdrant_utils import QDrantUtils def query_multiple_source( @@ -67,6 +68,7 @@ def query_multiple_source( """ query_engine_tools: list[QueryEngineTool] = [] tools: list[ToolMetadata] = [] + qdrant_utils = QDrantUtils(community_id) discord_query_engine: BaseQueryEngine github_query_engine: BaseQueryEngine @@ -76,6 +78,9 @@ def query_multiple_source( mediawiki_query_engine: BaseQueryEngine # telegram_query_engine: BaseQueryEngine + # wrapper for more clarity + check_collection = qdrant_utils.chech_collection_exist + # query engine perparation # tools_metadata and query_engine_tools if discord: @@ -98,7 +103,7 @@ def query_multiple_source( if discourse: raise NotImplementedError - if gdrive: + if gdrive and check_collection("gdrive"): gdrive_query_engine = GDriveQueryEngine(community_id=community_id).prepare() tool_metadata = ToolMetadata( name="Google-Drive", @@ -113,7 +118,7 @@ def query_multiple_source( metadata=tool_metadata, ) ) - if notion: + if notion and check_collection("notion"): notion_query_engine = NotionQueryEngine(community_id=community_id).prepare() tool_metadata = ToolMetadata( name="Notion", @@ -127,9 +132,9 @@ def query_multiple_source( metadata=tool_metadata, ) ) - if telegram: + if telegram and check_collection("telegram"): raise NotImplementedError - if github: + if github and check_collection("github"): github_query_engine = GitHubQueryEngine(community_id=community_id).prepare() tool_metadata = ToolMetadata( name="GitHub", @@ -144,7 +149,7 @@ def query_multiple_source( metadata=tool_metadata, ) ) - if mediaWiki: + if mediaWiki and check_collection("mediawiki"): mediawiki_query_engine = MediaWikiQueryEngine( community_id=community_id ).prepare() diff --git a/tests/integration/test_qdrant_collection_available.py b/tests/integration/test_qdrant_collection_available.py new file mode 100644 index 0000000..7a325e8 --- /dev/null +++ b/tests/integration/test_qdrant_collection_available.py @@ -0,0 +1,79 @@ +from unittest import TestCase + +from qdrant_client import models +from tc_hivemind_backend.db.qdrant import QdrantSingleton +from utils.qdrant_utils import QDrantUtils + + +class TestQDrantAvailableCollection(TestCase): + def setUp(self) -> None: + self.community_id = "community_sample" + self.qdrant_client = QdrantSingleton.get_instance().get_client() + self.qdrant_utils = QDrantUtils(self.community_id) + + # deleting all collections + collections = self.qdrant_client.get_collections() + for col in collections.collections: + self.qdrant_client.delete_collection(col.name) + + def test_no_collection_available(self): + platform = "platform1" + available = self.qdrant_utils.chech_collection_exist(platform) + + self.assertIsInstance(available, bool) + self.assertFalse(available) + + def test_single_collection_available(self): + platform = "platform1" + collection_name = f"{self.community_id}_{platform}" + self.qdrant_client.create_collection( + collection_name, + vectors_config=models.VectorParams( + size=100, distance=models.Distance.COSINE + ), + ) + available = self.qdrant_utils.chech_collection_exist(platform) + + self.assertIsInstance(available, bool) + self.assertTrue(available) + + def test_multiple_collections_but_not_input(self): + """ + test if there was multiple collections available + but it isn't the collection we want to check for + """ + platforms = ["platform1", "platform2", "platform3"] + for plt in platforms: + collection_name = f"{self.community_id}_{plt}" + self.qdrant_client.create_collection( + collection_name, + vectors_config=models.VectorParams( + size=100, distance=models.Distance.COSINE + ), + ) + + available = self.qdrant_utils.chech_collection_exist("platform4") + + self.assertIsInstance(available, bool) + self.assertFalse(available) + + def test_multiple_collections_available_given_input(self): + """ + test multiple collections available with given input + """ + platforms = ["platform1", "platform2", "platform3"] + for plt in platforms: + collection_name = f"{self.community_id}_{plt}" + self.qdrant_client.create_collection( + collection_name, + vectors_config=models.VectorParams( + size=100, distance=models.Distance.COSINE + ), + ) + + available = self.qdrant_utils.chech_collection_exist( + platforms[0], + ) + + self.assertIsInstance(available, bool) + self.assertTrue(available) diff --git a/utils/qdrant_utils.py b/utils/qdrant_utils.py new file mode 100644 index 0000000..90f9459 --- /dev/null +++ b/utils/qdrant_utils.py @@ -0,0 +1,34 @@ +from tc_hivemind_backend.db.qdrant import QdrantSingleton + + +class QDrantUtils: + def __init__(self, community_id: str) -> None: + """ + setup qdrant utils for a specific community + + Parameters + ------------ + community_id : str + the community we want to initialize the utils for + + """ + self.qdrant_client = QdrantSingleton.get_instance().get_client() + self.community_id = community_id + + def chech_collection_exist(self, platform_name: str) -> bool: + """ + check if the collection exist on qdrant database + + Parameters + ----------- + platform_name : str + the platform name we want to check for its collection availability + + Returns + -------- + available : bool + if the collection was available True, else would be False + """ + collection_name = f"{self.community_id}_{platform_name}" + available = self.qdrant_client.collection_exists(collection_name) + return available