From ac73706d91aea1f7628ac9caf8ac73c46fd5e940 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 23 Apr 2024 12:17:41 +0330 Subject: [PATCH] feat: Added automatic data source selector! --- celery_app/tasks.py | 6 +- .../integration/test_data_source_selector.py | 148 ++++++++++++++++++ utils/data_souce_selector.py | 51 ++++++ 3 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_data_source_selector.py create mode 100644 utils/data_souce_selector.py diff --git a/celery_app/tasks.py b/celery_app/tasks.py index eb8f99b..04ead35 100644 --- a/celery_app/tasks.py +++ b/celery_app/tasks.py @@ -17,6 +17,7 @@ from tc_messageBroker.rabbit_mq.payload.payload import Payload from tc_messageBroker.rabbit_mq.queue import Queue from traceloop.sdk import Traceloop +from utils.data_souce_selector import DataSourceSelector @app.task @@ -71,11 +72,12 @@ def ask_question_auto_search( # ) logging.info(f"{prefix}Querying the data sources!") # for now we have just the discord platform + selector = DataSourceSelector() + data_sources = selector.select_data_source(community_id) response, _ = query_multiple_source( query=question, community_id=community_id, - discord=True, - github=True, + **data_sources, ) # source_nodes_dict: list[dict[str, Any]] = [] diff --git a/tests/integration/test_data_source_selector.py b/tests/integration/test_data_source_selector.py new file mode 100644 index 0000000..07a7313 --- /dev/null +++ b/tests/integration/test_data_source_selector.py @@ -0,0 +1,148 @@ +from datetime import datetime +from unittest import TestCase + +from bson import ObjectId +from utils.data_souce_selector import DataSourceSelector +from utils.mongo import MongoSingleton + + +class TestDataSourceSelector(TestCase): + def setUp(self) -> None: + self.client = MongoSingleton.get_instance().get_client() + self.community_id = "6579c364f1120850414e0dc4" + self.client["Core"].drop_collection("modules") + self.client["Core"].drop_collection("platforms") + + def test_no_community(self): + """ + test if no community selected hivemind modeules + """ + selector = DataSourceSelector() + data_sources = selector.select_data_source(community_id=self.community_id) + self.assertEqual(data_sources, {}) + + def test_single_platform(self): + platform_id = "6579c364f1120850414e0da1" + self.client["Core"]["platforms"].insert_one( + { + "_id": ObjectId(platform_id), + "name": "discord", + "metadata": { + "name": "TEST", + "channels": ["1234", "4321"], + "roles": ["111", "222"], + }, + "community": ObjectId(self.community_id), + "disconnectedAt": None, + "connectedAt": datetime(2023, 12, 1), + "createdAt": datetime(2023, 12, 1), + "updatedAt": datetime(2023, 12, 1), + } + ) + + self.client["Core"]["modules"].insert_one( + { + "name": "hivemind", + "communityId": ObjectId(self.community_id), + "options": { + "platforms": [ + { + "platformId": ObjectId(platform_id), + "fromDate": datetime(2024, 1, 1), + "options": {}, + } + ] + }, + } + ) + selector = DataSourceSelector() + data_sources = selector.select_data_source(community_id=self.community_id) + self.assertEqual( + data_sources, + { + "discord": True, + }, + ) + + def test_multiple_platform(self): + platform_id1 = "6579c364f1120850414e0da1" + platform_id2 = "6579c364f1120850414e0da2" + platform_id3 = "6579c364f1120850414e0da3" + self.client["Core"]["platforms"].insert_many( + [ + { + "_id": ObjectId(platform_id1), + "name": "discord", + "metadata": { + "name": "TEST", + "channels": ["1234", "4321"], + "roles": ["111", "222"], + }, + "community": ObjectId(self.community_id), + "disconnectedAt": None, + "connectedAt": datetime(2023, 12, 1), + "createdAt": datetime(2023, 12, 1), + "updatedAt": datetime(2023, 12, 1), + }, + { + "_id": ObjectId(platform_id2), + "name": "github", + "metadata": { + "organizationId": 12345, + }, + "community": ObjectId(self.community_id), + "disconnectedAt": None, + "connectedAt": datetime(2023, 12, 1), + "createdAt": datetime(2023, 12, 1), + "updatedAt": datetime(2023, 12, 1), + }, + { + "_id": ObjectId(platform_id3), + "name": "discourse", + "metadata": { + "some_id": 133445, + }, + "community": ObjectId(self.community_id), + "disconnectedAt": None, + "connectedAt": datetime(2023, 12, 1), + "createdAt": datetime(2023, 12, 1), + "updatedAt": datetime(2023, 12, 1), + }, + ] + ) + + self.client["Core"]["modules"].insert_one( + { + "name": "hivemind", + "communityId": ObjectId(self.community_id), + "options": { + "platforms": [ + { + "platformId": ObjectId(platform_id1), + "fromDate": datetime(2024, 1, 1), + "options": {}, + }, + { + "platformId": ObjectId(platform_id2), + "fromDate": datetime(2024, 1, 1), + "options": {}, + }, + { + "platformId": ObjectId(platform_id3), + "fromDate": datetime(2024, 1, 1), + "options": {}, + }, + ] + }, + } + ) + selector = DataSourceSelector() + data_sources = selector.select_data_source(community_id=self.community_id) + self.assertEqual( + data_sources, + { + "discord": True, + "github": True, + "discourse": True, + }, + ) diff --git a/utils/data_souce_selector.py b/utils/data_souce_selector.py new file mode 100644 index 0000000..10a85fb --- /dev/null +++ b/utils/data_souce_selector.py @@ -0,0 +1,51 @@ +from .mongo import MongoSingleton +from bson import ObjectId + + +class DataSourceSelector: + def select_data_source(self, community_id: str) -> dict[str, bool]: + """ + Given a community id, find all its data sources selected for hivemind module + + Parameters + ----------- + community_id : str + id of a community + + Returns + ---------- + data_sources : dict[str, bool] + a dictionary representing what data sources is selcted + for the given community + """ + db_results = self._query_modules_db(community_id) + platforms = list(map(lambda data: data["platform"]["name"], db_results)) + data_sources = dict.fromkeys(platforms, True) + return data_sources + + def _query_modules_db(self, community_id: str) -> list[dict]: + client = MongoSingleton.get_instance().get_client() + + pipeline = [ + {"$match": {"name": "hivemind", "communityId": ObjectId(community_id)}}, + {"$unwind": "$options.platforms"}, + { + "$lookup": { + "from": "platforms", + "localField": "options.platforms.platformId", + "foreignField": "_id", + "as": "platform", + } + }, + {"$unwind": "$platform"}, + { + "$project": { + "_id": 0, + "platform.name": 1, + } + }, + ] + cursor = client["Core"]["modules"].aggregate(pipeline) + + data_sources = list(cursor) + return data_sources