diff --git a/bot/retrievers/custom_retriever.py b/bot/retrievers/custom_retriever.py index 046842e..46147ac 100644 --- a/bot/retrievers/custom_retriever.py +++ b/bot/retrievers/custom_retriever.py @@ -51,7 +51,7 @@ def _build_node_list_from_query_result( score: float | None = None if query_result.similarities is not None: score = query_result.similarities[ind] - + if score is not None and score >= RETRIEVER_THRESHOLD: # This is the part we updated node_new = Node.from_dict(node.to_dict()) diff --git a/utils/globals.py b/utils/globals.py index 098c67a..eb0bedd 100644 --- a/utils/globals.py +++ b/utils/globals.py @@ -1,2 +1,2 @@ # the theshold to skip nodes of being included in an answer -RETRIEVER_THRESHOLD = .4 \ No newline at end of file +RETRIEVER_THRESHOLD = 0.4 diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index d7cecd2..e3afbd3 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -187,7 +187,9 @@ def _process_basic_query(self, query_str: str) -> Response: def _process_summary_query(self, query_str: str) -> Response: summary_nodes = self.summary_retriever.retrieve(query_str) - summary_nodes_filtered = [node for node in summary_nodes if node.score >= RETRIEVER_THRESHOLD] + summary_nodes_filtered = [ + node for node in summary_nodes if node.score >= RETRIEVER_THRESHOLD + ] utils = QdrantEngineUtils( metadata_date_key=self.metadata_date_key, metadata_date_format=self.metadata_date_format, @@ -211,9 +213,13 @@ def _process_summary_query(self, query_str: str) -> Response: ) raw_nodes = retriever.retrieve(query_str) - raw_nodes_filtered = [node for node in raw_nodes if node.score >= RETRIEVER_THRESHOLD] + raw_nodes_filtered = [ + node for node in raw_nodes if node.score >= RETRIEVER_THRESHOLD + ] - context_str = utils.combine_nodes_for_prompt(summary_nodes_filtered, raw_nodes_filtered) + context_str = utils.combine_nodes_for_prompt( + summary_nodes_filtered, raw_nodes_filtered + ) prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(prompt) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 8fc52e6..8f5e61a 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -47,9 +47,13 @@ def custom_query(self, query_str: str): similar_nodes = retriever.query_db( query=query_str, filters=self._filters, date_interval=self._d ) - similar_nodes_filtered = [node for node in similar_nodes if node.score >= RETRIEVER_THRESHOLD] + similar_nodes_filtered = [ + node for node in similar_nodes if node.score >= RETRIEVER_THRESHOLD + ] - context_str = self._prepare_context_str(similar_nodes_filtered, summary_nodes=None) + context_str = self._prepare_context_str( + similar_nodes_filtered, summary_nodes=None + ) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}")