diff --git a/schema/type.py b/schema/type.py new file mode 100644 index 0000000..a0d1c5f --- /dev/null +++ b/schema/type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class DataType(Enum): + INTEGER = "INTEGER" + STRING = "STRING" + BOOLEAN = "BOOLEAN" + FLOAT = "FLOAT" diff --git a/subquery.py b/subquery.py index eeaa24f..1dacf83 100644 --- a/subquery.py +++ b/subquery.py @@ -1,6 +1,5 @@ from guidance.models import OpenAIChat from llama_index.core import QueryBundle, Settings -from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.query_engine import SubQuestionQueryEngine from llama_index.core.schema import NodeWithScore from llama_index.core.tools import QueryEngineTool, ToolMetadata @@ -14,6 +13,7 @@ GitHubQueryEngine, MediaWikiQueryEngine, NotionQueryEngine, + TelegramDualQueryEngine, TelegramQueryEngine, prepare_discord_engine_auto_filter, ) @@ -71,14 +71,6 @@ def query_multiple_source( tools: list[ToolMetadata] = [] qdrant_utils = QDrantUtils(community_id) - discord_query_engine: BaseQueryEngine - github_query_engine: BaseQueryEngine - # discourse_query_engine: BaseQueryEngine - google_query_engine: BaseQueryEngine - notion_query_engine: BaseQueryEngine - mediawiki_query_engine: BaseQueryEngine - # telegram_query_engine: BaseQueryEngine - # wrapper for more clarity check_collection = qdrant_utils.check_collection_exist @@ -134,7 +126,16 @@ def query_multiple_source( ) ) if telegram and check_collection("telegram"): - telegram_query_engine = TelegramQueryEngine(community_id=community_id).prepare() + # checking if the summaries was available + if check_collection("telegram_summary"): + telegram_query_engine = TelegramDualQueryEngine( + community_id=community_id + ).prepare() + else: + telegram_query_engine = TelegramQueryEngine( + community_id=community_id + ).prepare() + tool_metadata = ToolMetadata( name="Telegram", description=( diff --git a/tests/unit/test_base_qdrant_engine.py b/tests/unit/test_base_qdrant_engine.py deleted file mode 100644 index b3bca46..0000000 --- a/tests/unit/test_base_qdrant_engine.py +++ /dev/null @@ -1,25 +0,0 @@ -from unittest import TestCase - -from utils.query_engine.base_qdrant_engine import BaseQdrantEngine - - -class TestBaseQdrantEngine(TestCase): - def test_setup_vector_store_index(self): - """ - Tests that _setup_vector_store_index creates a PGVectorAccess object - and calls its load_index method. - """ - platform_table_name = "test_table" - community_id = "123456" - base_engine = BaseQdrantEngine( - platform_name=platform_table_name, - community_id=community_id, - ) - base_engine = base_engine._setup_vector_store_index( - testing=True, - ) - - expected_collection_name = f"{community_id}_{platform_table_name}" - self.assertEqual( - base_engine.vector_store.collection_name, expected_collection_name - ) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index b1ae299..71642ec 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,8 +1,9 @@ # flake8: noqa +from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine from .gdrive import GDriveQueryEngine from .github import GitHubQueryEngine from .media_wiki import MediaWikiQueryEngine from .notion import NotionQueryEngine from .prepare_discord_query_engine import prepare_discord_engine_auto_filter from .subquery_gen_prompt import DEFAULT_GUIDANCE_SUB_QUESTION_PROMPT_TMPL -from .telegram import TelegramQueryEngine +from .telegram import TelegramDualQueryEngine, TelegramQueryEngine diff --git a/utils/query_engine/base_qdrant_engine.py b/utils/query_engine/base_qdrant_engine.py index 2c89896..774817c 100644 --- a/utils/query_engine/base_qdrant_engine.py +++ b/utils/query_engine/base_qdrant_engine.py @@ -1,10 +1,7 @@ -from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.core import VectorStoreIndex, get_response_synthesizer -from llama_index.core.indices.vector_store.retrievers.retriever import ( - VectorIndexRetriever, -) -from llama_index.core.query_engine import RetrieverQueryEngine -from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess +from llama_index.core import Settings, get_response_synthesizer +from llama_index.core.query_engine import BaseQueryEngine + +from .dual_qdrant_retrieval_engine import DualQdrantRetrievalEngine class BaseQdrantEngine: @@ -24,44 +21,13 @@ def __init__(self, platform_name: str, community_id: str) -> None: """ self.platform_name = platform_name self.community_id = community_id - self.collection_name = f"{self.community_id}_{platform_name}" - def prepare(self, testing=False): - vector_store_index = self._setup_vector_store_index( - testing=testing, + def prepare(self, testing=False) -> BaseQueryEngine: + engine = DualQdrantRetrievalEngine.setup_engine( + llm=Settings.llm, + synthesizer=get_response_synthesizer(), + platform_name=self.platform_name, + community_id=self.community_id, ) - _, similarity_top_k, _ = load_hyperparams() - retriever = VectorIndexRetriever( - index=vector_store_index, - similarity_top_k=similarity_top_k, - ) - query_engine = RetrieverQueryEngine( - retriever=retriever, - response_synthesizer=get_response_synthesizer(), - ) - return query_engine - - def _setup_vector_store_index( - self, - testing: bool = False, - **kwargs, - ) -> VectorStoreIndex: - """ - prepare the vector_store for querying data - - Parameters - ------------ - testing : bool - for testing purposes - **kwargs : - collection_name : str - to override the default collection_name - """ - collection_name = kwargs.get("collection_name", self.collection_name) - qdrant_vector = QDrantVectorAccess( - collection_name=collection_name, - testing=testing, - ) - index = qdrant_vector.load_index() - return index + return engine diff --git a/utils/query_engine/dual_qdrant_retrieval_engine.py b/utils/query_engine/dual_qdrant_retrieval_engine.py new file mode 100644 index 0000000..2903287 --- /dev/null +++ b/utils/query_engine/dual_qdrant_retrieval_engine.py @@ -0,0 +1,210 @@ +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.core import PromptTemplate, VectorStoreIndex +from llama_index.core.indices.vector_store.retrievers.retriever import ( + VectorIndexRetriever, +) +from llama_index.core.query_engine import CustomQueryEngine +from llama_index.core.response_synthesizers import BaseSynthesizer +from llama_index.core.retrievers import BaseRetriever +from llama_index.llms.openai import OpenAI +from schema.type import DataType +from tc_hivemind_backend.qdrant_vector_access import QDrantVectorAccess +from utils.query_engine.qdrant_query_engine_utils import QdrantEngineUtils + +qa_prompt = PromptTemplate( + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " +) + + +class DualQdrantRetrievalEngine(CustomQueryEngine): + """RAG String Query Engine.""" + + retriever: BaseRetriever + response_synthesizer: BaseSynthesizer + llm: OpenAI + qa_prompt: PromptTemplate + + def custom_query(self, query_str: str): + if self.summary_retriever is None: + response = self._process_basic_query(query_str) + else: + response = self._process_summary_query(query_str) + return str(response) + + @classmethod + def setup_engine( + cls, + llm: OpenAI, + synthesizer: BaseSynthesizer, + platform_name: str, + community_id: str, + ): + """ + setup the custom query engine on qdrant data + + Parameters + ------------ + llm : OpenAI + the llm to be used for RAG pipeline + synthesizer : BaseSynthesizer + the process of generating response using an LLM + qa_prompt : PromptTemplate + the prompt template to be filled and passed to an LLM + platform_name : str + specifying the platform data to identify the data collection + community_id : str + specifying community_id to identify the data collection + """ + collection_name = f"{community_id}_{platform_name}" + + _, raw_data_top_k, date_margin = load_hyperparams() + cls._date_margin = date_margin + + cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index( + collection_name=collection_name + ) + retriever = VectorIndexRetriever( + index=cls._vector_store_index, + similarity_top_k=raw_data_top_k, + ) + + cls.summary_retriever = None + + return cls( + retriever=retriever, + response_synthesizer=synthesizer, + llm=llm, + qa_prompt=qa_prompt, + ) + + @classmethod + def setup_engine_with_summaries( + cls, + llm: OpenAI, + synthesizer: BaseSynthesizer, + platform_name: str, + community_id: str, + metadata_date_key: str, + metadata_date_format: DataType, + metadata_date_summary_key: str, + metadata_date_summary_format: DataType, + ): + """ + setup the custom query engine on qdrant data + + Parameters + ------------ + use_summary : bool + whether to use the summary data or not + note: the summary data should be available before + for this option to be enabled + llm : OpenAI + the llm to be used for RAG pipeline + synthesizer : BaseSynthesizer + the process of generating response using an LLM + platform_name : str + specifying the platform data to identify the data collection + community_id : str + specifying community_id to identify the data collection + metadata_date_summary_key : str | None + the date key name in summary documents' metadata + In case of `use_summary` equal to be true this shuold be passed + metadata_date_summary_format : DataType | None + the date format in metadata + In case of `use_summary` equal to be true this shuold be passed + NOTE: this should be always a string for the filtering of it to work. + """ + collection_name = f"{community_id}_{platform_name}" + summary_data_top_k, raw_data_top_k, date_margin = load_hyperparams() + cls._date_margin = date_margin + cls._raw_data_top_k = raw_data_top_k + + cls._vector_store_index: VectorStoreIndex = cls._setup_vector_store_index( + collection_name=collection_name + ) + retriever = VectorIndexRetriever( + index=cls._vector_store_index, + similarity_top_k=raw_data_top_k, + ) + + summary_collection_name = collection_name + "_summary" + summary_vector_store_index = cls._setup_vector_store_index( + collection_name=summary_collection_name + ) + + cls.summary_retriever = VectorIndexRetriever( + index=summary_vector_store_index, + similarity_top_k=summary_data_top_k, + ) + cls.metadata_date_summary_key = metadata_date_summary_key + cls.metadata_date_summary_format = metadata_date_summary_format + cls.metadata_date_key = metadata_date_key + cls.metadata_date_format = metadata_date_format + + return cls( + retriever=retriever, + response_synthesizer=synthesizer, + llm=llm, + qa_prompt=qa_prompt, + ) + + @classmethod + def _setup_vector_store_index( + cls, + collection_name: str, + ) -> VectorStoreIndex: + """ + prepare the vector_store for querying data + + Parameters + ------------ + collection_name : str + to override the default collection_name + """ + qdrant_vector = QDrantVectorAccess(collection_name=collection_name) + index = qdrant_vector.load_index() + return index + + def _process_basic_query(self, query_str: str) -> str: + nodes = self.retriever.retrieve(query_str) + context_str = "\n\n".join([n.node.get_content() for n in nodes]) + prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) + return response + + def _process_summary_query(self, query_str: str) -> str: + summary_nodes = self.summary_retriever.retrieve(query_str) + utils = QdrantEngineUtils( + metadata_date_key=self.metadata_date_key, + metadata_date_format=self.metadata_date_format, + date_margin=self._date_margin, + ) + + dates = [ + node.metadata[self.metadata_date_summary_key] + for node in summary_nodes + if self.metadata_date_summary_key in node.metadata + ] + + if not dates: + return self._process_basic_query(query_str) + + filter = utils.define_raw_data_filters(dates=dates) + + retriever: BaseRetriever = self._vector_store_index.as_retriever( + vector_store_kwargs={"qdrant_filters": filter}, + similarity_top_k=self._raw_data_top_k, + ) + raw_nodes = retriever.retrieve(query_str) + + context_str = utils.combine_nodes_for_prompt(summary_nodes, raw_nodes) + prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(prompt) + return response diff --git a/utils/query_engine/qdrant_query_engine_utils.py b/utils/query_engine/qdrant_query_engine_utils.py new file mode 100644 index 0000000..517a35a --- /dev/null +++ b/utils/query_engine/qdrant_query_engine_utils.py @@ -0,0 +1,138 @@ +from collections import defaultdict +from datetime import datetime, timedelta, timezone + +from dateutil.parser import parse +from llama_index.core.schema import NodeWithScore +from qdrant_client.http import models +from schema.type import DataType + + +class QdrantEngineUtils: + def __init__( + self, + metadata_date_key: str, + metadata_date_format: DataType, + date_margin: int, + ) -> None: + self.metadata_date_key = metadata_date_key + self.metadata_date_format = metadata_date_format + self.date_margin = date_margin + + def define_raw_data_filters(self, dates: list[str]) -> models.Filter: + """ + define the filters to be applied on raw data given the dates + + Parameters + ----------- + dates : list[str] + a list of dates that should be a string. i.e. with format of `%Y-%m-%d` + the date should be representing a day + + Returns + --------- + filter : models.Filter + the filters to be applied on raw data + """ + should_filters: list[models.FieldCondition] = [] + expanded_dates: set[datetime] = set() + + # accounting for the date margin + for date in dates: + day_value = parse(date) + expanded_dates.add(day_value) + + for i in range(1, self.date_margin + 1): + expanded_dates.add(day_value - timedelta(days=i)) + expanded_dates.add(day_value + timedelta(days=i)) + + for day_value in expanded_dates: + next_day = day_value + timedelta(days=1) + + if self.metadata_date_format == DataType.INTEGER: + gte_value = int(day_value.timestamp()) + lte_value = int(next_day.timestamp()) + elif self.metadata_date_format == DataType.FLOAT: + gte_value = day_value.timestamp() + lte_value = next_day.timestamp() + else: + raise ValueError( + "raw data metadata `date` shouldn't be anything other than FLOAT or INTEGER" + ) + + should_filters.append( + models.FieldCondition( + key=self.metadata_date_key, + range=models.Range( + gte=gte_value, + lte=lte_value, + ), + ) + ) + + filter = models.Filter(should=should_filters) + + return filter + + def combine_nodes_for_prompt( + self, + summary_nodes: list[NodeWithScore], + raw_nodes: list[NodeWithScore], + ) -> str: + """ + Combines summary nodes with their corresponding raw nodes based on date matching. + + Parameters + ---------- + summary_nodes : list[NodeWithScore] + list of summary nodes containing metadata with 'date' in "%Y-%m-%d" format + and 'text' field + raw_nodes : list[NodeWithScore] + list of raw nodes containing metadata with self.me as float timestamp + and 'text' field + + Returns + ------- + prompt : str + A formatted prompt combining matched summary and raw texts + """ + # Create a mapping of date to raw nodes for efficient lookup + raw_nodes_by_date: dict[str, list[NodeWithScore]] = {} + + for raw_node in raw_nodes: + timestamp = raw_node.metadata[self.metadata_date_key] + date_str = ( + datetime.fromtimestamp(timestamp) + .replace(tzinfo=timezone.utc) + .strftime("%Y-%m-%d") + ) + + if date_str not in raw_nodes_by_date: + raw_nodes_by_date[date_str] = [] + raw_nodes_by_date[date_str].append(raw_node) + + # A summary could be separated into multiple nodes + # combining them together + combined_summaries: dict[str, str] = defaultdict(str) + for summary_node in summary_nodes: + date = summary_node.metadata["date"] + summary_text = summary_node.text + + summaries = summary_text.split("\n") + summary_bullets = set(summaries) + if "" in summary_bullets: + summary_bullets.remove("") + combined_summaries[date] += "\n".join(summary_bullets) + + # Build the combined prompt + combined_sections = [] + + for date, summary_bullets in combined_summaries.items(): + section = f"Date: {date}\nSummary:\n" + summary_bullets + "\n\n" + + if date in raw_nodes_by_date: + raw_texts = [node.text for node in raw_nodes_by_date[date]] + section += "Messages:\n" + "\n".join(raw_texts) + + combined_sections.append(section) + + return "\n\n".join(combined_sections) diff --git a/utils/query_engine/telegram.py b/utils/query_engine/telegram.py index 9f0ba48..77ccb48 100644 --- a/utils/query_engine/telegram.py +++ b/utils/query_engine/telegram.py @@ -1,3 +1,7 @@ +from llama_index.core import Settings, get_response_synthesizer +from llama_index.core.query_engine import BaseQueryEngine +from schema.type import DataType +from utils.query_engine import DualQdrantRetrievalEngine from utils.query_engine.base_qdrant_engine import BaseQdrantEngine @@ -5,3 +9,22 @@ class TelegramQueryEngine(BaseQdrantEngine): def __init__(self, community_id: str) -> None: platform_name = "telegram" super().__init__(platform_name, community_id) + + +class TelegramDualQueryEngine: + def __init__(self, community_id: str) -> None: + self.platform_name = "telegram" + self.community_id = community_id + + def prepare(self) -> BaseQueryEngine: + engine = DualQdrantRetrievalEngine.setup_engine_with_summaries( + llm=Settings.llm, + synthesizer=get_response_synthesizer(), + platform_name=self.platform_name, + community_id=self.community_id, + metadata_date_key="createdAt", + metadata_date_format=DataType.FLOAT, + metadata_date_summary_key="date", + metadata_date_summary_format=DataType.STRING, + ) + return engine