Skip to content

Commit

Permalink
feat: Added references to answers!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Dec 2, 2024
1 parent 1758e61 commit c516959
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 19 deletions.
11 changes: 8 additions & 3 deletions routers/amqp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime

from faststream.rabbit import RabbitBroker
from faststream.rabbit.fastapi import Logger, RabbitRouter # type: ignore
from faststream.rabbit.schemas.queue import RabbitQueue
from pydantic import BaseModel
Expand All @@ -9,6 +8,7 @@
from tc_messageBroker.rabbit_mq.queue import Queue
from utils.credentials import load_rabbitmq_credentials
from utils.persist_payload import PersistPayload
from utils.query_engine.prepare_answer_sources import PrepareAnswerSources
from utils.traceloop import init_tracing
from worker.tasks import query_data_sources
from worker.utils.fire_event import job_send
Expand All @@ -32,14 +32,19 @@ async def ask(payload: Payload, logger: Logger):
community_id = payload.content.communityId
init_tracing()
logger.info(f"COMMUNITY_ID: {community_id} Received job")
response = query_data_sources(community_id=community_id, query=question)
response, references = query_data_sources(
community_id=community_id, query=question
)
prepare_answer = PrepareAnswerSources(threshold=0.7)
answer_reference = prepare_answer.prepare_answer_sources(nodes=references)

logger.info(f"COMMUNITY_ID: {community_id} Job finished")

response_payload = AMQPPayload(
communityId=community_id,
route=payload.content.route,
question=payload.content.question,
response=ResponseModel(message=response),
response=ResponseModel(message=f"{response}\n\n{answer_reference}"),
metadata=payload.content.metadata,
)
# dumping the whole payload of question & answer to db
Expand Down
14 changes: 10 additions & 4 deletions subquery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guidance.models import OpenAIChat
from llama_index.core import QueryBundle, Settings
from llama_index.core.query_engine import SubQuestionQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE
from llama_index.core.schema import NodeWithScore
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.llms.openai import OpenAI
Expand All @@ -9,6 +9,7 @@
from utils.qdrant_utils import QDrantUtils
from utils.query_engine import (
DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
CustomSubQuestionQueryEngine,
GDriveQueryEngine,
GitHubQueryEngine,
MediaWikiQueryEngine,
Expand Down Expand Up @@ -208,12 +209,17 @@ def query_multiple_source(
verbose=False,
prompt_template_str=DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL,
)
s_engine = SubQuestionQueryEngine.from_defaults(
s_engine = CustomSubQuestionQueryEngine.from_defaults(
question_gen=question_gen,
query_engine_tools=query_engine_tools,
use_async=False,
verbose=False,
)
query_embedding = embed_model.get_text_embedding(text=query)
response = s_engine.query(QueryBundle(query_str=query, embedding=query_embedding))

return response.response, response.source_nodes
result: tuple[RESPONSE_TYPE, list[NodeWithScore]] = s_engine.query(
QueryBundle(query_str=query, embedding=query_embedding)
)
response, source_nodes = result

return response.response, source_nodes
212 changes: 212 additions & 0 deletions tests/unit/test_prepare_answer_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import unittest
from llama_index.core.query_engine import SubQuestionAnswerPair
from llama_index.core.schema import TextNode, NodeWithScore
from llama_index.core.question_gen.types import SubQuestion
from utils.query_engine.prepare_answer_sources import PrepareAnswerSources


class TestPrepareAnswerSources(unittest.TestCase):
def setUp(self) -> None:
self.prepare = PrepareAnswerSources(threshold=0.7)

def test_empty_nodes_list(self):
"""Test with an empty list of nodes."""
nodes = []
result = self.prepare.prepare_answer_sources(nodes)
self.assertEqual(result, "")

def test_single_tool_with_high_score_urls(self):
"""Test with a single tool containing multiple URLs with scores above threshold."""
node1 = NodeWithScore(
node=TextNode(
text="content 1", metadata={"url": "https://github.com/repo1"}
),
score=0.8,
)
node2 = NodeWithScore(
node=TextNode(
text="content 2", metadata={"url": "https://github.com/repo2"}
),
score=0.9,
)

nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[node1, node2],
)
]
result = self.prepare.prepare_answer_sources(nodes)
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo1\n"
"[2] https://github.com/repo2"
)
self.assertEqual(result, expected)

def test_urls_below_score_threshold(self):
"""Test with URLs that have scores below the 0.7 threshold."""
node1 = NodeWithScore(
node=TextNode(
text="content 1", metadata={"url": "https://github.com/repo1"}
),
score=0.6,
)
node2 = NodeWithScore(
node=TextNode(
text="content 2", metadata={"url": "https://github.com/repo2"}
),
score=0.5,
)

nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[node1, node2],
)
]
result = self.prepare.prepare_answer_sources(nodes)
self.assertEqual(result, "")

def test_mixed_score_urls(self):
"""Test with a mixture of high and low score URLs."""
nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(
text="content 1",
metadata={"url": "https://github.com/repo1"},
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="content 2",
metadata={"url": "https://github.com/repo2"},
),
score=0.6, # Below threshold
),
NodeWithScore(
node=TextNode(
text="content 3",
metadata={"url": "https://github.com/repo3"},
),
score=0.9,
),
],
)
]
result = self.prepare.prepare_answer_sources(nodes)
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo1\n"
"[2] https://github.com/repo3"
)
self.assertEqual(result, expected)

def test_multiple_tools_with_valid_scores(self):
"""Test with multiple tools containing URLs with valid scores."""
nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(
text="content 1",
metadata={"url": "https://github.com/repo1"},
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="content 2",
metadata={"url": "https://github.com/repo2"},
),
score=0.75,
),
],
),
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="stackoverflow", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(
text="content 3",
metadata={"url": "https://stackoverflow.com/q1"},
),
score=0.9,
),
NodeWithScore(
node=TextNode(
text="content 4",
metadata={"url": "https://stackoverflow.com/q2"},
),
score=0.85,
),
],
),
]
result = self.prepare.prepare_answer_sources(nodes)
expected = (
"References:\n"
"github:\n"
"[1] https://github.com/repo1\n"
"[2] https://github.com/repo2\n\n"
"stackoverflow:\n"
"[1] https://stackoverflow.com/q1\n"
"[2] https://stackoverflow.com/q2"
)
self.assertEqual(result, expected)

def test_none_urls_with_valid_scores(self):
"""Test with None URLs that have valid scores."""
nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(text="content 1", metadata={"url": None}),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="content 2",
metadata={"url": "https://github.com/repo2"},
),
score=0.9,
),
],
)
]
result = self.prepare.prepare_answer_sources(nodes)
self.assertEqual(
result, ("References:\n" "github:\n" "[1] https://github.com/repo2")
)

def test_missing_urls_with_valid_scores(self):
"""Test with missing URLs that have valid scores."""
nodes = [
SubQuestionAnswerPair(
sub_q=SubQuestion(tool_name="github", sub_question="Question"),
sources=[
NodeWithScore(
node=TextNode(text="content 1", metadata={}), score=0.8
),
NodeWithScore(
node=TextNode(
text="content 2",
metadata={"url": "https://github.com/repo2"},
),
score=0.9,
),
],
)
]
result = self.prepare.prepare_answer_sources(nodes)
self.assertEqual(
result, ("References:\n" "github:\n" "[1] https://github.com/repo2")
)
2 changes: 1 addition & 1 deletion utils/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def config_mogno_creds(mongo_creds: dict[str, Any]):
host = mongo_creds["host"]
port = mongo_creds["port"]

connection = f"mongodb://{user}:{password}@{host}:{port}"
connection = f"mongodb://{user}:{password}@{host}:{port}?directConnection=true"

return connection
1 change: 1 addition & 0 deletions utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .media_wiki import MediaWikiQueryEngine
from .notion import NotionQueryEngine
from .prepare_discord_query_engine import prepare_discord_engine_auto_filter
from .subquestion_engine import CustomSubQuestionQueryEngine
from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL
from .telegram import TelegramDualQueryEngine, TelegramQueryEngine
from .website import WebsiteQueryEngine
17 changes: 11 additions & 6 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.response_synthesizers import BaseSynthesizer
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.base.response.schema import Response
from llama_index.llms.openai import OpenAI
from schema.type import DataType
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
Expand Down Expand Up @@ -36,7 +38,7 @@ def custom_query(self, query_str: str):
response = self._process_basic_query(query_str)
else:
response = self._process_summary_query(query_str)
return str(response)
return response

@classmethod
def setup_engine(
Expand Down Expand Up @@ -172,14 +174,16 @@ def _setup_vector_store_index(
index = qdrant_vector.load_index()
return index

def _process_basic_query(self, query_str: str) -> str:
nodes = self.retriever.retrieve(query_str)
def _process_basic_query(self, query_str: str) -> Response:
nodes: list[NodeWithScore] = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response

def _process_summary_query(self, query_str: str) -> str:
# return final_response
return Response(response=str(response), source_nodes=nodes)

def _process_summary_query(self, query_str: str) -> Response:
summary_nodes = self.summary_retriever.retrieve(query_str)
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
Expand Down Expand Up @@ -207,4 +211,5 @@ def _process_summary_query(self, query_str: str) -> str:
context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response

return Response(response=str(response), source_nodes=raw_nodes)
3 changes: 2 additions & 1 deletion utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from llama_index.llms.openai import OpenAI
from utils.query_engine.base_pg_engine import BasePGEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils
from llama_index.core.base.response.schema import Response

qa_prompt = PromptTemplate(
"Context information is below.\n"
Expand Down Expand Up @@ -51,7 +52,7 @@ def custom_query(self, query_str: str):
response = self.llm.complete(fmt_qa_prompt)
logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}")

return str(response)
return Response(response=str(response), source_nodes=similar_nodes)

@classmethod
def prepare_platform_engine(
Expand Down
Loading

0 comments on commit c516959

Please sign in to comment.