Skip to content

Commit

Permalink
feat: Added discord raw data channel filtering!
Browse files Browse the repository at this point in the history
- Adjusted the codes to support channel filtering within Module database.
- Note that we would have `Module` as database and `modules` as table name.
  • Loading branch information
amindadgar committed Jan 11, 2024
1 parent 4eea2bf commit f5805b9
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 31 deletions.
66 changes: 48 additions & 18 deletions dags/hivemind_etl_helpers/src/db/discord/fetch_raw_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def fetch_raw_messages(guild_id: str, from_date: datetime | None = None) -> list
"""
client = MongoSingleton.get_instance().get_client()

channels = fetch_channels(guild_id=guild_id)

raw_messages: list[dict]
if from_date is not None:
cursor = (
Expand All @@ -30,6 +32,7 @@ def fetch_raw_messages(guild_id: str, from_date: datetime | None = None) -> list
{
"createdDate": {"$gte": from_date},
"isGeneratedByWebhook": False,
"channelId": {"$in": channels},
}
)
.sort("createdDate", 1)
Expand All @@ -38,7 +41,12 @@ def fetch_raw_messages(guild_id: str, from_date: datetime | None = None) -> list
else:
cursor = (
client[guild_id]["rawinfos"]
.find({"isGeneratedByWebhook": False})
.find(
{
"isGeneratedByWebhook": False,
"channelId": {"$in": channels},
}
)
.sort("createdDate", 1)
)
raw_messages = list(cursor)
Expand Down Expand Up @@ -80,34 +88,29 @@ def fetch_raw_msg_grouped(
"""
client = MongoSingleton.get_instance().client

channels = fetch_channels(guild_id)

# the pipeline to apply through mongodb
pipeline: list[dict] = []

if from_date is not None:
pipeline.append(
{
"$match": {
"$and": [
{
"createdDate": {
"$gte": from_date,
"$lt": datetime.now().replace(
hour=0, minute=0, second=0, microsecond=0
),
}
},
{"isGeneratedByWebhook": False},
]
"createdDate": {
"$gte": from_date,
"$lt": datetime.now().replace(
hour=0, minute=0, second=0, microsecond=0
),
},
"isGeneratedByWebhook": False,
"channelId": {"$in": channels},
}
},
)
else:
pipeline.append(
{
"$match": {
"isGeneratedByWebhook": False,
}
},
{"$match": {"isGeneratedByWebhook": False, "channelId": {"$in": channels}}},
)

# sorting
Expand All @@ -134,6 +137,7 @@ def fetch_raw_msg_grouped(

return raw_messages_grouped


def fetch_channels(guild_id: str):
"""
fetch the channels from modules that we wanted to process
Expand All @@ -148,4 +152,30 @@ def fetch_channels(guild_id: str):
channels : list[str]
the channels to fetch data from
"""
pass
client = MongoSingleton.get_instance().client
platform = client["Core"]["platforms"].find_one(
{"name": "discord", "metadata.id": guild_id},
{
"_id": 1,
"community": 1,
},
)

if platform is None:
raise ValueError(f"No platform with given guild_id: {guild_id} available!")

result = client["Module"]["modules"].find_one(
{
"communityId": platform["community"],
"options.platforms.platformId": platform["_id"],
},
{"_id": 0, "options.platforms.$": 1},
)

channels: list[str]
if result is not None:
channels = result["options"]["platforms"][0]["options"]["channels"]
else:
raise ValueError("No modules set for this community!")

return channels
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from unittest import TestCase
from bson import ObjectId

from datetime import datetime, timedelta
from hivemind_etl_helpers.src.utils.mongo import MongoSingleton
from hivemind_etl_helpers.src.db.discord.fetch_raw_messages import fetch_channels


class TestDiscordFetchModulesChannels(TestCase):
def setup_db(
self,
channels: list[str],
create_modules: bool = True,
create_platform: bool = True,
guild_id: str = "1234",
):
client = MongoSingleton.get_instance().client

community_id = ObjectId("9f59dd4f38f3474accdc8f24")
platform_id = ObjectId("063a2a74282db2c00fbc2428")

client["Module"].drop_collection("modules")
client["Core"].drop_collection("platforms")

if create_modules:
data = {
"name": "hivemind",
"communityId": community_id,
"options": {
"platforms": [
{
"platformId": platform_id,
"options": {
"channels": channels,
"roles": ["role_id"],
"users": ["user_id"],
},
}
]
},
}
client["Module"]["modules"].insert_one(data)

if create_platform:
client["Core"]["platforms"].insert_one(
{
"_id": platform_id,
"name": "discord",
"metadata": {
"action": {
"INT_THR": 1,
"UW_DEG_THR": 1,
"PAUSED_T_THR": 1,
"CON_T_THR": 4,
"CON_O_THR": 3,
"EDGE_STR_THR": 5,
"UW_THR_DEG_THR": 5,
"VITAL_T_THR": 4,
"VITAL_O_THR": 3,
"STILL_T_THR": 2,
"STILL_O_THR": 2,
"DROP_H_THR": 2,
"DROP_I_THR": 1,
},
"window": {"period_size": 7, "step_size": 1},
"id": guild_id,
"isInProgress": False,
"period": datetime.now() - timedelta(days=35),
"icon": "some_icon_hash",
"selectedChannels": channels,
"name": "GuildName",
},
"community": community_id,
"disconnectedAt": None,
"connectedAt": datetime.now(),
"createdAt": datetime.now(),
"updatedAt": datetime.now(),
}
)

def test_fetch_channels(self):
guild_id = "1234"
channels = ["111111", "22222"]
self.setup_db(
create_modules=True,
create_platform=True,
guild_id=guild_id,
channels=channels,
)
channels = fetch_channels(guild_id="1234")

self.assertEqual(channels, channels)

def test_fetch_channels_no_modules_available(self):
guild_id = "1234"
channels = ["111111", "22222"]
self.setup_db(
create_modules=False,
create_platform=True,
guild_id=guild_id,
channels=channels,
)
with self.assertRaises(ValueError):
_ = fetch_channels(guild_id="1234")

def test_fetch_channels_no_platform_available(self):
guild_id = "1234"
channels = ["111111", "22222"]
self.setup_db(
create_modules=True,
create_platform=False,
guild_id=guild_id,
channels=channels,
)

with self.assertRaises(ValueError):
_ = fetch_channels(guild_id="1234")
Original file line number Diff line number Diff line change
@@ -1,24 +1,100 @@
import unittest
from datetime import datetime
from datetime import datetime, timedelta
from bson import ObjectId

import numpy as np
from hivemind_etl_helpers.src.db.discord.fetch_raw_messages import fetch_raw_messages
from hivemind_etl_helpers.src.utils.mongo import MongoSingleton


class TestFetchRawMessages(unittest.TestCase):
def test_fetch_raw_messages_fetch_all(self):
def setup_db(
self,
channels: list[str],
create_modules: bool = True,
create_platform: bool = True,
guild_id: str = "1234",
):
client = MongoSingleton.get_instance().client

community_id = ObjectId("9f59dd4f38f3474accdc8f24")
platform_id = ObjectId("063a2a74282db2c00fbc2428")

client["Module"].drop_collection("modules")
client["Core"].drop_collection("platforms")

if create_modules:
data = {
"name": "hivemind",
"communityId": community_id,
"options": {
"platforms": [
{
"platformId": platform_id,
"options": {
"channels": channels,
"roles": ["role_id"],
"users": ["user_id"],
},
}
]
},
}
client["Module"]["modules"].insert_one(data)

if create_platform:
client["Core"]["platforms"].insert_one(
{
"_id": platform_id,
"name": "discord",
"metadata": {
"action": {
"INT_THR": 1,
"UW_DEG_THR": 1,
"PAUSED_T_THR": 1,
"CON_T_THR": 4,
"CON_O_THR": 3,
"EDGE_STR_THR": 5,
"UW_THR_DEG_THR": 5,
"VITAL_T_THR": 4,
"VITAL_O_THR": 3,
"STILL_T_THR": 2,
"STILL_O_THR": 2,
"DROP_H_THR": 2,
"DROP_I_THR": 1,
},
"window": {"period_size": 7, "step_size": 1},
"id": guild_id,
"isInProgress": False,
"period": datetime.now() - timedelta(days=35),
"icon": "some_icon_hash",
"selectedChannels": channels,
"name": "GuildName",
},
"community": community_id,
"disconnectedAt": None,
"connectedAt": datetime.now(),
"createdAt": datetime.now(),
"updatedAt": datetime.now(),
}
)

def test_fetch_raw_messages_fetch_all(self):
client = MongoSingleton.get_instance().client
channels = ["111111", "22222"]
guild_id = "1234"
self.setup_db(
channels=channels,
guild_id=guild_id,
)

# droping any previous data
client[guild_id].drop_collection("rawinfos")

message_count = 2

raw_data = []
for _ in range(message_count):
for i in range(message_count):
data = {
"type": 0,
"author": str(np.random.randint(100000, 999999)),
Expand All @@ -29,7 +105,7 @@ def test_fetch_raw_messages_fetch_all(self):
"replied_user": None,
"createdDate": datetime.now(),
"messageId": str(np.random.randint(1000000, 9999999)),
"channelId": str(np.random.randint(10000000, 99999999)),
"channelId": channels[i % len(channels)],
"channelName": "general",
"threadId": None,
"threadName": None,
Expand Down Expand Up @@ -59,6 +135,13 @@ def test_fetch_raw_messages_fetch_all_no_data_available(self):
client = MongoSingleton.get_instance().client

guild_id = "1234"

channels = ["111111", "22222"]
guild_id = "1234"
self.setup_db(
channels=channels,
guild_id=guild_id,
)
# droping any previous data
client[guild_id].drop_collection("rawinfos")

Expand All @@ -71,6 +154,12 @@ def test_fetch_raw_messages_fetch_from_date(self):
client = MongoSingleton.get_instance().client

guild_id = "1234"
channels = ["111111", "22222"]
guild_id = "1234"
self.setup_db(
channels=channels,
guild_id=guild_id,
)

# Dropping any previous data
client[guild_id].drop_collection("rawinfos")
Expand All @@ -90,8 +179,8 @@ def test_fetch_raw_messages_fetch_from_date(self):
2023, 10, i + 1
), # Different dates in October 2023
"messageId": str(np.random.randint(1000000, 9999999)),
"channelId": str(np.random.randint(10000000, 99999999)),
"channelName": "general",
"channelId": channels[i % len(channels)],
"channelName": f"general {channels[i % len(channels)]}",
"threadId": None,
"threadName": None,
"isGeneratedByWebhook": False,
Expand Down
Loading

0 comments on commit f5805b9

Please sign in to comment.