From 8bf77ef5a0d05a419be82b42e9689fc163bf45b4 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 7 Nov 2024 14:31:07 +0330 Subject: [PATCH 1/9] wip: Added dual query engine! --- .../dual_qdrant_retrieval_engine.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 utils/query_engine/dual_qdrant_retrieval_engine.py diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py new file mode 100644 index 0000000..94d7f9d --- /dev/null +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -0,0 +1,127 @@ +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.llms.openai import OpenAI +from llama_index.core import PromptTemplate, Document, VectorStoreIndex +from llama_index.core.query_engine import CustomQueryEngine +from llama_index.core.retrievers import BaseRetriever +from llama_index.core import get_response_synthesizer +from llama_index.core.response_synthesizers import BaseSynthesizer +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) +from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess + + +qa_prompt = PromptTemplate( + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " +) + + +class DualQdrantRetrievalEngine(CustomQueryEngine): + """RAG String Query Engine.""" + + retriever: BaseRetriever + response_synthesizer: BaseSynthesizer + llm: OpenAI + qa_prompt: PromptTemplate + + def custom_query(self, query_str: str): + + if self.summary_retriever is None: + nodes = self.retriever.retrieve(query_str) + context_str = "\n\n".join([n.node.get_content() for n in nodes]) + response = self.llm.complete( + qa_prompt.format(context_str=context_str, query_str=query_str) + ) + else: + summary_nodes = self.summary_retriever.retrieve(query_str) + + # TODO: filter on raw data for extraction + # and then prepare the prompt + + return str(response) + + @classmethod + def setup_engine( + cls, + use_summary: bool | None, + llm: OpenAI, + synthesizer: BaseSynthesizer, + qa_prompt: PromptTemplate, + platform_name: str, + community_id: str, + ): + """ + setup the custom query engine on qdrant data + + Parameters + ------------ + use_summary : bool | None + whether to use the summary data or not + note: the summary data should be available before + for this option to be enabled + llm : OpenAI + the llm to be used for RAG pipeline + synthesizer : BaseSynthesizer + the process of generating response using an LLM + qa_prompt : PromptTemplate + the prompt template to be filled and passed to an LLM + platform_name : str + specifying the platform data to identify the data collection + community_id : str + specifying community_id to identify the data collection + """ + collection_name = f"{community_id}_{platform_name}" + + summary_data_top_k, raw_data_top_k, interval_margin = load_hyperparams() + cls._interval_margin = interval_margin + + vector_store_index = cls._setup_vector_store_index( + collection_name=collection_name + ) + retriever = VectorIndexRetriever( + index=vector_store_index, + similarity_top_k=raw_data_top_k, + ) + + if use_summary: + summary_collection_name = collection_name + "_summary" + summary_vector_store_index = cls._setup_vector_store_index( + collection_name=summary_collection_name + ) + + cls.summary_retriever = VectorIndexRetriever( + index=summary_vector_store_index, + similarity_top_k=summary_data_top_k, + ) + else: + cls.summary_retriever = None + + return cls( + retriever=retriever, + response_synthesizer=synthesizer, + llm=llm, + qa_prompt=qa_prompt, + ) + + def _setup_vector_store_index( + self, + collection_name: str, + ) -> VectorStoreIndex: + """ + prepare the vector_store for querying data + + Parameters + ------------ + collection_name : str + to override the default collection_name + """ + qdrant_vector = QDrantVectorAccess(collection_name=collection_name) + index = qdrant_vector.load_index() + return index From 750888f8b12a303ba59e4daf121eeb16b208dad8 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 7 Nov 2024 16:54:36 +0330 Subject: [PATCH 2/9] wip: preparing the queries for raw data! --- schema/type.py | 6 ++++ .../dual_qdrant_retrieval_engine.py | 28 ++++++++++++++++--- 2 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 schema/type.py diff --git a/schema/type.py b/schema/type.py new file mode 100644 index 0000000..7808592 --- /dev/null +++ b/schema/type.py @@ -0,0 +1,6 @@ +from enum import Enum + +class DataType(Enum): + INTEGER = "INTEGER" + STRING = "STRING" + BOOLEAN = "BOOLEAN" \ No newline at end of file diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index 94d7f9d..97fcd93 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -1,6 +1,6 @@ from bot.retrievers.utils.load_hyperparams import load_hyperparams from llama_index.llms.openai import OpenAI -from llama_index.core import PromptTemplate, Document, VectorStoreIndex +from llama_index.core import PromptTemplate, VectorStoreIndex from llama_index.core.query_engine import CustomQueryEngine from llama_index.core.retrievers import BaseRetriever from llama_index.core import get_response_synthesizer @@ -9,6 +9,7 @@ VectorIndexRetriever, ) from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess +from schema.type import DataType qa_prompt = PromptTemplate( @@ -41,7 +42,10 @@ def custom_query(self, query_str: str): ) else: summary_nodes = self.summary_retriever.retrieve(query_str) - + # the filters that will be applied on qdrant + should_filters = [] + for node in summary_nodes: + date_value = node.metadata[self.metadata_date_key] # TODO: filter on raw data for extraction # and then prepare the prompt @@ -50,19 +54,21 @@ def custom_query(self, query_str: str): @classmethod def setup_engine( cls, - use_summary: bool | None, + use_summary: bool, llm: OpenAI, synthesizer: BaseSynthesizer, qa_prompt: PromptTemplate, platform_name: str, community_id: str, + metadata_date_key: str | None = None, + metadata_date_format: DataType | None = None, ): """ setup the custom query engine on qdrant data Parameters ------------ - use_summary : bool | None + use_summary : bool whether to use the summary data or not note: the summary data should be available before for this option to be enabled @@ -76,7 +82,19 @@ def setup_engine( specifying the platform data to identify the data collection community_id : str specifying community_id to identify the data collection + metadata_date_key : str | None + the date key name in summary documents' metadata + In case of `use_summary` equal to be true this shuold be passed + metadata_date_format : DataType | None + the date format in metadata + In case of `use_summary` equal to be true this shuold be passed """ + if use_summary and (metadata_date_key is None or metadata_date_format is None): + raise ValueError( + "`metadata_date_key` and `metadata_date_format` " + "should be given in case if use_summary=True!" + ) + collection_name = f"{community_id}_{platform_name}" summary_data_top_k, raw_data_top_k, interval_margin = load_hyperparams() @@ -100,6 +118,8 @@ def setup_engine( index=summary_vector_store_index, similarity_top_k=summary_data_top_k, ) + cls.metadata_date_key = metadata_date_key + cls.metadata_date_format = metadata_date_format else: cls.summary_retriever = None From 26bc12e30c98f4391407293c6e6f33a65ea2863d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 11 Nov 2024 14:41:16 +0330 Subject: [PATCH 3/9] feat: Added codes for using summary data! --- schema/type.py | 4 +- .../dual_qdrant_retrieval_engine.py | 133 +++++++++++++----- .../query_engine/qdrant_query_engine_utils.py | 125 ++++++++++++++++ 3 files changed, 226 insertions(+), 36 deletions(-) create mode 100644 utils/query_engine/qdrant_query_engine_utils.py diff --git a/schema/type.py b/schema/type.py index 7808592..a0d1c5f 100644 --- a/schema/type.py +++ b/schema/type.py @@ -1,6 +1,8 @@ from enum import Enum + class DataType(Enum): INTEGER = "INTEGER" STRING = "STRING" - BOOLEAN = "BOOLEAN" \ No newline at end of file + BOOLEAN = "BOOLEAN" + FLOAT = "FLOAT" diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index 97fcd93..34cbbc6 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -10,6 +10,7 @@ ) from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess from schema.type import DataType +from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils qa_prompt = PromptTemplate( @@ -42,26 +43,93 @@ def custom_query(self, query_str: str): ) else: summary_nodes = self.summary_retriever.retrieve(query_str) + utils = QdrantEngineUtils( + metadata_date_key=self.metadata_date_key, + metadata_date_format=self.metadata_date_format, + date_margin=self._date_margin, + ) # the filters that will be applied on qdrant - should_filters = [] - for node in summary_nodes: - date_value = node.metadata[self.metadata_date_key] - # TODO: filter on raw data for extraction - # and then prepare the prompt + dates = [ + node.metadata[self.metadata_date_summary_key] for node in summary_nodes + ] + should_filters = utils.define_raw_data_filters(dates=dates) + _, raw_data_top_k, _ = load_hyperparams() + + # retrieve based on summary nodes + retriever: BaseRetriever = self._vector_store_index.as_retriever( + {"qdrant_filters": should_filters}, + similarity_top_k=raw_data_top_k, + ) + raw_nodes = retriever.retrieve(query_str) + + context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) + + response = self.llm.complete( + qa_prompt.format(context_str=context_str, query_str=query_str) + ) return str(response) @classmethod def setup_engine( cls, - use_summary: bool, llm: OpenAI, synthesizer: BaseSynthesizer, qa_prompt: PromptTemplate, platform_name: str, community_id: str, - metadata_date_key: str | None = None, - metadata_date_format: DataType | None = None, + ): + """ + setup the custom query engine on qdrant data + + Parameters + ------------ + llm : OpenAI + the llm to be used for RAG pipeline + synthesizer : BaseSynthesizer + the process of generating response using an LLM + qa_prompt : PromptTemplate + the prompt template to be filled and passed to an LLM + platform_name : str + specifying the platform data to identify the data collection + community_id : str + specifying community_id to identify the data collection + """ + collection_name = f"{community_id}_{platform_name}" + + _, raw_data_top_k, date_margin = load_hyperparams() + cls._date_margin = date_margin + + cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index( + collection_name=collection_name + ) + retriever = VectorIndexRetriever( + index=cls._vector_store_index, + similarity_top_k=raw_data_top_k, + ) + + cls.summary_retriever = None + + return cls( + retriever=retriever, + response_synthesizer=synthesizer, + llm=llm, + qa_prompt=qa_prompt, + ) + + @classmethod + def _prepare_engine_with_summaries( + cls, + llm: OpenAI, + synthesizer: BaseSynthesizer, + qa_prompt: PromptTemplate, + platform_name: str, + community_id: str, + metadata_date_key: str, + metadata_date_format: DataType, + metadata_date_summary_key: str, + metadata_date_summary_format: DataType, + summary_metadata_to_use: list[str], ): """ setup the custom query engine on qdrant data @@ -81,47 +149,42 @@ def setup_engine( platform_name : str specifying the platform data to identify the data collection community_id : str - specifying community_id to identify the data collection - metadata_date_key : str | None + specifying community_id to identify the data collection + metadata_date_summary_key : str | None the date key name in summary documents' metadata In case of `use_summary` equal to be true this shuold be passed - metadata_date_format : DataType | None + metadata_date_summary_format : DataType | None the date format in metadata In case of `use_summary` equal to be true this shuold be passed + NOTE: this should be always a string for the filtering of it to work. """ - if use_summary and (metadata_date_key is None or metadata_date_format is None): - raise ValueError( - "`metadata_date_key` and `metadata_date_format` " - "should be given in case if use_summary=True!" - ) - collection_name = f"{community_id}_{platform_name}" + summary_data_top_k, raw_data_top_k, date_margin = load_hyperparams() + cls._date_margin = date_margin + cls._raw_data_top_k = raw_data_top_k - summary_data_top_k, raw_data_top_k, interval_margin = load_hyperparams() - cls._interval_margin = interval_margin - - vector_store_index = cls._setup_vector_store_index( + cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index( collection_name=collection_name ) retriever = VectorIndexRetriever( - index=vector_store_index, + index=cls._vector_store_index, similarity_top_k=raw_data_top_k, ) - if use_summary: - summary_collection_name = collection_name + "_summary" - summary_vector_store_index = cls._setup_vector_store_index( - collection_name=summary_collection_name - ) + summary_collection_name = collection_name + "_summary" + summary_vector_store_index = cls._setup_vector_store_index( + collection_name=summary_collection_name + ) - cls.summary_retriever = VectorIndexRetriever( - index=summary_vector_store_index, - similarity_top_k=summary_data_top_k, - ) - cls.metadata_date_key = metadata_date_key - cls.metadata_date_format = metadata_date_format - else: - cls.summary_retriever = None + cls.summary_retriever = VectorIndexRetriever( + index=summary_vector_store_index, + similarity_top_k=summary_data_top_k, + ) + cls.metadata_date_summary_key = metadata_date_summary_key + cls.metadata_date_summary_format = metadata_date_summary_format + cls.metadata_date_key = metadata_date_key + cls.metadata_date_format = metadata_date_format + cls.summary_metadata_to_use = summary_metadata_to_use return cls( retriever=retriever, diff --git a/utils/query_engine/qdrant_query_engine_utils.py b/utils/query_engine/qdrant_query_engine_utils.py new file mode 100644 index 0000000..a072fa3 --- /dev/null +++ b/utils/query_engine/qdrant_query_engine_utils.py @@ -0,0 +1,125 @@ +from datetime import datetime, timedelta +from dateutil.parser import parse + +from llama_index.core.schema import NodeWithScore +from schema.type import DataType +from qdrant_client.http import models + + +class QdrantEngineUtils: + def __init__( + self, + metadata_date_key: str, + metadata_date_format: DataType, + date_margin: int, + ) -> None: + self.metadata_date_key = metadata_date_key + self.metadata_date_format = metadata_date_format + self.date_margin = date_margin + + def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldCondition]: + """ + define the filters to be applied on raw data given the dates + + Parameters + ----------- + dates : list[str] + a list of dates that should be a string. i.e. with format of `%Y-%m-%d` + the date should be representing a day + + Returns + --------- + should_filters : list[models.FieldCondition] + the filters to be applied on raw data + """ + should_filters: set[models.FieldCondition] = set() + expanded_dates: set[datetime] = set() + + # accounting for the date margin + for date in dates: + day_value = parse(date) + expanded_dates.add(day_value) + + for i in range(1, self.date_margin + 1): + expanded_dates.add(day_value - timedelta(days=i)) + expanded_dates.add(day_value + timedelta(days=i)) + + for day_value in expanded_dates: + next_day = day_value + timedelta(days=1) + + if self.metadata_date_format is DataType.INTEGER: + gte_value = int(day_value.timestamp()) + lte_value = int(next_day.timestamp()) + elif self.metadata_date_format is DataType.FLOAT: + gte_value = day_value.timestamp() + lte_value = next_day.timestamp() + else: + raise ValueError( + "raw data metadata `date` shouldn't be anything other than FLOAT or INTEGER" + ) + + should_filters.add( + models.FieldCondition( + key=self.metadata_date_key, + range=models.Range( + gte=gte_value, + lte=lte_value, + ), + ) + ) + + return list(should_filters) + + def combine_nodes_for_prompt( + self, + summary_nodes: list[NodeWithScore], + raw_nodes: list[NodeWithScore], + ) -> str: + """ + Combines summary nodes with their corresponding raw nodes based on date matching. + + Parameters + ---------- + summary_nodes : list[NodeWithScore] + list of summary nodes containing metadata with 'date' in "%Y-%m-%d" format + and 'text' field + raw_nodes : list[NodeWithScore] + list of raw nodes containing metadata with self.me as float timestamp + and 'text' field + + Returns + ------- + prompt : str + A formatted prompt combining matched summary and raw texts + """ + # Create a mapping of date to raw nodes for efficient lookup + raw_nodes_by_date: dict[str, list[NodeWithScore]] = {} + + for raw_node in raw_nodes: + timestamp = raw_node.metadata[self.metadata_date_key] + date_str = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d") + + if date_str not in raw_nodes_by_date: + raw_nodes_by_date[date_str] = [] + raw_nodes_by_date[date_str].append(raw_node) + + # Build the combined prompt + combined_sections = [] + + for summary_node in summary_nodes: + date = summary_node.metadata["date"] + summary_text = summary_node.text + + summaries = summary_text.split("\n") + summary_bullets = set(summaries) + summary_bullets.remove("") + + section = f"date: {date}\n\nSummary:\n" + "\n".join(summary_bullets) + "\n" + + if date in raw_nodes_by_date: + raw_texts = [node.text for node in raw_nodes_by_date[date]] + section += "Messages:\n" + "\n".join(raw_texts) + + combined_sections.append(section) + + return "\n\n" + "\n\n".join(combined_sections) From 62e008e28d7f8c286f5d8b529e82e2dc2d750353 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 11:26:55 +0330 Subject: [PATCH 4/9] feat: Added summary prompt engine! and updated other related parts --- utils/query_engine/base_qdrant_engine.py | 54 ++++--------------- .../dual_qdrant_retrieval_engine.py | 21 ++++---- .../query_engine/qdrant_query_engine_utils.py | 31 +++++++---- 3 files changed, 40 insertions(+), 66 deletions(-) diff --git a/utils/query_engine/base_qdrant_engine.py b/utils/query_engine/base_qdrant_engine.py index 2c89896..b3029bb 100644 --- a/utils/query_engine/base_qdrant_engine.py +++ b/utils/query_engine/base_qdrant_engine.py @@ -1,10 +1,7 @@ -from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.core import VectorStoreIndex, get_response_synthesizer -from llama_index.core.indices.vector_store.retrievers.retriever import ( - VectorIndexRetriever, -) -from llama_index.core.query_engine import RetrieverQueryEngine -from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess +from llama_index.core import Settings +from llama_index.core import get_response_synthesizer + +from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine class BaseQdrantEngine: @@ -24,44 +21,13 @@ def __init__(self, platform_name: str, community_id: str) -> None: """ self.platform_name = platform_name self.community_id = community_id - self.collection_name = f"{self.community_id}_{platform_name}" def prepare(self, testing=False): - vector_store_index = self._setup_vector_store_index( - testing=testing, + engine = DualQdrantRetrievalEngine.setup_engine( + llm=Settings.llm, + synthesizer=get_response_synthesizer(), + platform_name=self.platform_name, + community_id=self.community_id, ) - _, similarity_top_k, _ = load_hyperparams() - retriever = VectorIndexRetriever( - index=vector_store_index, - similarity_top_k=similarity_top_k, - ) - query_engine = RetrieverQueryEngine( - retriever=retriever, - response_synthesizer=get_response_synthesizer(), - ) - return query_engine - - def _setup_vector_store_index( - self, - testing: bool = False, - **kwargs, - ) -> VectorStoreIndex: - """ - prepare the vector_store for querying data - - Parameters - ------------ - testing : bool - for testing purposes - **kwargs : - collection_name : str - to override the default collection_name - """ - collection_name = kwargs.get("collection_name", self.collection_name) - qdrant_vector = QDrantVectorAccess( - collection_name=collection_name, - testing=testing, - ) - index = qdrant_vector.load_index() - return index + return engine diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index 34cbbc6..88705a1 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -3,7 +3,6 @@ from llama_index.core import PromptTemplate, VectorStoreIndex from llama_index.core.query_engine import CustomQueryEngine from llama_index.core.retrievers import BaseRetriever -from llama_index.core import get_response_synthesizer from llama_index.core.response_synthesizers import BaseSynthesizer from llama_index.core.indices.vector_store.retrievers.retriever import ( VectorIndexRetriever, @@ -38,9 +37,8 @@ def custom_query(self, query_str: str): if self.summary_retriever is None: nodes = self.retriever.retrieve(query_str) context_str = "\n\n".join([n.node.get_content() for n in nodes]) - response = self.llm.complete( - qa_prompt.format(context_str=context_str, query_str=query_str) - ) + prompt = qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) else: summary_nodes = self.summary_retriever.retrieve(query_str) utils = QdrantEngineUtils( @@ -52,21 +50,21 @@ def custom_query(self, query_str: str): dates = [ node.metadata[self.metadata_date_summary_key] for node in summary_nodes ] - should_filters = utils.define_raw_data_filters(dates=dates) + filter = utils.define_raw_data_filters(dates=dates) _, raw_data_top_k, _ = load_hyperparams() # retrieve based on summary nodes retriever: BaseRetriever = self._vector_store_index.as_retriever( - {"qdrant_filters": should_filters}, + vector_store_kwargs={"qdrant_filters": filter}, similarity_top_k=raw_data_top_k, ) raw_nodes = retriever.retrieve(query_str) context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) - response = self.llm.complete( - qa_prompt.format(context_str=context_str, query_str=query_str) - ) + prompt = qa_prompt.format(context_str=context_str, query_str=query_str) + + response = self.llm.complete(prompt) return str(response) @@ -118,7 +116,7 @@ def setup_engine( ) @classmethod - def _prepare_engine_with_summaries( + def setup_engine_with_summaries( cls, llm: OpenAI, synthesizer: BaseSynthesizer, @@ -193,8 +191,9 @@ def _prepare_engine_with_summaries( qa_prompt=qa_prompt, ) + @classmethod def _setup_vector_store_index( - self, + cls, collection_name: str, ) -> VectorStoreIndex: """ diff --git a/utils/query_engine/qdrant_query_engine_utils.py b/utils/query_engine/qdrant_query_engine_utils.py index a072fa3..cfde108 100644 --- a/utils/query_engine/qdrant_query_engine_utils.py +++ b/utils/query_engine/qdrant_query_engine_utils.py @@ -1,3 +1,4 @@ +from collections import defaultdict from datetime import datetime, timedelta from dateutil.parser import parse @@ -17,7 +18,7 @@ def __init__( self.metadata_date_format = metadata_date_format self.date_margin = date_margin - def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldCondition]: + def define_raw_data_filters(self, dates: list[str]) -> models.Filter: """ define the filters to be applied on raw data given the dates @@ -29,10 +30,10 @@ def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldConditio Returns --------- - should_filters : list[models.FieldCondition] + filter : models.Filter the filters to be applied on raw data """ - should_filters: set[models.FieldCondition] = set() + should_filters: list[models.FieldCondition] = [] expanded_dates: set[datetime] = set() # accounting for the date margin @@ -58,7 +59,7 @@ def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldConditio "raw data metadata `date` shouldn't be anything other than FLOAT or INTEGER" ) - should_filters.add( + should_filters.append( models.FieldCondition( key=self.metadata_date_key, range=models.Range( @@ -68,7 +69,9 @@ def define_raw_data_filters(self, dates: list[str]) -> list[models.FieldConditio ) ) - return list(should_filters) + filter = models.Filter(should=should_filters) + + return filter def combine_nodes_for_prompt( self, @@ -103,18 +106,24 @@ def combine_nodes_for_prompt( raw_nodes_by_date[date_str] = [] raw_nodes_by_date[date_str].append(raw_node) - # Build the combined prompt - combined_sections = [] - + # A summary could be separated into multiple nodes + # combining them together + combined_summaries: dict[str, str] = defaultdict(str) for summary_node in summary_nodes: date = summary_node.metadata["date"] summary_text = summary_node.text summaries = summary_text.split("\n") summary_bullets = set(summaries) - summary_bullets.remove("") + if "" in summary_bullets: + summary_bullets.remove("") + combined_summaries[date] += "\n".join(summary_bullets) + + # Build the combined prompt + combined_sections = [] - section = f"date: {date}\n\nSummary:\n" + "\n".join(summary_bullets) + "\n" + for date, summary_bullets in combined_summaries.items(): + section = f"Date: {date}\nSummary:\n" + summary_bullets + "\n\n" if date in raw_nodes_by_date: raw_texts = [node.text for node in raw_nodes_by_date[date]] @@ -122,4 +131,4 @@ def combine_nodes_for_prompt( combined_sections.append(section) - return "\n\n" + "\n\n".join(combined_sections) + return "\n\n".join(combined_sections) From 618ef1b0b4d87048d2bb96d223940a92878b830e Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 11:45:51 +0330 Subject: [PATCH 5/9] feat: Added telegram summary to the pipeline! --- subquery.py | 12 ++++++++- utils/query_engine/__init__.py | 3 ++- .../dual_qdrant_retrieval_engine.py | 6 ----- utils/query_engine/telegram.py | 25 +++++++++++++++++++ 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/subquery.py b/subquery.py index eeaa24f..d84f49f 100644 --- a/subquery.py +++ b/subquery.py @@ -14,6 +14,7 @@ GitHubQueryEngine, MediaWikiQueryEngine, NotionQueryEngine, + TelegramDualQueryEngine, TelegramQueryEngine, prepare_discord_engine_auto_filter, ) @@ -134,7 +135,16 @@ def query_multiple_source( ) ) if telegram and check_collection("telegram"): - telegram_query_engine = TelegramQueryEngine(community_id=community_id).prepare() + # checking if the summaries was available + if check_collection("telegram_summary"): + telegram_query_engine = TelegramDualQueryEngine( + community_id=community_id + ).prepare() + else: + telegram_query_engine = TelegramQueryEngine( + community_id=community_id + ).prepare() + tool_metadata = ToolMetadata( name="Telegram", description=( diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index b1ae299..71642ec 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,8 +1,9 @@ # flake8: noqa +from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine from .gdrive import GDriveQueryEngine from .github import GitHubQueryEngine from .media_wiki import MediaWikiQueryEngine from .notion import NotionQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL -from .telegram import TelegramQueryEngine +from .telegram import TelegramDualQueryEngine, TelegramQueryEngine diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index 88705a1..aaf655e 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -73,7 +73,6 @@ def setup_engine( cls, llm: OpenAI, synthesizer: BaseSynthesizer, - qa_prompt: PromptTemplate, platform_name: str, community_id: str, ): @@ -120,14 +119,12 @@ def setup_engine_with_summaries( cls, llm: OpenAI, synthesizer: BaseSynthesizer, - qa_prompt: PromptTemplate, platform_name: str, community_id: str, metadata_date_key: str, metadata_date_format: DataType, metadata_date_summary_key: str, metadata_date_summary_format: DataType, - summary_metadata_to_use: list[str], ): """ setup the custom query engine on qdrant data @@ -142,8 +139,6 @@ def setup_engine_with_summaries( the llm to be used for RAG pipeline synthesizer : BaseSynthesizer the process of generating response using an LLM - qa_prompt : PromptTemplate - the prompt template to be filled and passed to an LLM platform_name : str specifying the platform data to identify the data collection community_id : str @@ -182,7 +177,6 @@ def setup_engine_with_summaries( cls.metadata_date_summary_format = metadata_date_summary_format cls.metadata_date_key = metadata_date_key cls.metadata_date_format = metadata_date_format - cls.summary_metadata_to_use = summary_metadata_to_use return cls( retriever=retriever, diff --git a/utils/query_engine/telegram.py b/utils/query_engine/telegram.py index 9f0ba48..381eb4d 100644 --- a/utils/query_engine/telegram.py +++ b/utils/query_engine/telegram.py @@ -1,7 +1,32 @@ +from llama_index.core import Settings +from llama_index.core import get_response_synthesizer +from llama_index.core.query_engine import BaseQueryEngine + +from schema.type import DataType from utils.query_engine.base_qdrant_engine import BaseQdrantEngine +from utils.query_engine import DualQdrantRetrievalEngine class TelegramQueryEngine(BaseQdrantEngine): def __init__(self, community_id: str) -> None: platform_name = "telegram" super().__init__(platform_name, community_id) + + +class TelegramDualQueryEngine: + def __init__(self, community_id: str) -> None: + self.platform_name = "telegram" + self.community_id = community_id + + def prepare(self) -> BaseQueryEngine: + engine = DualQdrantRetrievalEngine.setup_engine_with_summaries( + llm=Settings.llm, + synthesizer=get_response_synthesizer(), + platform_name=self.platform_name, + community_id=self.community_id, + metadata_date_key="createdAt", + metadata_date_format=DataType.FLOAT, + metadata_date_summary_key="date", + metadata_date_summary_format=DataType.STRING, + ) + return engine From 42e8a97c6089378f4b440c02344f24400f78def8 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 11:49:50 +0330 Subject: [PATCH 6/9] feat: cleaning codes! --- subquery.py | 9 --------- utils/query_engine/base_qdrant_engine.py | 6 +++--- utils/query_engine/telegram.py | 3 +-- 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/subquery.py b/subquery.py index d84f49f..1dacf83 100644 --- a/subquery.py +++ b/subquery.py @@ -1,6 +1,5 @@ from guidance.models import OpenAIChat from llama_index.core import QueryBundle, Settings -from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.query_engine import SubQuestionQueryEngine from llama_index.core.schema import NodeWithScore from llama_index.core.tools import QueryEngineTool, ToolMetadata @@ -72,14 +71,6 @@ def query_multiple_source( tools: list[ToolMetadata] = [] qdrant_utils = QDrantUtils(community_id) - discord_query_engine: BaseQueryEngine - github_query_engine: BaseQueryEngine - # discourse_query_engine: BaseQueryEngine - google_query_engine: BaseQueryEngine - notion_query_engine: BaseQueryEngine - mediawiki_query_engine: BaseQueryEngine - # telegram_query_engine: BaseQueryEngine - # wrapper for more clarity check_collection = qdrant_utils.check_collection_exist diff --git a/utils/query_engine/base_qdrant_engine.py b/utils/query_engine/base_qdrant_engine.py index b3029bb..774817c 100644 --- a/utils/query_engine/base_qdrant_engine.py +++ b/utils/query_engine/base_qdrant_engine.py @@ -1,5 +1,5 @@ -from llama_index.core import Settings -from llama_index.core import get_response_synthesizer +from llama_index.core import Settings, get_response_synthesizer +from llama_index.core.query_engine import BaseQueryEngine from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine @@ -22,7 +22,7 @@ def __init__(self, platform_name: str, community_id: str) -> None: self.platform_name = platform_name self.community_id = community_id - def prepare(self, testing=False): + def prepare(self, testing=False) -> BaseQueryEngine: engine = DualQdrantRetrievalEngine.setup_engine( llm=Settings.llm, synthesizer=get_response_synthesizer(), diff --git a/utils/query_engine/telegram.py b/utils/query_engine/telegram.py index 381eb4d..63978a6 100644 --- a/utils/query_engine/telegram.py +++ b/utils/query_engine/telegram.py @@ -1,5 +1,4 @@ -from llama_index.core import Settings -from llama_index.core import get_response_synthesizer +from llama_index.core import Settings, get_response_synthesizer from llama_index.core.query_engine import BaseQueryEngine from schema.type import DataType From 0e8a2f8a119f77cd1cf8b9774ec8fdb2ab815199 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 11:58:25 +0330 Subject: [PATCH 7/9] fix: test case is not required! the base qdrant engine is now using the structure of our dual qdrant engine. in future we'll be removing the BaseQdrantEngine --- tests/unit/test_base_qdrant_engine.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 tests/unit/test_base_qdrant_engine.py diff --git a/tests/unit/test_base_qdrant_engine.py b/tests/unit/test_base_qdrant_engine.py deleted file mode 100644 index b3bca46..0000000 --- a/tests/unit/test_base_qdrant_engine.py +++ /dev/null @@ -1,25 +0,0 @@ -from unittest import TestCase - -from utils.query_engine.base_qdrant_engine import BaseQdrantEngine - - -class TestBaseQdrantEngine(TestCase): - def test_setup_vector_store_index(self): - """ - Tests that _setup_vector_store_index creates a PGVectorAccess object - and calls its load_index method. - """ - platform_table_name = "test_table" - community_id = "123456" - base_engine = BaseQdrantEngine( - platform_name=platform_table_name, - community_id=community_id, - ) - base_engine = base_engine._setup_vector_store_index( - testing=True, - ) - - expected_collection_name = f"{community_id}_{platform_table_name}" - self.assertEqual( - base_engine.vector_store.collection_name, expected_collection_name - ) From f3053457b833b2995a4240910d8edba68b5cff98 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 12:04:31 +0330 Subject: [PATCH 8/9] fix: linter issues! --- utils/query_engine/dual_qdrant_retrieval_engine.py | 12 +++++------- utils/query_engine/qdrant_query_engine_utils.py | 4 ++-- utils/query_engine/telegram.py | 3 +-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index aaf655e..41f96ac 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -1,17 +1,16 @@ from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.llms.openai import OpenAI from llama_index.core import PromptTemplate, VectorStoreIndex -from llama_index.core.query_engine import CustomQueryEngine -from llama_index.core.retrievers import BaseRetriever -from llama_index.core.response_synthesizers import BaseSynthesizer from llama_index.core.indices.vector_store.retrievers.retriever import ( VectorIndexRetriever, ) -from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess +from llama_index.core.query_engine import CustomQueryEngine +from llama_index.core.response_synthesizers import BaseSynthesizer +from llama_index.core.retrievers import BaseRetriever +from llama_index.llms.openai import OpenAI from schema.type import DataType +from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils - qa_prompt = PromptTemplate( "Context information is below.\n" "---------------------\n" @@ -33,7 +32,6 @@ class DualQdrantRetrievalEngine(CustomQueryEngine): qa_prompt: PromptTemplate def custom_query(self, query_str: str): - if self.summary_retriever is None: nodes = self.retriever.retrieve(query_str) context_str = "\n\n".join([n.node.get_content() for n in nodes]) diff --git a/utils/query_engine/qdrant_query_engine_utils.py b/utils/query_engine/qdrant_query_engine_utils.py index cfde108..b560749 100644 --- a/utils/query_engine/qdrant_query_engine_utils.py +++ b/utils/query_engine/qdrant_query_engine_utils.py @@ -1,10 +1,10 @@ from collections import defaultdict from datetime import datetime, timedelta -from dateutil.parser import parse +from dateutil.parser import parse from llama_index.core.schema import NodeWithScore -from schema.type import DataType from qdrant_client.http import models +from schema.type import DataType class QdrantEngineUtils: diff --git a/utils/query_engine/telegram.py b/utils/query_engine/telegram.py index 63978a6..77ccb48 100644 --- a/utils/query_engine/telegram.py +++ b/utils/query_engine/telegram.py @@ -1,9 +1,8 @@ from llama_index.core import Settings, get_response_synthesizer from llama_index.core.query_engine import BaseQueryEngine - from schema.type import DataType -from utils.query_engine.base_qdrant_engine import BaseQdrantEngine from utils.query_engine import DualQdrantRetrievalEngine +from utils.query_engine.base_qdrant_engine import BaseQdrantEngine class TelegramQueryEngine(BaseQdrantEngine): From d9249220f92bb2dd1d43c8ffbf935f096e54a6b6 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 14 Nov 2024 13:30:02 +0330 Subject: [PATCH 9/9] fix: cleaning codes! codeRabbitAI suggestions. --- .../dual_qdrant_retrieval_engine.py | 69 +++++++++++-------- .../query_engine/qdrant_query_engine_utils.py | 12 ++-- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py index 41f96ac..2903287 100644 --- a/utils/query_engine/dual_qdrant_retrieval_engine.py +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -33,37 +33,9 @@ class DualQdrantRetrievalEngine(CustomQueryEngine): def custom_query(self, query_str: str): if self.summary_retriever is None: - nodes = self.retriever.retrieve(query_str) - context_str = "\n\n".join([n.node.get_content() for n in nodes]) - prompt = qa_prompt.format(context_str=context_str, query_str=query_str) - response = self.llm.complete(prompt) + response = self._process_basic_query(query_str) else: - summary_nodes = self.summary_retriever.retrieve(query_str) - utils = QdrantEngineUtils( - metadata_date_key=self.metadata_date_key, - metadata_date_format=self.metadata_date_format, - date_margin=self._date_margin, - ) - # the filters that will be applied on qdrant - dates = [ - node.metadata[self.metadata_date_summary_key] for node in summary_nodes - ] - filter = utils.define_raw_data_filters(dates=dates) - _, raw_data_top_k, _ = load_hyperparams() - - # retrieve based on summary nodes - retriever: BaseRetriever = self._vector_store_index.as_retriever( - vector_store_kwargs={"qdrant_filters": filter}, - similarity_top_k=raw_data_top_k, - ) - raw_nodes = retriever.retrieve(query_str) - - context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) - - prompt = qa_prompt.format(context_str=context_str, query_str=query_str) - - response = self.llm.complete(prompt) - + response = self._process_summary_query(query_str) return str(response) @classmethod @@ -199,3 +171,40 @@ def _setup_vector_store_index( qdrant_vector = QDrantVectorAccess(collection_name=collection_name) index = qdrant_vector.load_index() return index + + def _process_basic_query(self, query_str: str) -> str: + nodes = self.retriever.retrieve(query_str) + context_str = "\n\n".join([n.node.get_content() for n in nodes]) + prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) + return response + + def _process_summary_query(self, query_str: str) -> str: + summary_nodes = self.summary_retriever.retrieve(query_str) + utils = QdrantEngineUtils( + metadata_date_key=self.metadata_date_key, + metadata_date_format=self.metadata_date_format, + date_margin=self._date_margin, + ) + + dates = [ + node.metadata[self.metadata_date_summary_key] + for node in summary_nodes + if self.metadata_date_summary_key in node.metadata + ] + + if not dates: + return self._process_basic_query(query_str) + + filter = utils.define_raw_data_filters(dates=dates) + + retriever: BaseRetriever = self._vector_store_index.as_retriever( + vector_store_kwargs={"qdrant_filters": filter}, + similarity_top_k=self._raw_data_top_k, + ) + raw_nodes = retriever.retrieve(query_str) + + context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) + prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) + return response diff --git a/utils/query_engine/qdrant_query_engine_utils.py b/utils/query_engine/qdrant_query_engine_utils.py index b560749..517a35a 100644 --- a/utils/query_engine/qdrant_query_engine_utils.py +++ b/utils/query_engine/qdrant_query_engine_utils.py @@ -1,5 +1,5 @@ from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from dateutil.parser import parse from llama_index.core.schema import NodeWithScore @@ -48,10 +48,10 @@ def define_raw_data_filters(self, dates: list[str]) -> models.Filter: for day_value in expanded_dates: next_day = day_value + timedelta(days=1) - if self.metadata_date_format is DataType.INTEGER: + if self.metadata_date_format == DataType.INTEGER: gte_value = int(day_value.timestamp()) lte_value = int(next_day.timestamp()) - elif self.metadata_date_format is DataType.FLOAT: + elif self.metadata_date_format == DataType.FLOAT: gte_value = day_value.timestamp() lte_value = next_day.timestamp() else: @@ -100,7 +100,11 @@ def combine_nodes_for_prompt( for raw_node in raw_nodes: timestamp = raw_node.metadata[self.metadata_date_key] - date_str = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d") + date_str = ( + datetime.fromtimestamp(timestamp) + .replace(tzinfo=timezone.utc) + .strftime("%Y-%m-%d") + ) if date_str not in raw_nodes_by_date: raw_nodes_by_date[date_str] = []