From 43ca735eb18903df345ccabb9d422522f5010ec7 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 3 Dec 2024 11:15:51 +0330 Subject: [PATCH] feat: Added the max reference per data source support! --- tests/unit/test_prepare_answer_sources.py | 97 +++++++++++++++++++- utils/query_engine/prepare_answer_sources.py | 73 ++++++++++++--- 2 files changed, 150 insertions(+), 20 deletions(-) diff --git a/tests/unit/test_prepare_answer_sources.py b/tests/unit/test_prepare_answer_sources.py index 1f114a3..ef06bd2 100644 --- a/tests/unit/test_prepare_answer_sources.py +++ b/tests/unit/test_prepare_answer_sources.py @@ -8,7 +8,7 @@ class TestPrepareAnswerSources(unittest.TestCase): def setUp(self) -> None: - self.prepare = PrepareAnswerSources(threshold=0.7) + self.prepare = PrepareAnswerSources(threshold=0.7, max_refs_per_source=3) def test_empty_nodes_list(self): """Test with an empty list of nodes.""" @@ -41,8 +41,8 @@ def test_single_tool_with_high_score_urls(self): expected = ( "References:\n" "github:\n" - "[1] https://github.com/repo1\n" - "[2] https://github.com/repo2" + "[1] https://github.com/repo2\n" # Higher score (0.9) should come first + "[2] https://github.com/repo1" ) self.assertEqual(result, expected) @@ -104,8 +104,8 @@ def test_mixed_score_urls(self): expected = ( "References:\n" "github:\n" - "[1] https://github.com/repo1\n" - "[2] https://github.com/repo3" + "[1] https://github.com/repo3\n" # Highest score (0.9) should come first + "[2] https://github.com/repo1" ) self.assertEqual(result, expected) @@ -211,3 +211,90 @@ def test_missing_urls_with_valid_scores(self): self.assertEqual( result, ("References:\n" "github:\n" "[1] https://github.com/repo2") ) + + def test_max_refs_per_source_limit(self): + """Test that the number of references per source respects the max_refs_per_source limit.""" + nodes = [ + SubQuestionAnswerPair( + sub_q=SubQuestion(tool_name="github", sub_question="Question"), + sources=[ + NodeWithScore( + node=TextNode( + text="content 1", + metadata={"url": "https://github.com/repo1"}, + ), + score=0.8, + ), + NodeWithScore( + node=TextNode( + text="content 2", + metadata={"url": "https://github.com/repo2"}, + ), + score=0.9, + ), + NodeWithScore( + node=TextNode( + text="content 3", + metadata={"url": "https://github.com/repo3"}, + ), + score=0.85, + ), + NodeWithScore( + node=TextNode( + text="content 4", + metadata={"url": "https://github.com/repo4"}, + ), + score=0.75, + ), + ], + ) + ] + result = self.prepare.prepare_answer_sources(nodes) + expected = ( + "References:\n" + "github:\n" + "[1] https://github.com/repo2\n" # Highest score (0.9) + "[2] https://github.com/repo3\n" # Second highest (0.85) + "[3] https://github.com/repo1" # Third highest (0.8) + ) + self.assertEqual(result, expected) + + def test_custom_max_refs_per_source(self): + """Test with a custom max_refs_per_source value.""" + prepare_custom = PrepareAnswerSources(threshold=0.7, max_refs_per_source=2) + nodes = [ + SubQuestionAnswerPair( + sub_q=SubQuestion(tool_name="github", sub_question="Question"), + sources=[ + NodeWithScore( + node=TextNode( + text="content 1", + metadata={"url": "https://github.com/repo1"}, + ), + score=0.8, + ), + NodeWithScore( + node=TextNode( + text="content 2", + metadata={"url": "https://github.com/repo2"}, + ), + score=0.9, + ), + NodeWithScore( + node=TextNode( + text="content 3", + metadata={"url": "https://github.com/repo3"}, + ), + score=0.85, + ), + ], + ) + ] + result = prepare_custom.prepare_answer_sources(nodes) + expected = ( + "References:\n" + "github:\n" + "[1] https://github.com/repo2\n" # Highest score (0.9) + "[2] https://github.com/repo3" # Second highest (0.85) + ) + self.assertEqual(result, expected) diff --git a/utils/query_engine/prepare_answer_sources.py b/utils/query_engine/prepare_answer_sources.py index e050211..a0091ad 100644 --- a/utils/query_engine/prepare_answer_sources.py +++ b/utils/query_engine/prepare_answer_sources.py @@ -4,29 +4,64 @@ class PrepareAnswerSources: - def __init__(self, threshold: float = 0.5) -> None: + def __init__(self, threshold: float = 0.5, max_refs_per_source: int = 3) -> None: + """ + Initialize the PrepareAnswerSources class. + + Parameters + ---------- + threshold : float, optional + Minimum score threshold for including a node's URL, by default 0.5 + max_refs_per_source : int, optional + Maximum number of references to include per data source, by default 3 + """ self.threshold = threshold + self.max_refs_per_source = max_refs_per_source def prepare_answer_sources(self, nodes: list[NodeWithScore]) -> str: """ - Prepares a formatted string containing source URLs from the provided nodes. + Prepares a formatted string containing source URLs organized by tool name from the provided nodes. - This method extracts URLs from the metadata of each node and combines them - into a newline-separated string. Only nodes with valid URLs in their metadata - are included in the output. + This method processes a list of nodes, filtering them based on a score threshold and + organizing the URLs by their associated tool names. It creates a formatted output with + URLs numbered under their respective tool sections, limiting the number of references + per data source. Parameters - ------------ + ---------- nodes : list[NodeWithScore] - A list of nodes that was used for answering a question. Each node - should have a metadata attribute containing an optional 'url' field. - + A list of node collections used for answering a question. Each node collection + contains: + - sub_q.tool_name: Name of the tool that generated these nodes + - sources: List of nodes, each containing: + - score: Relevance score of the node + - metadata: Dictionary containing an optional 'url' field Returns ------- - all_sources : str - A newline-separated string of source URLs. Returns an empty string if - no valid URLs are found in the nodes' metadata. + str + A formatted string containing numbered URLs organized by tool name, with the format: + References: + {tool_name}: + [1] {url1} + [2] {url2} + + Returns an empty string if: + - The input nodes list is empty + - No nodes meet the score threshold + - No valid URLs are found in the nodes' metadata + + Notes + ----- + - URLs are only included if their node's score meets or exceeds the threshold + (default: 0.5) + - Each tool's sources are grouped together and prefixed with the tool name + - URLs are numbered sequentially within each tool's section + - Maximum number of references per data source is limited by max_refs_per_source + (default: 3) + - References are selected based on highest scores when limiting + - Logs error messages when no nodes are available or when all nodes are below + the threshold """ if len(nodes) == 0: logging.error("No reference nodes available! returning empty string.") @@ -37,12 +72,20 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore]) -> str: for tool_nodes in nodes: # platform name tool_name = tool_nodes.sub_q.tool_name - urls = [ - node.metadata.get("url") + + # Filter and sort nodes by score + valid_nodes = [ + node for node in tool_nodes.sources if node.score >= self.threshold and node.metadata.get("url") is not None ] - if urls: + valid_nodes.sort(key=lambda x: x.score, reverse=True) + + # Limit the number of references + limited_nodes = valid_nodes[: self.max_refs_per_source] + + if limited_nodes: + urls = [node.metadata.get("url") for node in limited_nodes] sources: list[str] = [ f"[{idx + 1}] {url}" for idx, url in enumerate(urls) ]