Skip to content

Commit

Permalink
feat: Added node scoring limit to be 0.4!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Dec 17, 2024
1 parent c1b2e5d commit 97c36c0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
11 changes: 7 additions & 4 deletions bot/retrievers/custom_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from llama_index.core.schema import Node, NodeWithScore, ObjectType
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from utils.globals import RETRIEVER_THRESHOLD


class CustomVectorStoreRetriever(VectorIndexRetriever):
Expand Down Expand Up @@ -50,10 +51,12 @@ def _build_node_list_from_query_result(
score: float | None = None
if query_result.similarities is not None:
score = query_result.similarities[ind]
# This is the part we updated
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=score)

if score is not None and score >= RETRIEVER_THRESHOLD:
# This is the part we updated
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=score)

node_with_scores.append(node_with_score)
node_with_scores.append(node_with_score)

return node_with_scores
2 changes: 2 additions & 0 deletions utils/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# the theshold to skip nodes of being included in an answer
RETRIEVER_THRESHOLD = .4
15 changes: 10 additions & 5 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from schema.type import DataType
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils
from utils.globals import RETRIEVER_THRESHOLD

qa_prompt = PromptTemplate(
"Context information is below.\n"
Expand Down Expand Up @@ -176,15 +177,17 @@ def _setup_vector_store_index(

def _process_basic_query(self, query_str: str) -> Response:
nodes: list[NodeWithScore] = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
nodes_filtered = [node for node in nodes if node.score >= RETRIEVER_THRESHOLD]
context_str = "\n\n".join([n.node.get_content() for n in nodes_filtered])
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)

# return final_response
return Response(response=str(response), source_nodes=nodes)
return Response(response=str(response), source_nodes=nodes_filtered)

def _process_summary_query(self, query_str: str) -> Response:
summary_nodes = self.summary_retriever.retrieve(query_str)
summary_nodes_filtered = [node for node in summary_nodes if node.score >= RETRIEVER_THRESHOLD]
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
metadata_date_format=self.metadata_date_format,
Expand All @@ -193,7 +196,7 @@ def _process_summary_query(self, query_str: str) -> Response:

dates = [
node.metadata[self.metadata_date_summary_key]
for node in summary_nodes
for node in summary_nodes_filtered
if self.metadata_date_summary_key in node.metadata
]

Expand All @@ -208,8 +211,10 @@ def _process_summary_query(self, query_str: str) -> Response:
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)
raw_nodes_filtered = [node for node in raw_nodes if node.score >= RETRIEVER_THRESHOLD]

context_str = utils.combine_nodes_for_prompt(summary_nodes_filtered, raw_nodes_filtered)
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)

return Response(response=str(response), source_nodes=raw_nodes)
return Response(response=str(response), source_nodes=raw_nodes_filtered)
6 changes: 4 additions & 2 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llama_index.llms.openai import OpenAI
from utils.query_engine.base_pg_engine import BasePGEngine
from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils
from utils.globals import RETRIEVER_THRESHOLD

qa_prompt = PromptTemplate(
"Context information is below.\n"
Expand Down Expand Up @@ -46,13 +47,14 @@ def custom_query(self, query_str: str):
similar_nodes = retriever.query_db(
query=query_str, filters=self._filters, date_interval=self._d
)
similar_nodes_filtered = [node for node in similar_nodes if node.score >= RETRIEVER_THRESHOLD]

context_str = self._prepare_context_str(similar_nodes, summary_nodes=None)
context_str = self._prepare_context_str(similar_nodes_filtered, summary_nodes=None)
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(fmt_qa_prompt)
logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}")

return Response(response=str(response), source_nodes=similar_nodes)
return Response(response=str(response), source_nodes=similar_nodes_filtered)

@classmethod
def prepare_platform_engine(
Expand Down

0 comments on commit 97c36c0

Please sign in to comment.