Skip to content

Commit

Permalink
feat: Added automatic data source selector!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Apr 23, 2024
1 parent 6077677 commit ac73706
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 2 deletions.
6 changes: 4 additions & 2 deletions celery_app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tc_messageBroker.rabbit_mq.payload.payload import Payload
from tc_messageBroker.rabbit_mq.queue import Queue
from traceloop.sdk import Traceloop
from utils.data_souce_selector import DataSourceSelector


@app.task
Expand Down Expand Up @@ -71,11 +72,12 @@ def ask_question_auto_search(
# )
logging.info(f"{prefix}Querying the data sources!")
# for now we have just the discord platform
selector = DataSourceSelector()
data_sources = selector.select_data_source(community_id)
response, _ = query_multiple_source(
query=question,
community_id=community_id,
discord=True,
github=True,
**data_sources,
)

# source_nodes_dict: list[dict[str, Any]] = []
Expand Down
148 changes: 148 additions & 0 deletions tests/integration/test_data_source_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from datetime import datetime
from unittest import TestCase

from bson import ObjectId
from utils.data_souce_selector import DataSourceSelector
from utils.mongo import MongoSingleton


class TestDataSourceSelector(TestCase):
def setUp(self) -> None:
self.client = MongoSingleton.get_instance().get_client()
self.community_id = "6579c364f1120850414e0dc4"
self.client["Core"].drop_collection("modules")
self.client["Core"].drop_collection("platforms")

def test_no_community(self):
"""
test if no community selected hivemind modeules
"""
selector = DataSourceSelector()
data_sources = selector.select_data_source(community_id=self.community_id)
self.assertEqual(data_sources, {})

def test_single_platform(self):
platform_id = "6579c364f1120850414e0da1"
self.client["Core"]["platforms"].insert_one(
{
"_id": ObjectId(platform_id),
"name": "discord",
"metadata": {
"name": "TEST",
"channels": ["1234", "4321"],
"roles": ["111", "222"],
},
"community": ObjectId(self.community_id),
"disconnectedAt": None,
"connectedAt": datetime(2023, 12, 1),
"createdAt": datetime(2023, 12, 1),
"updatedAt": datetime(2023, 12, 1),
}
)

self.client["Core"]["modules"].insert_one(
{
"name": "hivemind",
"communityId": ObjectId(self.community_id),
"options": {
"platforms": [
{
"platformId": ObjectId(platform_id),
"fromDate": datetime(2024, 1, 1),
"options": {},
}
]
},
}
)
selector = DataSourceSelector()
data_sources = selector.select_data_source(community_id=self.community_id)
self.assertEqual(
data_sources,
{
"discord": True,
},
)

def test_multiple_platform(self):
platform_id1 = "6579c364f1120850414e0da1"
platform_id2 = "6579c364f1120850414e0da2"
platform_id3 = "6579c364f1120850414e0da3"
self.client["Core"]["platforms"].insert_many(
[
{
"_id": ObjectId(platform_id1),
"name": "discord",
"metadata": {
"name": "TEST",
"channels": ["1234", "4321"],
"roles": ["111", "222"],
},
"community": ObjectId(self.community_id),
"disconnectedAt": None,
"connectedAt": datetime(2023, 12, 1),
"createdAt": datetime(2023, 12, 1),
"updatedAt": datetime(2023, 12, 1),
},
{
"_id": ObjectId(platform_id2),
"name": "github",
"metadata": {
"organizationId": 12345,
},
"community": ObjectId(self.community_id),
"disconnectedAt": None,
"connectedAt": datetime(2023, 12, 1),
"createdAt": datetime(2023, 12, 1),
"updatedAt": datetime(2023, 12, 1),
},
{
"_id": ObjectId(platform_id3),
"name": "discourse",
"metadata": {
"some_id": 133445,
},
"community": ObjectId(self.community_id),
"disconnectedAt": None,
"connectedAt": datetime(2023, 12, 1),
"createdAt": datetime(2023, 12, 1),
"updatedAt": datetime(2023, 12, 1),
},
]
)

self.client["Core"]["modules"].insert_one(
{
"name": "hivemind",
"communityId": ObjectId(self.community_id),
"options": {
"platforms": [
{
"platformId": ObjectId(platform_id1),
"fromDate": datetime(2024, 1, 1),
"options": {},
},
{
"platformId": ObjectId(platform_id2),
"fromDate": datetime(2024, 1, 1),
"options": {},
},
{
"platformId": ObjectId(platform_id3),
"fromDate": datetime(2024, 1, 1),
"options": {},
},
]
},
}
)
selector = DataSourceSelector()
data_sources = selector.select_data_source(community_id=self.community_id)
self.assertEqual(
data_sources,
{
"discord": True,
"github": True,
"discourse": True,
},
)
51 changes: 51 additions & 0 deletions utils/data_souce_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from .mongo import MongoSingleton
from bson import ObjectId


class DataSourceSelector:
def select_data_source(self, community_id: str) -> dict[str, bool]:
"""
Given a community id, find all its data sources selected for hivemind module
Parameters
-----------
community_id : str
id of a community
Returns
----------
data_sources : dict[str, bool]
a dictionary representing what data sources is selcted
for the given community
"""
db_results = self._query_modules_db(community_id)
platforms = list(map(lambda data: data["platform"]["name"], db_results))
data_sources = dict.fromkeys(platforms, True)
return data_sources

def _query_modules_db(self, community_id: str) -> list[dict]:
client = MongoSingleton.get_instance().get_client()

pipeline = [
{"$match": {"name": "hivemind", "communityId": ObjectId(community_id)}},
{"$unwind": "$options.platforms"},
{
"$lookup": {
"from": "platforms",
"localField": "options.platforms.platformId",
"foreignField": "_id",
"as": "platform",
}
},
{"$unwind": "$platform"},
{
"$project": {
"_id": 0,
"platform.name": 1,
}
},
]
cursor = client["Core"]["modules"].aggregate(pipeline)

data_sources = list(cursor)
return data_sources

0 comments on commit ac73706

Please sign in to comment.