Skip to content

Commit

Permalink
fix: cleaning codes!
Browse files Browse the repository at this point in the history
codeRabbitAI suggestions.
  • Loading branch information
amindadgar committed Nov 14, 2024
1 parent f305345 commit d924922
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 34 deletions.
69 changes: 39 additions & 30 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,9 @@ class DualQdrantRetrievalEngine(CustomQueryEngine):

def custom_query(self, query_str: str):
if self.summary_retriever is None:
nodes = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
response = self._process_basic_query(query_str)
else:
summary_nodes = self.summary_retriever.retrieve(query_str)
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
metadata_date_format=self.metadata_date_format,
date_margin=self._date_margin,
)
# the filters that will be applied on qdrant
dates = [
node.metadata[self.metadata_date_summary_key] for node in summary_nodes
]
filter = utils.define_raw_data_filters(dates=dates)
_, raw_data_top_k, _ = load_hyperparams()

# retrieve based on summary nodes
retriever: BaseRetriever = self._vector_store_index.as_retriever(
vector_store_kwargs={"qdrant_filters": filter},
similarity_top_k=raw_data_top_k,
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)

prompt = qa_prompt.format(context_str=context_str, query_str=query_str)

response = self.llm.complete(prompt)

response = self._process_summary_query(query_str)
return str(response)

@classmethod
Expand Down Expand Up @@ -199,3 +171,40 @@ def _setup_vector_store_index(
qdrant_vector = QDrantVectorAccess(collection_name=collection_name)
index = qdrant_vector.load_index()
return index

def _process_basic_query(self, query_str: str) -> str:
nodes = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response

def _process_summary_query(self, query_str: str) -> str:
summary_nodes = self.summary_retriever.retrieve(query_str)
utils = QdrantEngineUtils(
metadata_date_key=self.metadata_date_key,
metadata_date_format=self.metadata_date_format,
date_margin=self._date_margin,
)

dates = [
node.metadata[self.metadata_date_summary_key]
for node in summary_nodes
if self.metadata_date_summary_key in node.metadata
]

if not dates:
return self._process_basic_query(query_str)

filter = utils.define_raw_data_filters(dates=dates)

retriever: BaseRetriever = self._vector_store_index.as_retriever(
vector_store_kwargs={"qdrant_filters": filter},
similarity_top_k=self._raw_data_top_k,
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)
prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
return response
12 changes: 8 additions & 4 deletions utils/query_engine/qdrant_query_engine_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone

from dateutil.parser import parse
from llama_index.core.schema import NodeWithScore
Expand Down Expand Up @@ -48,10 +48,10 @@ def define_raw_data_filters(self, dates: list[str]) -> models.Filter:
for day_value in expanded_dates:
next_day = day_value + timedelta(days=1)

if self.metadata_date_format is DataType.INTEGER:
if self.metadata_date_format == DataType.INTEGER:
gte_value = int(day_value.timestamp())
lte_value = int(next_day.timestamp())
elif self.metadata_date_format is DataType.FLOAT:
elif self.metadata_date_format == DataType.FLOAT:
gte_value = day_value.timestamp()
lte_value = next_day.timestamp()
else:
Expand Down Expand Up @@ -100,7 +100,11 @@ def combine_nodes_for_prompt(

for raw_node in raw_nodes:
timestamp = raw_node.metadata[self.metadata_date_key]
date_str = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d")
date_str = (
datetime.fromtimestamp(timestamp)
.replace(tzinfo=timezone.utc)
.strftime("%Y-%m-%d")
)

if date_str not in raw_nodes_by_date:
raw_nodes_by_date[date_str] = []
Expand Down

0 comments on commit d924922

Please sign in to comment.