Skip to content

Commit

Permalink
feat: Added the max reference per data source support!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Dec 3, 2024
1 parent 2496c2d commit 43ca735
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 20 deletions.
97 changes: 92 additions & 5 deletions tests/unit/test_prepare_answer_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class TestPrepareAnswerSources(unittest.TestCase):
def setUp(self) -> None:
self.prepare = PrepareAnswerSources(threshold=0.7)
self.prepare = PrepareAnswerSources(threshold=0.7, max_refs_per_source=3)

def test_empty_nodes_list(self):
"""Test with an empty list of nodes."""
Expand Down Expand Up @@ -41,8 +41,8 @@ def test_single_tool_with_high_score_urls(self):
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo1\n"
"[2] https://github.com/repo2"
"[1] https://github.com/repo2\n" # Higher score (0.9) should come first
"[2] https://github.com/repo1"
)
self.assertEqual(result, expected)

Expand Down Expand Up @@ -104,8 +104,8 @@ def test_mixed_score_urls(self):
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo1\n"
"[2] https://github.com/repo3"
"[1] https://github.com/repo3\n" # Highest score (0.9) should come first
"[2] https://github.com/repo1"
)
self.assertEqual(result, expected)

Expand Down Expand Up @@ -211,3 +211,90 @@ def test_missing_urls_with_valid_scores(self):
self.assertEqual(
result, ("References:\n" "github:\n" "[1] https://github.com/repo2")
)

def test_max_refs_per_source_limit(self):
"""Test that the number of references per source respects the max_refs_per_source limit."""
nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(
text="content 1",
metadata={"url": "https://github.com/repo1"},
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="content 2",
metadata={"url": "https://github.com/repo2"},
),
score=0.9,
),
NodeWithScore(
node=TextNode(
text="content 3",
metadata={"url": "https://github.com/repo3"},
),
score=0.85,
),
NodeWithScore(
node=TextNode(
text="content 4",
metadata={"url": "https://github.com/repo4"},
),
score=0.75,
),
],
)
]
result = self.prepare.prepare_answer_sources(nodes)
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo2\n" # Highest score (0.9)
"[2] https://github.com/repo3\n" # Second highest (0.85)
"[3] https://github.com/repo1" # Third highest (0.8)
)
self.assertEqual(result, expected)

def test_custom_max_refs_per_source(self):
"""Test with a custom max_refs_per_source value."""
prepare_custom = PrepareAnswerSources(threshold=0.7, max_refs_per_source=2)
nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(
text="content 1",
metadata={"url": "https://github.com/repo1"},
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="content 2",
metadata={"url": "https://github.com/repo2"},
),
score=0.9,
),
NodeWithScore(
node=TextNode(
text="content 3",
metadata={"url": "https://github.com/repo3"},
),
score=0.85,
),
],
)
]
result = prepare_custom.prepare_answer_sources(nodes)
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo2\n" # Highest score (0.9)
"[2] https://github.com/repo3" # Second highest (0.85)
)
self.assertEqual(result, expected)
73 changes: 58 additions & 15 deletions utils/query_engine/prepare_answer_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,64 @@


class PrepareAnswerSources:
def __init__(self, threshold: float = 0.5) -> None:
def __init__(self, threshold: float = 0.5, max_refs_per_source: int = 3) -> None:
"""
Initialize the PrepareAnswerSources class.
Parameters
----------
threshold : float, optional
Minimum score threshold for including a node's URL, by default 0.5
max_refs_per_source : int, optional
Maximum number of references to include per data source, by default 3
"""
self.threshold = threshold
self.max_refs_per_source = max_refs_per_source

def prepare_answer_sources(self, nodes: list[NodeWithScore]) -> str:
"""
Prepares a formatted string containing source URLs from the provided nodes.
Prepares a formatted string containing source URLs organized by tool name from the provided nodes.
This method extracts URLs from the metadata of each node and combines them
into a newline-separated string. Only nodes with valid URLs in their metadata
are included in the output.
This method processes a list of nodes, filtering them based on a score threshold and
organizing the URLs by their associated tool names. It creates a formatted output with
URLs numbered under their respective tool sections, limiting the number of references
per data source.
Parameters
------------
----------
nodes : list[NodeWithScore]
A list of nodes that was used for answering a question. Each node
should have a metadata attribute containing an optional 'url' field.
A list of node collections used for answering a question. Each node collection
contains:
- sub_q.tool_name: Name of the tool that generated these nodes
- sources: List of nodes, each containing:
- score: Relevance score of the node
- metadata: Dictionary containing an optional 'url' field
Returns
-------
all_sources : str
A newline-separated string of source URLs. Returns an empty string if
no valid URLs are found in the nodes' metadata.
str
A formatted string containing numbered URLs organized by tool name, with the format:
References:
{tool_name}:
[1] {url1}
[2] {url2}
Returns an empty string if:
- The input nodes list is empty
- No nodes meet the score threshold
- No valid URLs are found in the nodes' metadata
Notes
-----
- URLs are only included if their node's score meets or exceeds the threshold
(default: 0.5)
- Each tool's sources are grouped together and prefixed with the tool name
- URLs are numbered sequentially within each tool's section
- Maximum number of references per data source is limited by max_refs_per_source
(default: 3)
- References are selected based on highest scores when limiting
- Logs error messages when no nodes are available or when all nodes are below
the threshold
"""
if len(nodes) == 0:
logging.error("No reference nodes available! returning empty string.")
Expand All @@ -37,12 +72,20 @@ def prepare_answer_sources(self, nodes: list[NodeWithScore]) -> str:
for tool_nodes in nodes:
# platform name
tool_name = tool_nodes.sub_q.tool_name
urls = [
node.metadata.get("url")

# Filter and sort nodes by score
valid_nodes = [
node
for node in tool_nodes.sources
if node.score >= self.threshold and node.metadata.get("url") is not None
]
if urls:
valid_nodes.sort(key=lambda x: x.score, reverse=True)

# Limit the number of references
limited_nodes = valid_nodes[: self.max_refs_per_source]

if limited_nodes:
urls = [node.metadata.get("url") for node in limited_nodes]
sources: list[str] = [
f"[{idx + 1}] {url}" for idx, url in enumerate(urls)
]
Expand Down

0 comments on commit 43ca735

Please sign in to comment.