diff --git a/bot/retrievers/custom_retriever.py b/bot/retrievers/custom_retriever.py index bb05ec0..046842e 100644 --- a/bot/retrievers/custom_retriever.py +++ b/bot/retrievers/custom_retriever.py @@ -10,6 +10,7 @@ ) from llama_index.core.schema import Node, NodeWithScore, ObjectType from llama_index.core.vector_stores.types import VectorStoreQueryResult +from utils.globals import RETRIEVER_THRESHOLD class CustomVectorStoreRetriever(VectorIndexRetriever): @@ -50,10 +51,12 @@ def _build_node_list_from_query_result( score: float | None = None if query_result.similarities is not None: score = query_result.similarities[ind] - # This is the part we updated - node_new = Node.from_dict(node.to_dict()) - node_with_score = NodeWithScore(node=node_new, score=score) + + if score is not None and score >= RETRIEVER_THRESHOLD: + # This is the part we updated + node_new = Node.from_dict(node.to_dict()) + node_with_score = NodeWithScore(node=node_new, score=score) - node_with_scores.append(node_with_score) + node_with_scores.append(node_with_score) return node_with_scores diff --git a/utils/globals.py b/utils/globals.py new file mode 100644 index 0000000..098c67a --- /dev/null +++ b/utils/globals.py @@ -0,0 +1,2 @@ +# the theshold to skip nodes of being included in an answer +RETRIEVER_THRESHOLD = .4 \ No newline at end of file diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index ebcb250..d7cecd2 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -12,6 +12,7 @@ from schema.type import DataType from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils +from utils.globals import RETRIEVER_THRESHOLD qa_prompt = PromptTemplate( "Context information is below.\n" @@ -176,15 +177,17 @@ def _setup_vector_store_index( def _process_basic_query(self, query_str: str) -> Response: nodes: list[NodeWithScore] = self.retriever.retrieve(query_str) - context_str = "\n\n".join([n.node.get_content() for n in nodes]) + nodes_filtered = [node for node in nodes if node.score >= RETRIEVER_THRESHOLD] + context_str = "\n\n".join([n.node.get_content() for n in nodes_filtered]) prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(prompt) # return final_response - return Response(response=str(response), source_nodes=nodes) + return Response(response=str(response), source_nodes=nodes_filtered) def _process_summary_query(self, query_str: str) -> Response: summary_nodes = self.summary_retriever.retrieve(query_str) + summary_nodes_filtered = [node for node in summary_nodes if node.score >= RETRIEVER_THRESHOLD] utils = QdrantEngineUtils( metadata_date_key=self.metadata_date_key, metadata_date_format=self.metadata_date_format, @@ -193,7 +196,7 @@ def _process_summary_query(self, query_str: str) -> Response: dates = [ node.metadata[self.metadata_date_summary_key] - for node in summary_nodes + for node in summary_nodes_filtered if self.metadata_date_summary_key in node.metadata ] @@ -208,8 +211,10 @@ def _process_summary_query(self, query_str: str) -> Response: ) raw_nodes = retriever.retrieve(query_str) - context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) + raw_nodes_filtered = [node for node in raw_nodes if node.score >= RETRIEVER_THRESHOLD] + + context_str = utils.combine_nodes_for_prompt(summary_nodes_filtered, raw_nodes_filtered) prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(prompt) - return Response(response=str(response), source_nodes=raw_nodes) + return Response(response=str(response), source_nodes=raw_nodes_filtered) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 837e083..8fc52e6 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -16,6 +16,7 @@ from llama_index.llms.openai import OpenAI from utils.query_engine.base_pg_engine import BasePGEngine from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils +from utils.globals import RETRIEVER_THRESHOLD qa_prompt = PromptTemplate( "Context information is below.\n" @@ -46,13 +47,14 @@ def custom_query(self, query_str: str): similar_nodes = retriever.query_db( query=query_str, filters=self._filters, date_interval=self._d ) + similar_nodes_filtered = [node for node in similar_nodes if node.score >= RETRIEVER_THRESHOLD] - context_str = self._prepare_context_str(similar_nodes, summary_nodes=None) + context_str = self._prepare_context_str(similar_nodes_filtered, summary_nodes=None) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") - return Response(response=str(response), source_nodes=similar_nodes) + return Response(response=str(response), source_nodes=similar_nodes_filtered) @classmethod def prepare_platform_engine(