Skip to content

Commit

Permalink
update: fixing test cases and adding more!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Feb 12, 2024
1 parent f8bced6 commit 092158b
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,5 @@ cython_debug/
hivemind-bot-env/*
main.ipynb
.DS_Store

temp_test_run_data.json
34 changes: 18 additions & 16 deletions tests/unit/test_level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from utils.query_engine.level_based_platform_query_engine import (
LevelBasedPlatformQueryEngine,
)
from sqlalchemy.exc import OperationalError


class TestLevelBasedPlatformQueryEngine(unittest.TestCase):
Expand All @@ -26,9 +27,9 @@ def test_prepare_platform_engine(self):
"""
# the output should always have a `date` key for each dictionary
filters = [
{"channel": "general", "date": "2023-01-02"},
{"thread": "discussion", "date": "2024-01-03"},
{"date": "2022-01-01"},
{"channel": "general", "thread": "some_thread", "date": "2023-01-02"},
{"channel": "general", "thread": "discussion", "date": "2024-01-03"},
{"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"},
]

engine = LevelBasedPlatformQueryEngine.prepare_platform_engine(
Expand All @@ -44,21 +45,22 @@ def test_prepare_engine_auto_filter(self):
Test prepare_engine_auto_filter method with sample data
"""
with patch.object(
ForumBasedSummaryRetriever, "retreive_filtering"
ForumBasedSummaryRetriever, "define_filters"
) as mock_retriever:
# the output should always have a `date` key for each dictionary
mock_retriever.return_value = [
{"channel": "general", "date": "2023-01-02"},
{"thread": "discussion", "date": "2024-01-03"},
{"date": "2022-01-01"},
{"channel": "general", "thread": "some_thread", "date": "2023-01-02"},
{"channel": "general", "thread": "discussion", "date": "2024-01-03"},
{"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"},
]

engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
platform_table_name=self.platform_table_name,
level1_key=self.level1_key,
level2_key=self.level2_key,
date_key=self.date_key,
)
self.assertIsNotNone(engine)
with self.assertRaises(OperationalError):
# no database with name of `test_community` is available
_ = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
platform_table_name=self.platform_table_name,
level1_key=self.level1_key,
level2_key=self.level2_key,
date_key=self.date_key,
)
211 changes: 211 additions & 0 deletions tests/unit/test_level_based_platform_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import unittest
from llama_index.schema import NodeWithScore, TextNode
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils


class TestLevelBasedPlatformUtils(unittest.TestCase):
def setUp(self):
self.level1_key = "channel"
self.level2_key = "thread"
self.date_key = "date"
self.utils = LevelBasedPlatformUtils(
self.level1_key, self.level2_key, self.date_key
)

def test_prepare_prompt_with_metadata_info(self):
nodes = [
NodeWithScore(
node=TextNode(
text="content1",
metadata={"author_username": "user1", "date": "2022-01-01"},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="content2",
metadata={"author_username": "user2", "date": "2022-01-02"},
),
score=0,
),
]
prefix = " "
expected_output = (
" author: user1\n message_date: 2022-01-01\n message 1: content1\n"
" author: user2\n message_date: 2022-01-02\n message 2: content2"
)
result = self.utils.prepare_prompt_with_metadata_info(nodes, prefix)
self.assertEqual(result, expected_output)

def test_group_nodes_per_metadata(self):
nodes = [
NodeWithScore(
node=TextNode(
text="content1",
metadata={"channel": "A", "thread": "X", "date": "2022-01-01"},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="content2",
metadata={"channel": "A", "thread": "Y", "date": "2022-01-01"},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="content3",
metadata={"channel": "B", "thread": "X", "date": "2022-01-02"},
),
score=0,
),
]
expected_output = {
"A": {"X": {"2022-01-01": [nodes[0]]}, "Y": {"2022-01-01": [nodes[1]]}},
"B": {"X": {"2022-01-02": [nodes[2]]}},
}
result = self.utils.group_nodes_per_metadata(nodes)
self.assertEqual(result, expected_output)

def test_prepare_context_str_based_on_summaries(self):
raw_nodes = [
NodeWithScore(
node=TextNode(
text="raw_content1",
metadata={
"channel": "A",
"thread": "X",
"date": "2022-01-01",
"author_username": "USERNAME#1",
},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="raw_content2",
metadata={
"channel": "A",
"thread": "Y",
"date": "2022-01-04",
"author_username": "USERNAME#2",
},
),
score=0,
),
]
summary_nodes = [
NodeWithScore(
node=TextNode(
text="summary_content",
metadata={"channel": "A", "thread": "X", "date": "2022-01-01"},
),
score=0,
)
]
grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}}
grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}}
expected_output = (
"""channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n"""
""" author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n"""
""" author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n"""
)
result, _ = self.utils.prepare_context_str_based_on_summaries(
grouped_raw_nodes, grouped_summary_nodes
)
self.assertEqual(result.strip(), expected_output.strip())

def test_prepare_context_str_based_on_summaries_no_summary(self):
node1 = NodeWithScore(
node=TextNode(
text="raw_content1",
metadata={
"channel": "A",
"thread": "X",
"date": "2022-01-01",
"author_username": "USERNAME#1",
},
),
score=0,
)
node2 = NodeWithScore(
node=TextNode(
text="raw_content2",
metadata={
"channel": "A",
"thread": "Y",
"date": "2022-01-04",
"author_username": "USERNAME#2",
},
),
score=0,
)
grouped_raw_nodes = {
"A": {"X": {"2022-01-01": [node1]}, "Y": {"2022-01-04": [node2]}}
}
grouped_summary_nodes = {}
result, (
summary_nodes_to_fetch_filters,
raw_nodes_missed,
) = self.utils.prepare_context_str_based_on_summaries(
grouped_raw_nodes, grouped_summary_nodes
)
self.assertEqual(result, "")
self.assertEqual(len(summary_nodes_to_fetch_filters), 2)
for channel in raw_nodes_missed.keys():
self.assertIn(channel, ["A"])
for thread in raw_nodes_missed[channel].keys():
self.assertIn(thread, ["X", "Y"])
for date in raw_nodes_missed[channel][thread]:
self.assertIn(date, ["2022-01-01", "2022-01-04"])
nodes = raw_nodes_missed[channel][thread][date]

if date == "2022-01-01":
self.assertEqual(
grouped_raw_nodes["A"]["X"]["2022-01-01"], nodes
)
elif date == "2022-01-04":
self.assertEqual(
grouped_raw_nodes["A"]["Y"]["2022-01-04"], nodes
)

def test_prepare_context_str_based_on_summaries_multiple_summaries_error(self):
raw_nodes = [
NodeWithScore(
node=TextNode(
text="raw_content1",
metadata={"channel": "A", "thread": "X", "date": "2022-01-01"},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="raw_content2",
metadata={"channel": "A", "thread": "Y", "date": "2022-01-01"},
),
score=0,
),
]
summary_nodes = [
NodeWithScore(
node=TextNode(
text="summary_content1",
metadata={"channel": "A", "thread": "X", "date": "2022-01-01"},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="summary_content2",
metadata={"channel": "A", "thread": "X", "date": "2022-01-01"},
),
score=0,
),
]
grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}}
grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}}
with self.assertRaises(ValueError):
self.utils.prepare_context_str_based_on_summaries(
grouped_raw_nodes, grouped_summary_nodes
)
1 change: 1 addition & 0 deletions utils/query_engine/level_based_platforms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def prepare_context_str_based_on_summaries(
f"{self.level1_key}: {level1_title}, "
f"{self.level2_key}: {level2_title}, "
f"{self.date_key}: {date}"
"\t will fetch them after"
)
summary_nodes_to_fetch_filters.append(
{
Expand Down

0 comments on commit 092158b

Please sign in to comment.