From b9f5d7c30da24f93f8e3f9f7c211483e0046f6fe Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 16:02:51 +0330 Subject: [PATCH] update: Added more tests to increase test coverage! --- .../test_level_based_platform_query_engine.py | 48 ++++- ...d_platform_query_engine_prepare_context.py | 200 ++++++++++++++++++ tests/unit/test_level_based_platform_util.py | 10 +- .../level_based_platforms_util.py | 1 + 4 files changed, 253 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_level_based_platform_query_engine_prepare_context.py diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py index d4a1286..15d1243 100644 --- a/tests/unit/test_level_based_platform_query_engine.py +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -3,6 +3,8 @@ from unittest.mock import patch from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode from sqlalchemy.exc import OperationalError from utils.query_engine.level_based_platform_query_engine import ( LevelBasedPlatformQueryEngine, @@ -40,9 +42,10 @@ def test_prepare_platform_engine(self): ) self.assertIsNotNone(engine) - def test_prepare_engine_auto_filter(self): + def test_prepare_engine_auto_filter_raise_error(self): """ Test prepare_engine_auto_filter method with sample data + when an error was raised """ with patch.object( ForumBasedSummaryRetriever, "define_filters" @@ -64,3 +67,46 @@ def test_prepare_engine_auto_filter(self): level2_key=self.level2_key, date_key=self.date_key, ) + + def test_prepare_engine_auto_filter(self): + """ + Test prepare_engine_auto_filter method with sample data in normal condition + """ + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + # the output should always have a `date` key for each dictionary + mock_query.return_value = [ + NodeWithScore( + node=TextNode( + text="some summaries #1", + metadata={ + "thread": "thread#1", + "channel": "channel#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="some summaries #2", + metadata={ + "thread": "thread#3", + "channel": "channel#2", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + # no database with name of `test_community` is available + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + self.assertIsNotNone(engine) diff --git a/tests/unit/test_level_based_platform_query_engine_prepare_context.py b/tests/unit/test_level_based_platform_query_engine_prepare_context.py new file mode 100644 index 0000000..e28469c --- /dev/null +++ b/tests/unit/test_level_based_platform_query_engine_prepare_context.py @@ -0,0 +1,200 @@ +import os +import unittest +from unittest.mock import patch + +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode +from sqlalchemy.exc import OperationalError +from utils.query_engine.level_based_platform_query_engine import ( + LevelBasedPlatformQueryEngine, +) + + +class TestLevelBasedPlatformQueryEngine(unittest.TestCase): + def setUp(self): + """ + Set up common parameters for testing + """ + self.community_id = "test_community" + self.level1_key = "channel" + self.level2_key = "thread" + self.platform_table_name = "discord" + self.date_key = "date" + os.environ["OPENAI_API_KEY"] = "sk-some_creds" + + def test_prepare_context_str_without_summaries(self): + """ + test prepare the context string while not having the summaries nodes + """ + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + summary_nodes = [] + mock_query.return_value = summary_nodes + + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={ + "author_username": "user1", + "channel": "channel#1", + "thread": "thread#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={ + "author_username": "user2", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content4", + metadata={ + "author_username": "user3", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + contest_str = engine._prepare_context_str(raw_nodes, summary_nodes) + expected_context_str = ( + "author: user1\n" + "message_date: 2022-01-01\n" + "message 1: content1\n\n" + "author: user2\n" + "message_date: 2022-01-02\n" + "message 2: content2\n\n" + "author: user3\n" + "message_date: 2022-01-02\n" + "message 3: content4\n" + ) + self.assertEqual(contest_str, expected_context_str) + + def test_prepare_context_str_with_summaries(self): + """ + test prepare the context string having the summaries nodes + """ + + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="some summaries #1", + metadata={ + "thread": "thread#1", + "channel": "channel#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="some summaries #2", + metadata={ + "thread": "thread#3", + "channel": "channel#2", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + mock_query.return_value = summary_nodes + + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={ + "author_username": "user1", + "channel": "channel#1", + "thread": "thread#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={ + "author_username": "user2", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content4", + metadata={ + "author_username": "user3", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + contest_str = engine._prepare_context_str(raw_nodes, summary_nodes) + expected_context_str = ( + "channel: channel#1\n" + "thread: thread#1\n" + "date: 2022-01-01\n" + "summary: some summaries #1\n" + "messages:\n" + " author: user1\n" + " message_date: 2022-01-01\n" + " message 1: content1\n\n" + "channel: channel#2\n" + "thread: thread#3\n" + "date: 2022-01-02\n" + "summary: some summaries #2\n" + "messages:\n" + " author: user2\n" + " message_date: 2022-01-02\n" + " message 1: content2\n\n" + " author: user3\n" + " message_date: 2022-01-02\n" + " message 2: content4\n\n" + ) + self.assertEqual(contest_str, expected_context_str) diff --git a/tests/unit/test_level_based_platform_util.py b/tests/unit/test_level_based_platform_util.py index 57f0c22..cc7d721 100644 --- a/tests/unit/test_level_based_platform_util.py +++ b/tests/unit/test_level_based_platform_util.py @@ -32,8 +32,8 @@ def test_prepare_prompt_with_metadata_info(self): ] prefix = " " expected_output = ( - " author: user1\n message_date: 2022-01-01\n message 1: content1\n" - " author: user2\n message_date: 2022-01-02\n message 2: content2" + " author: user1\n message_date: 2022-01-01\n message 1: content1\n\n" + " author: user2\n message_date: 2022-01-02\n message 2: content2\n" ) result = self.utils.prepare_prompt_with_metadata_info(nodes, prefix) self.assertEqual(result, expected_output) @@ -108,9 +108,9 @@ def test_prepare_context_str_based_on_summaries(self): grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}} grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}} expected_output = ( - """channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n""" - """ author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n""" - """ author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n""" + "channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n" + " author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n\n" + " author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n\n" ) result, _ = self.utils.prepare_context_str_based_on_summaries( grouped_raw_nodes, grouped_summary_nodes diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index b1730fd..d5e6eaa 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -30,6 +30,7 @@ def prepare_prompt_with_metadata_info( + prefix + f"message {idx + 1}: " + node.get_content() + + "\n" for idx, node in enumerate(nodes) ] )