Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/90 qdrant summary and raw engine #99

Merged
merged 10 commits into from
Nov 14, 2024
Prev Previous commit
Next Next commit
feat: Added summary prompt engine!
and updated other related parts
amindadgar committed Nov 14, 2024

Verified

This commit was signed with the committer’s verified signature.
amindadgar Mohammad Amin Dadgar
commit 62e008e28d7f8c286f5d8b529e82e2dc2d750353
54 changes: 10 additions & 44 deletions utils/query_engine/base_qdrant_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
)
from llama_index.core.query_engine import RetrieverQueryEngine
from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess
from llama_index.core import Settings
from llama_index.core import get_response_synthesizer

from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine


class BaseQdrantEngine:
@@ -24,44 +21,13 @@ def __init__(self, platform_name: str, community_id: str) -> None:
"""
self.platform_name = platform_name
self.community_id = community_id
self.collection_name = f"{self.community_id}_{platform_name}"

def prepare(self, testing=False):
vector_store_index = self._setup_vector_store_index(
testing=testing,
engine = DualQdrantRetrievalEngine.setup_engine(
llm=Settings.llm,
synthesizer=get_response_synthesizer(),
platform_name=self.platform_name,
community_id=self.community_id,
)
_, similarity_top_k, _ = load_hyperparams()

retriever = VectorIndexRetriever(
index=vector_store_index,
similarity_top_k=similarity_top_k,
)
query_engine = RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=get_response_synthesizer(),
)
return query_engine

def _setup_vector_store_index(
self,
testing: bool = False,
**kwargs,
) -> VectorStoreIndex:
"""
prepare the vector_store for querying data

Parameters
------------
testing : bool
for testing purposes
**kwargs :
collection_name : str
to override the default collection_name
"""
collection_name = kwargs.get("collection_name", self.collection_name)
qdrant_vector = QDrantVectorAccess(
collection_name=collection_name,
testing=testing,
)
index = qdrant_vector.load_index()
return index
return engine
21 changes: 10 additions & 11 deletions utils/query_engine/dual_qdrant_retrieval_engine.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
from llama_index.core import PromptTemplate, VectorStoreIndex
from llama_index.core.query_engine import CustomQueryEngine
from llama_index.core.retrievers import BaseRetriever
from llama_index.core import get_response_synthesizer
from llama_index.core.response_synthesizers import BaseSynthesizer
from llama_index.core.indices.vector_store.retrievers.retriever import (
VectorIndexRetriever,
@@ -38,9 +37,8 @@ def custom_query(self, query_str: str):
if self.summary_retriever is None:
nodes = self.retriever.retrieve(query_str)
context_str = "\n\n".join([n.node.get_content() for n in nodes])
response = self.llm.complete(
qa_prompt.format(context_str=context_str, query_str=query_str)
)
prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(prompt)
else:
summary_nodes = self.summary_retriever.retrieve(query_str)
utils = QdrantEngineUtils(
@@ -52,21 +50,21 @@ def custom_query(self, query_str: str):
dates = [
node.metadata[self.metadata_date_summary_key] for node in summary_nodes
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
]
should_filters = utils.define_raw_data_filters(dates=dates)
filter = utils.define_raw_data_filters(dates=dates)
_, raw_data_top_k, _ = load_hyperparams()

# retrieve based on summary nodes
retriever: BaseRetriever = self._vector_store_index.as_retriever(
{"qdrant_filters": should_filters},
vector_store_kwargs={"qdrant_filters": filter},
similarity_top_k=raw_data_top_k,
)
raw_nodes = retriever.retrieve(query_str)

context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes)

response = self.llm.complete(
qa_prompt.format(context_str=context_str, query_str=query_str)
)
prompt = qa_prompt.format(context_str=context_str, query_str=query_str)

response = self.llm.complete(prompt)

return str(response)

@@ -118,7 +116,7 @@ def setup_engine(
)

@classmethod
def _prepare_engine_with_summaries(
def setup_engine_with_summaries(
cls,
llm: OpenAI,
synthesizer: BaseSynthesizer,
@@ -193,8 +191,9 @@ def _prepare_engine_with_summaries(
qa_prompt=qa_prompt,
)

@classmethod
def _setup_vector_store_index(
self,
cls,
collection_name: str,
) -> VectorStoreIndex:
"""
31 changes: 20 additions & 11 deletions utils/query_engine/qdrant_query_engine_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from datetime import datetime, timedelta
from dateutil.parser import parse

@@ -17,7 +18,7 @@ def __init__(
self.metadata_date_format = metadata_date_format
self.date_margin = date_margin

def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldCondition]:
def define_raw_data_filters(self, dates: list[str]) -> models.Filter:
"""
define the filters to be applied on raw data given the dates

@@ -29,10 +30,10 @@ def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldConditio

Returns
---------
should_filters : list[models.FieldCondition]
filter : models.Filter
the filters to be applied on raw data
"""
should_filters: set[models.FieldCondition] = set()
should_filters: list[models.FieldCondition] = []
expanded_dates: set[datetime] = set()

# accounting for the date margin
@@ -58,7 +59,7 @@ def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldConditio
"raw data metadata `date` shouldn't be anything other than FLOAT or INTEGER"
)

should_filters.add(
should_filters.append(
models.FieldCondition(
key=self.metadata_date_key,
range=models.Range(
@@ -68,7 +69,9 @@ def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldConditio
)
)

return list(should_filters)
filter = models.Filter(should=should_filters)

return filter

def combine_nodes_for_prompt(
self,
@@ -103,23 +106,29 @@ def combine_nodes_for_prompt(
raw_nodes_by_date[date_str] = []
raw_nodes_by_date[date_str].append(raw_node)

# Build the combined prompt
combined_sections = []

# A summary could be separated into multiple nodes
# combining them together
combined_summaries: dict[str, str] = defaultdict(str)
for summary_node in summary_nodes:
date = summary_node.metadata["date"]
summary_text = summary_node.text

summaries = summary_text.split("\n")
summary_bullets = set(summaries)
summary_bullets.remove("")
if "" in summary_bullets:
summary_bullets.remove("")
combined_summaries[date] += "\n".join(summary_bullets)

# Build the combined prompt
combined_sections = []

section = f"date: {date}\n\nSummary:\n" + "\n".join(summary_bullets) + "\n"
for date, summary_bullets in combined_summaries.items():
section = f"Date: {date}\nSummary:\n" + summary_bullets + "\n\n"

if date in raw_nodes_by_date:
raw_texts = [node.text for node in raw_nodes_by_date[date]]
section += "Messages:\n" + "\n".join(raw_texts)

combined_sections.append(section)

return "\n\n" + "\n\n".join(combined_sections)
return "\n\n".join(combined_sections)