From d9249220f92bb2dd1d43c8ffbf935f096e54a6b6 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 13:30:02 +0330 Subject: [PATCH] fix: cleaning codes! codeRabbitAI suggestions. --- .../dual_qdrant_retrieval_engine.py | 69 +++++++++++-------- .../query_engine/qdrant_query_engine_utils.py | 12 ++-- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index 41f96ac..2903287 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -33,37 +33,9 @@ class DualQdrantRetrievalEngine(CustomQueryEngine): def custom_query(self, query_str: str): if self.summary_retriever is None: - nodes = self.retriever.retrieve(query_str) - context_str = "\n\n".join([n.node.get_content() for n in nodes]) - prompt = qa_prompt.format(context_str=context_str, query_str=query_str) - response = self.llm.complete(prompt) + response = self._process_basic_query(query_str) else: - summary_nodes = self.summary_retriever.retrieve(query_str) - utils = QdrantEngineUtils( - metadata_date_key=self.metadata_date_key, - metadata_date_format=self.metadata_date_format, - date_margin=self._date_margin, - ) - # the filters that will be applied on qdrant - dates = [ - node.metadata[self.metadata_date_summary_key] for node in summary_nodes - ] - filter = utils.define_raw_data_filters(dates=dates) - _, raw_data_top_k, _ = load_hyperparams() - - # retrieve based on summary nodes - retriever: BaseRetriever = self._vector_store_index.as_retriever( - vector_store_kwargs={"qdrant_filters": filter}, - similarity_top_k=raw_data_top_k, - ) - raw_nodes = retriever.retrieve(query_str) - - context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) - - prompt = qa_prompt.format(context_str=context_str, query_str=query_str) - - response = self.llm.complete(prompt) - + response = self._process_summary_query(query_str) return str(response) @classmethod @@ -199,3 +171,40 @@ def _setup_vector_store_index( qdrant_vector = QDrantVectorAccess(collection_name=collection_name) index = qdrant_vector.load_index() return index + + def _process_basic_query(self, query_str: str) -> str: + nodes = self.retriever.retrieve(query_str) + context_str = "\n\n".join([n.node.get_content() for n in nodes]) + prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) + return response + + def _process_summary_query(self, query_str: str) -> str: + summary_nodes = self.summary_retriever.retrieve(query_str) + utils = QdrantEngineUtils( + metadata_date_key=self.metadata_date_key, + metadata_date_format=self.metadata_date_format, + date_margin=self._date_margin, + ) + + dates = [ + node.metadata[self.metadata_date_summary_key] + for node in summary_nodes + if self.metadata_date_summary_key in node.metadata + ] + + if not dates: + return self._process_basic_query(query_str) + + filter = utils.define_raw_data_filters(dates=dates) + + retriever: BaseRetriever = self._vector_store_index.as_retriever( + vector_store_kwargs={"qdrant_filters": filter}, + similarity_top_k=self._raw_data_top_k, + ) + raw_nodes = retriever.retrieve(query_str) + + context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) + prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) + return response diff --git a/utils/query_engine/qdrant_query_engine_utils.py b/utils/query_engine/qdrant_query_engine_utils.py index b560749..517a35a 100644 --- a/utils/query_engine/qdrant_query_engine_utils.py +++ b/utils/query_engine/qdrant_query_engine_utils.py @@ -1,5 +1,5 @@ from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from dateutil.parser import parse from llama_index.core.schema import NodeWithScore @@ -48,10 +48,10 @@ def define_raw_data_filters(self, dates: list[str]) -> models.Filter: for day_value in expanded_dates: next_day = day_value + timedelta(days=1) - if self.metadata_date_format is DataType.INTEGER: + if self.metadata_date_format == DataType.INTEGER: gte_value = int(day_value.timestamp()) lte_value = int(next_day.timestamp()) - elif self.metadata_date_format is DataType.FLOAT: + elif self.metadata_date_format == DataType.FLOAT: gte_value = day_value.timestamp() lte_value = next_day.timestamp() else: @@ -100,7 +100,11 @@ def combine_nodes_for_prompt( for raw_node in raw_nodes: timestamp = raw_node.metadata[self.metadata_date_key] - date_str = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d") + date_str = ( + datetime.fromtimestamp(timestamp) + .replace(tzinfo=timezone.utc) + .strftime("%Y-%m-%d") + ) if date_str not in raw_nodes_by_date: raw_nodes_by_date[date_str] = []