From d38e3431ee918b682950d7e9811c438b4f3d68a0 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 18:23:20 +0330 Subject: [PATCH 01/13] WIP: updating the prompt! The prompts are now having a new structure. --- bot/retrievers/forum_summary_retriever.py | 32 +++++ .../level_based_platform_query_engine.py | 136 ++++++++++++++++-- .../prepare_discord_query_engine.py | 1 + 3 files changed, 157 insertions(+), 12 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 6dd56d1..e02edc0 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -1,6 +1,7 @@ from bot.retrievers.summary_retriever_base import BaseSummarySearch from llama_index.embeddings import BaseEmbedding from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from llama_index.schema import NodeWithScore class ForumBasedSummaryRetriever(BaseSummarySearch): @@ -53,6 +54,37 @@ def retreive_filtering( """ nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) + filters = self.define_filters( + nodes=nodes, + metadata_group1_key=metadata_group1_key, + metadata_group2_key=metadata_group2_key, + metadata_date_key=metadata_date_key, + ) + + return filters + + def define_filters( + self, + nodes: list[NodeWithScore], + metadata_group1_key: str, + metadata_group2_key: str, + metadata_date_key: str, + ) -> list[dict[str, str]]: + """ + define dictionary filters based on metadata of retrieved nodes + + Parameters + ---------- + nodes : list[dict[llama_index.schema.NodeWithScore]] + a list of retrieved similar nodes to define filters based + + Returns + --------- + filters : list[dict[str, str]] + a list of filters to apply with `or` condition + the dictionary would be applying `and` + operation between keys and values of json metadata_ + """ filters: list[dict[str, str]] = [] for node in nodes: diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 6d27c01..a0d1433 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -1,3 +1,4 @@ +from dateutil import parser import logging from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever @@ -40,10 +41,10 @@ def custom_query(self, query_str: str): ) similar_nodes = retriever.query_db(query=query_str, filters=self._filters) - context_str = self._prepare_context_str(similar_nodes) + context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) - # logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") + logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") return str(response) @classmethod @@ -77,7 +78,7 @@ def prepare_platform_engine( **kwargs : llm : llama-index.LLM the LLM to use answering queries - default is gpt-3.5-turbo + default is gpt-4 synthesizer : llama_index.response_synthesizers.base.BaseSynthesizer the synthesizers to use when creating the prompt default is to get from `get_response_synthesizer(response_mode="compact")` @@ -95,7 +96,7 @@ def prepare_platform_engine( synthesizer = kwargs.get( "synthesizer", get_response_synthesizer(response_mode="compact") ) - llm = kwargs.get("llm", OpenAI("gpt-3.5-turbo")) + llm = kwargs.get("llm", OpenAI("gpt-4")) qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) pg_vector = PGVectorAccess( @@ -128,6 +129,7 @@ def prepare_engine_auto_filter( level1_key: str, level2_key: str, date_key: str = "date", + include_summary_context: bool = False, ) -> "LevelBasedPlatformQueryEngine": """ get the query engine and do the filtering automatically. @@ -171,13 +173,23 @@ def prepare_engine_auto_filter( table_name=platform_table_name + "_summary", dbname=dbname ) - filters = platform_retriever.retreive_filtering( - query=query, + nodes = platform_retriever.get_similar_nodes(query, summary_similarity_top_k) + + filters = platform_retriever.define_filters( + nodes, metadata_group1_key=level1_key, metadata_group2_key=level2_key, metadata_date_key=date_key, - similarity_top_k=summary_similarity_top_k, ) + # saving to add summaries to the context of prompt + if include_summary_context: + cls.summary_nodes = nodes + else: + cls.summary_nodes = [] + + cls._level1_key = level1_key + cls._level2_key = level2_key + cls._date_key = date_key # getting all the metadata dates from filters dates: list[str] = [f[date_key] for f in filters] @@ -194,13 +206,113 @@ def prepare_engine_auto_filter( ) return engine - def _prepare_context_str(self, nodes: list[NodeWithScore]) -> str: - context_str = "\n\n".join( + def _prepare_context_str( + self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] + ) -> str: + """ + prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information + """ + context_str: str = "" + + if summary_nodes == []: + logging.warning( + "Empty context_nodes. Cannot add summaries as context information!" + ) + + context_str += self._prepare_prompt_with_metadata_info(nodes=raw_nodes) + else: + grouped_raw_nodes = self._group_nodes_per_metadata(raw_nodes) + for summary_node in summary_nodes: + # can be thread_title for discord + level1_title = summary_node.metadata[self._level1_key] + # can be channel_title for discord + level2_title = summary_node.metadata[self._level2_key] + date = summary_node.metadata[self._date_key] + + # intiialization + node_context: str = "" + + nested_dict = grouped_raw_nodes.get(level1_title, {}).get( + level2_title, {} + ) + + if date in nested_dict: + raw_nodes = grouped_raw_nodes[level1_title][level2_title][date] + node_context: str = ( + f"{self._level1_key}: {level1_title}\n" + f"{self._level2_key}: {level2_title}\n" + f"{self._date_key}: {date}\n" + f"summary: {summary_node.text}\n" + "messages:\n" + ) + node_context += self._prepare_prompt_with_metadata_info( + raw_nodes, prefix=" " + ) + + context_str += node_context + + logging.info(f"||||||||context_str|||||||| {context_str} |||||||") + return context_str + + def _group_nodes_per_metadata( + self, raw_nodes: list[NodeWithScore] + ) -> dict[str, dict[str, dict[str, list[NodeWithScore]]]]: + """ + group all nodes based on their level1 and level2 metadata + + Parameters + ----------- + raw_nodes : list[NodeWithScore] + a list of raw nodes + + Returns + --------- + grouped_nodes : dict[str, dict[str, dict[str, list[NodeWithScore]]]] + a list of nodes grouped by + - `level1_key` + - `level2_key` + - and the last dict key `date_key` + + The values of the nested dictionary are the nodes grouped + """ + grouped_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} + for node in raw_nodes: + level1_title = node.metadata[self._level1_key] + # TODO: remove the _name when the data got updated + level2_title = node.metadata[self._level2_key + "_name"] + date_str = node.metadata[self._date_key] + date = parser.parse(date_str).strftime("%Y-%m-%d") + + # defining an empty list (if keys weren't previously made) + grouped_nodes.setdefault(level1_title, {}).setdefault( + level2_title, {} + ).setdefault(date, []) + # Adding to list + grouped_nodes[level1_title][level2_title][date].append(node) + + return grouped_nodes + + def _prepare_prompt_with_metadata_info( + self, nodes: list[NodeWithScore], prefix: str = "" + ) -> str: + """ + prepare a prompt with given nodes including the nodes' metadata + Note: the prefix is set before each text! + """ + context_str = "\n".join( [ - node.get_content() + "author: " + + node.metadata["author_username"] + + "\n" + + prefix + + "message_date: " + + node.metadata["date"] + "\n" - + node.node.get_metadata_str(mode=MetadataMode.LLM) - for node in nodes + + prefix + + f"message {idx + 1}: " + + node.get_content() + for idx, node in enumerate(nodes) ] ) + return context_str diff --git a/utils/query_engine/prepare_discord_query_engine.py b/utils/query_engine/prepare_discord_query_engine.py index ac185a9..553dde3 100644 --- a/utils/query_engine/prepare_discord_query_engine.py +++ b/utils/query_engine/prepare_discord_query_engine.py @@ -72,5 +72,6 @@ def prepare_discord_engine_auto_filter( level1_key="channel", level2_key="thread", date_key="date", + include_summary_context=True, ) return engine From 08fca4036abd68b3d6a9a34469b8a3af80ed7795 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 7 Feb 2024 16:16:44 +0330 Subject: [PATCH 02/13] feat: updated metadata filtering + new prompts! + The metadata filtering is now done for +d and -d days of given channel-thread filters. + the prompts are now completely changed by adding summaries as context of the messages. --- bot/retrievers/forum_summary_retriever.py | 6 +- bot/retrievers/process_dates.py | 4 +- bot/retrievers/retrieve_similar_nodes.py | 39 +++++++-- .../level_based_platform_query_engine.py | 79 +++++++++++++++---- 4 files changed, 102 insertions(+), 26 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index e02edc0..31468fc 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -90,10 +90,8 @@ def define_filters( for node in nodes: # the filter made by given node filter: dict[str, str] = {} - if node.metadata[metadata_group1_key]: - filter[metadata_group1_key] = node.metadata[metadata_group1_key] - if node.metadata[metadata_group2_key]: - filter[metadata_group2_key] = node.metadata[metadata_group2_key] + filter[metadata_group1_key] = node.metadata[metadata_group1_key] + filter[metadata_group2_key] = node.metadata[metadata_group2_key] # date filter filter[metadata_date_key] = node.metadata[metadata_date_key] diff --git a/bot/retrievers/process_dates.py b/bot/retrievers/process_dates.py index dba3217..cb5ecec 100644 --- a/bot/retrievers/process_dates.py +++ b/bot/retrievers/process_dates.py @@ -19,7 +19,9 @@ def process_dates(dates: list[str], d: int) -> list[str]: Returns ---------- dates_modified : list[str] - days added to it + days added to it sorted ascending meaning + the first index is the lowest date + and the last is the biggest date """ dates_modified: list[str] = [] if dates != []: diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index baac30e..51e89a1 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -1,3 +1,6 @@ +from datetime import timedelta +from dateutil import parser + from llama_index.embeddings import BaseEmbedding from llama_index.schema import NodeWithScore from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult @@ -21,7 +24,10 @@ def __init__( self._similarity_top_k = similarity_top_k def query_db( - self, query: str, filters: list[dict[str, str]] | None = None + self, + query: str, + filters: list[dict[str, str]] | None = None, + date_interval: int = 0, ) -> list[NodeWithScore]: """ query database with given filters (similarity search is also done) @@ -35,6 +41,9 @@ def query_db( the dictionary would be applying `and` operation between keys and values of json metadata_ if `None` then no filtering would be applied + date_interval : int + the number of back and forth days of date + default is set to 0 meaning no days back or forward. """ self._vector_store._initialize() embedding = self._embed_model.get_text_embedding(text=query) @@ -55,21 +64,41 @@ def query_db( for key, value in condition.items(): if key == "date": # Apply ::date cast when the key is 'date' - filter_condition = cast( + date = parser.parse(value) + date_back = (date - timedelta(days=date_interval)).strftime( + "%Y-%m-%d" + ) + date_forward = (date + timedelta(days=date_interval)).strftime( + "%Y-%m-%d" + ) + + filter_condition_back = cast( self._vector_store._table_class.metadata_.op("->>")(key), Date, - ) == cast(value, Date) + ) >= cast(date_back, Date) + + filter_condition_forward = cast( + self._vector_store._table_class.metadata_.op("->>")(key), + Date, + ) <= cast(date_forward, Date) + + filters_and.append(filter_condition_back) + filters_and.append(filter_condition_forward) else: filter_condition = ( self._vector_store._table_class.metadata_.op("->>")(key) == value + if value is not None + else self._vector_store._table_class.metadata_.op("->>")( + key + ).is_(None) ) - - filters_and.append(filter_condition) + filters_and.append(filter_condition) conditions.append(and_(*filters_and)) stmt = stmt.where(or_(*conditions)) + print("filters", filters) stmt = stmt.limit(self._similarity_top_k) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index a0d1433..fbd93c1 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -39,12 +39,15 @@ def custom_query(self, query_str: str): self._vector_store, self._similarity_top_k, ) - similar_nodes = retriever.query_db(query=query_str, filters=self._filters) + logging.info(f"self._filters {self._filters}") + similar_nodes = retriever.query_db( + query=query_str, filters=self._filters, date_interval=self._d + ) context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) - logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") + # logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") return str(response) @classmethod @@ -107,7 +110,8 @@ def prepare_platform_engine( ) index = pg_vector.load_index() retriever = index.as_retriever() - _, similarity_top_k, _ = load_hyperparams() + _, similarity_top_k, d = load_hyperparams() + cls._d = d cls._vector_store = index.vector_store cls._similarity_top_k = similarity_top_k @@ -190,12 +194,13 @@ def prepare_engine_auto_filter( cls._level1_key = level1_key cls._level2_key = level2_key cls._date_key = date_key + cls._d = d - # getting all the metadata dates from filters - dates: list[str] = [f[date_key] for f in filters] - dates_modified = process_dates(list(dates), d) - dates_filter = [{date_key: date} for date in dates_modified] - filters.extend(dates_filter) + # # getting all the metadata dates from filters + # dates: list[str] = [f[date_key] for f in filters] + # dates_modified = process_dates(list(dates), d) + # dates_filter = [{date_key: date} for date in dates_modified] + # filters.extend(dates_filter) logging.info(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") @@ -212,6 +217,8 @@ def _prepare_context_str( """ prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information """ + logging.info(f"START len(raw_nodes) {len(raw_nodes)}") + logging.info(f"START len(summary_nodes) {len(summary_nodes)}") context_str: str = "" if summary_nodes == []: @@ -222,12 +229,28 @@ def _prepare_context_str( context_str += self._prepare_prompt_with_metadata_info(nodes=raw_nodes) else: grouped_raw_nodes = self._group_nodes_per_metadata(raw_nodes) + + raw_data_logs: list[str] = [] + for level1_title in grouped_raw_nodes: + for level2_title in grouped_raw_nodes[level1_title]: + for date in grouped_raw_nodes[level1_title][level2_title]: + raw_data_logs.append( + f"GROUPED RAW DATA {self._level1_key}: {level1_title}, {self._level2_key}: {level2_title}, {self._date_key}: {date}" + f", Message count {len(grouped_raw_nodes[level1_title][level2_title][date])}" + ) + + logging.info(f"raw_data_logs {raw_data_logs}") + + summary_log_data: list[str] = [] for summary_node in summary_nodes: # can be thread_title for discord level1_title = summary_node.metadata[self._level1_key] # can be channel_title for discord level2_title = summary_node.metadata[self._level2_key] date = summary_node.metadata[self._date_key] + summary_log_data.append( + f"SUMMARY DATA {self._level1_key}: {level1_title}, {self._level2_key}: {level2_title}, {self._date_key}: {date}" + ) # intiialization node_context: str = "" @@ -236,21 +259,45 @@ def _prepare_context_str( level2_title, {} ) - if date in nested_dict: - raw_nodes = grouped_raw_nodes[level1_title][level2_title][date] + dates_modified = process_dates([date], self._d) + # if date in nested_dict: + + # if they had any intersect + if set(nested_dict.keys()) & set(dates_modified): + raw_nodes = [] + for date in dates_modified: + nodes = grouped_raw_nodes[level1_title][level2_title].get( + date, [] + ) + raw_nodes.extend(nodes) + + logging.info( + f"len(raw_nodes) {len(raw_nodes)} for " + f"{self._level1_key}: {level1_title}, " + f"{self._level2_key}: {level2_title}, " + f"{self._date_key} range: {dates_modified[0]} - {dates_modified[-1]}" + ) node_context: str = ( f"{self._level1_key}: {level1_title}\n" f"{self._level2_key}: {level2_title}\n" - f"{self._date_key}: {date}\n" + f"{self._date_key} range: {dates_modified[0]} - {dates_modified[-1]}\n" f"summary: {summary_node.text}\n" "messages:\n" ) node_context += self._prepare_prompt_with_metadata_info( raw_nodes, prefix=" " ) + if node_context == "": + logging.warning( + "Error empty node_context for " + f"{self._level1_key}: {level1_title}, " + f"{self._level2_key}: {level2_title}, " + f"{self._date_key}: {date}" + ) + else: + context_str += node_context + "\n" - context_str += node_context - + # logging.info(f"summary_log_data {summary_log_data}") logging.info(f"||||||||context_str|||||||| {context_str} |||||||") return context_str @@ -278,8 +325,7 @@ def _group_nodes_per_metadata( grouped_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} for node in raw_nodes: level1_title = node.metadata[self._level1_key] - # TODO: remove the _name when the data got updated - level2_title = node.metadata[self._level2_key + "_name"] + level2_title = node.metadata[self._level2_key] date_str = node.metadata[self._date_key] date = parser.parse(date_str).strftime("%Y-%m-%d") @@ -301,7 +347,8 @@ def _prepare_prompt_with_metadata_info( """ context_str = "\n".join( [ - "author: " + prefix + + "author: " + node.metadata["author_username"] + "\n" + prefix From 94fa8408dea4f9a04e40450fc6ec7080a65d129d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 7 Feb 2024 16:37:13 +0330 Subject: [PATCH 03/13] update: code and logs cleaning! --- bot/retrievers/retrieve_similar_nodes.py | 1 - .../level_based_platform_query_engine.py | 54 ++++++------------- 2 files changed, 16 insertions(+), 39 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 51e89a1..7918fe0 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -98,7 +98,6 @@ def query_db( conditions.append(and_(*filters_and)) stmt = stmt.where(or_(*conditions)) - print("filters", filters) stmt = stmt.limit(self._similarity_top_k) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index fbd93c1..0cbd201 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -10,7 +10,7 @@ from llama_index.query_engine import CustomQueryEngine from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.retrievers import BaseRetriever -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.schema import NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess @@ -39,7 +39,7 @@ def custom_query(self, query_str: str): self._vector_store, self._similarity_top_k, ) - logging.info(f"self._filters {self._filters}") + logging.debug(f"retrieval database filters {self._filters}") similar_nodes = retriever.query_db( query=query_str, filters=self._filters, date_interval=self._d ) @@ -47,7 +47,7 @@ def custom_query(self, query_str: str): context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) - # logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") + return str(response) @classmethod @@ -169,7 +169,6 @@ def prepare_engine_auto_filter( the created query engine with the filters """ dbname = f"community_{community_id}" - summary_similarity_top_k, _, d = load_hyperparams() # For summaries data a posfix `summary` would be added @@ -196,12 +195,6 @@ def prepare_engine_auto_filter( cls._date_key = date_key cls._d = d - # # getting all the metadata dates from filters - # dates: list[str] = [f[date_key] for f in filters] - # dates_modified = process_dates(list(dates), d) - # dates_filter = [{date_key: date} for date in dates_modified] - # filters.extend(dates_filter) - logging.info(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( @@ -217,8 +210,6 @@ def _prepare_context_str( """ prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information """ - logging.info(f"START len(raw_nodes) {len(raw_nodes)}") - logging.info(f"START len(summary_nodes) {len(summary_nodes)}") context_str: str = "" if summary_nodes == []: @@ -230,27 +221,12 @@ def _prepare_context_str( else: grouped_raw_nodes = self._group_nodes_per_metadata(raw_nodes) - raw_data_logs: list[str] = [] - for level1_title in grouped_raw_nodes: - for level2_title in grouped_raw_nodes[level1_title]: - for date in grouped_raw_nodes[level1_title][level2_title]: - raw_data_logs.append( - f"GROUPED RAW DATA {self._level1_key}: {level1_title}, {self._level2_key}: {level2_title}, {self._date_key}: {date}" - f", Message count {len(grouped_raw_nodes[level1_title][level2_title][date])}" - ) - - logging.info(f"raw_data_logs {raw_data_logs}") - - summary_log_data: list[str] = [] for summary_node in summary_nodes: # can be thread_title for discord level1_title = summary_node.metadata[self._level1_key] # can be channel_title for discord level2_title = summary_node.metadata[self._level2_key] date = summary_node.metadata[self._date_key] - summary_log_data.append( - f"SUMMARY DATA {self._level1_key}: {level1_title}, {self._level2_key}: {level2_title}, {self._date_key}: {date}" - ) # intiialization node_context: str = "" @@ -271,12 +247,6 @@ def _prepare_context_str( ) raw_nodes.extend(nodes) - logging.info( - f"len(raw_nodes) {len(raw_nodes)} for " - f"{self._level1_key}: {level1_title}, " - f"{self._level2_key}: {level2_title}, " - f"{self._date_key} range: {dates_modified[0]} - {dates_modified[-1]}" - ) node_context: str = ( f"{self._level1_key}: {level1_title}\n" f"{self._level2_key}: {level2_title}\n" @@ -288,17 +258,25 @@ def _prepare_context_str( raw_nodes, prefix=" " ) if node_context == "": - logging.warning( - "Error empty node_context for " + logging.debug( + "No messages fetched for " f"{self._level1_key}: {level1_title}, " f"{self._level2_key}: {level2_title}, " f"{self._date_key}: {date}" + " of summaries data" + ) + if node_context != "": + logging.debug( + f"{len(raw_nodes)} messages fetched for " + f"{self._level1_key}: {level1_title}, " + f"{self._level2_key}: {level2_title}, " + f"{self._date_key}: {date}" + " of summaries data" ) - else: context_str += node_context + "\n" - # logging.info(f"summary_log_data {summary_log_data}") - logging.info(f"||||||||context_str|||||||| {context_str} |||||||") + logging.debug(f"context_str of prompt\n" f"{context_str}") + return context_str def _group_nodes_per_metadata( From 9cf224f66abb076e4b0e0729a8bace11d6cf1375 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 8 Feb 2024 17:07:32 +0330 Subject: [PATCH 04/13] feat: Creating the prompt and getting the missing summaries! + Now we added the feature to get the missing summaries. - test cases should be added and we need to update the retrieval of similar nodes in first place to use just the thread summaries. --- bot/retrievers/retrieve_similar_nodes.py | 35 ++- .../level_based_platform_query_engine.py | 209 +++++++----------- utils/query_engine/utils.py | 164 ++++++++++++++ 3 files changed, 275 insertions(+), 133 deletions(-) create mode 100644 utils/query_engine/utils.py diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 7918fe0..5488d03 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -5,7 +5,7 @@ from llama_index.schema import NodeWithScore from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult from llama_index.vector_stores.postgres import DBEmbeddingRow -from sqlalchemy import Date, and_, cast, or_, select, text +from sqlalchemy import Date, and_, null, cast, or_, select, text from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -15,7 +15,7 @@ class RetrieveSimilarNodes: def __init__( self, vector_store: PGVectorStore, - similarity_top_k: int, + similarity_top_k: int | None, embed_model: BaseEmbedding = CohereEmbedding(), ) -> None: """Init params.""" @@ -28,6 +28,7 @@ def query_db( query: str, filters: list[dict[str, str]] | None = None, date_interval: int = 0, + **kwargs ) -> list[NodeWithScore]: """ query database with given filters (similarity search is also done) @@ -44,18 +45,35 @@ def query_db( date_interval : int the number of back and forth days of date default is set to 0 meaning no days back or forward. + **kwargs + ignore_sort : bool + to ignore sort by vector similarity. + Note: This would completely disable the similarity search and + it would just return the results with no ordering. + default is `False`. If `True` the query will be ignored and no embedding of it would be fetched """ + ignore_sort = kwargs.get("ignore_sort", False) self._vector_store._initialize() - embedding = self._embed_model.get_text_embedding(text=query) + + if not ignore_sort: + embedding = self._embed_model.get_text_embedding(text=query) + else: + embedding = None + stmt = select( # type: ignore self._vector_store._table_class.id, self._vector_store._table_class.node_id, self._vector_store._table_class.text, self._vector_store._table_class.metadata_, - self._vector_store._table_class.embedding.cosine_distance(embedding).label( - "distance" - ), - ).order_by(text("distance asc")) + ( + self._vector_store._table_class.embedding.cosine_distance(embedding) + if not ignore_sort + else null() + ).label("distance"), + ) + + if not ignore_sort: + stmt = stmt.order_by(text("distance asc")) if filters is not None and filters != []: conditions = [] @@ -99,7 +117,8 @@ def query_db( stmt = stmt.where(or_(*conditions)) - stmt = stmt.limit(self._similarity_top_k) + if self._similarity_top_k is not None: + stmt = stmt.limit(self._similarity_top_k) with self._vector_store._session() as session, session.begin(): res = session.execute(stmt) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 0cbd201..5a06489 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -1,12 +1,14 @@ -from dateutil import parser import logging from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from bot.retrievers.process_dates import process_dates +from utils.query_engine.utils import ( + LevelBasedPlatformUtils, +) from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes from bot.retrievers.utils.load_hyperparams import load_hyperparams from llama_index.llms import OpenAI from llama_index.prompts import PromptTemplate +from llama_index import VectorStoreIndex from llama_index.query_engine import CustomQueryEngine from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.retrievers import BaseRetriever @@ -36,10 +38,10 @@ def custom_query(self, query_str: str): """Doing custom query""" # first retrieving similar nodes in summary retriever = RetrieveSimilarNodes( - self._vector_store, + self._raw_vector_store, self._similarity_top_k, ) - logging.debug(f"retrieval database filters {self._filters}") + similar_nodes = retriever.query_db( query=query_str, filters=self._filters, date_interval=self._d ) @@ -102,18 +104,15 @@ def prepare_platform_engine( llm = kwargs.get("llm", OpenAI("gpt-4")) qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) - pg_vector = PGVectorAccess( - table_name=platform_table_name, - dbname=dbname, - testing=testing, - embed_model=CohereEmbedding(), - ) - index = pg_vector.load_index() + index = cls._setup_vector_store_index(platform_table_name, dbname, testing) retriever = index.as_retriever() _, similarity_top_k, d = load_hyperparams() cls._d = d - cls._vector_store = index.vector_store + cls._raw_vector_store = index._vector_store + cls._summary_vector_store = cls._setup_vector_store_index( + platform_table_name + "_summary", dbname, testing + )._vector_store cls._similarity_top_k = similarity_top_k cls._filters = filters @@ -190,12 +189,14 @@ def prepare_engine_auto_filter( else: cls.summary_nodes = [] + cls._utils_class = LevelBasedPlatformUtils(level1_key, level2_key, date_key) cls._level1_key = level1_key cls._level2_key = level2_key cls._date_key = date_key cls._d = d + cls._platform_table_name = platform_table_name - logging.info(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") + logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=community_id, @@ -217,127 +218,85 @@ def _prepare_context_str( "Empty context_nodes. Cannot add summaries as context information!" ) - context_str += self._prepare_prompt_with_metadata_info(nodes=raw_nodes) + context_str += self._utils_class.prepare_prompt_with_metadata_info( + nodes=raw_nodes + ) else: - grouped_raw_nodes = self._group_nodes_per_metadata(raw_nodes) - - for summary_node in summary_nodes: - # can be thread_title for discord - level1_title = summary_node.metadata[self._level1_key] - # can be channel_title for discord - level2_title = summary_node.metadata[self._level2_key] - date = summary_node.metadata[self._date_key] - - # intiialization - node_context: str = "" + # grouping the data we have so we could + # get them per each metadata without search + ( + grouped_raw_nodes, + grouped_summary_nodes, + ) = self._group_summary_and_raw_nodes(raw_nodes, summary_nodes) + + # first using the available summary nodes try to create prompt + context_data, ( + summary_nodes_to_fetch_filters, + raw_nodes_missed, + ) = self._utils_class.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + context_str += context_data - nested_dict = grouped_raw_nodes.get(level1_title, {}).get( - level2_title, {} + logging.info( + f"summary_nodes_to_fetch_filters {summary_nodes_to_fetch_filters}" + ) + # then if there was some missing summaries + if len(summary_nodes_to_fetch_filters): + retriever = RetrieveSimilarNodes( + self._summary_vector_store, + similarity_top_k=None, ) - - dates_modified = process_dates([date], self._d) - # if date in nested_dict: - - # if they had any intersect - if set(nested_dict.keys()) & set(dates_modified): - raw_nodes = [] - for date in dates_modified: - nodes = grouped_raw_nodes[level1_title][level2_title].get( - date, [] - ) - raw_nodes.extend(nodes) - - node_context: str = ( - f"{self._level1_key}: {level1_title}\n" - f"{self._level2_key}: {level2_title}\n" - f"{self._date_key} range: {dates_modified[0]} - {dates_modified[-1]}\n" - f"summary: {summary_node.text}\n" - "messages:\n" - ) - node_context += self._prepare_prompt_with_metadata_info( - raw_nodes, prefix=" " - ) - if node_context == "": - logging.debug( - "No messages fetched for " - f"{self._level1_key}: {level1_title}, " - f"{self._level2_key}: {level2_title}, " - f"{self._date_key}: {date}" - " of summaries data" - ) - if node_context != "": - logging.debug( - f"{len(raw_nodes)} messages fetched for " - f"{self._level1_key}: {level1_title}, " - f"{self._level2_key}: {level2_title}, " - f"{self._date_key}: {date}" - " of summaries data" - ) - context_str += node_context + "\n" + fetched_summary_nodes = retriever.query_db( + query="", + filters=summary_nodes_to_fetch_filters, + ignore_sort=True, + ) + logging.info(f"len(fetched_summary_nodes) {len(fetched_summary_nodes)}") + logging.info(f"fetched_summary_nodes {fetched_summary_nodes}") + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + fetched_summary_nodes + ) + logging.info(f"grouped_summary_nodes {grouped_summary_nodes}") + logging.info(f"len(grouped_summary_nodes) {len(grouped_summary_nodes)}") + context_data, ( + summary_nodes_to_fetch_filters, + _, + ) = self._utils_class.prepare_context_str_based_on_summaries( + raw_nodes_missed, grouped_summary_nodes + ) + context_str += context_data logging.debug(f"context_str of prompt\n" f"{context_str}") return context_str - def _group_nodes_per_metadata( - self, raw_nodes: list[NodeWithScore] - ) -> dict[str, dict[str, dict[str, list[NodeWithScore]]]]: - """ - group all nodes based on their level1 and level2 metadata - - Parameters - ----------- - raw_nodes : list[NodeWithScore] - a list of raw nodes - - Returns - --------- - grouped_nodes : dict[str, dict[str, dict[str, list[NodeWithScore]]]] - a list of nodes grouped by - - `level1_key` - - `level2_key` - - and the last dict key `date_key` - - The values of the nested dictionary are the nodes grouped - """ - grouped_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} - for node in raw_nodes: - level1_title = node.metadata[self._level1_key] - level2_title = node.metadata[self._level2_key] - date_str = node.metadata[self._date_key] - date = parser.parse(date_str).strftime("%Y-%m-%d") - - # defining an empty list (if keys weren't previously made) - grouped_nodes.setdefault(level1_title, {}).setdefault( - level2_title, {} - ).setdefault(date, []) - # Adding to list - grouped_nodes[level1_title][level2_title][date].append(node) - - return grouped_nodes - - def _prepare_prompt_with_metadata_info( - self, nodes: list[NodeWithScore], prefix: str = "" - ) -> str: + @classmethod + def _setup_vector_store_index( + cls, platform_table_name: str, dbname: str, testing: str + ) -> VectorStoreIndex: """ - prepare a prompt with given nodes including the nodes' metadata - Note: the prefix is set before each text! + prepare the vector_store for querying data """ - context_str = "\n".join( - [ - prefix - + "author: " - + node.metadata["author_username"] - + "\n" - + prefix - + "message_date: " - + node.metadata["date"] - + "\n" - + prefix - + f"message {idx + 1}: " - + node.get_content() - for idx, node in enumerate(nodes) - ] + pg_vector = PGVectorAccess( + table_name=platform_table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), ) + index = pg_vector.load_index() + return index - return context_str + def _group_summary_and_raw_nodes( + self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] + ) -> tuple[ + dict[str, dict[str, dict[str, list[NodeWithScore]]]], + dict[str, dict[str, dict[str, list[NodeWithScore]]]], + ]: + """a wrapper to do the grouping of given nodes""" + grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + summary_nodes + ) + + return grouped_raw_nodes, grouped_summary_nodes diff --git a/utils/query_engine/utils.py b/utils/query_engine/utils.py new file mode 100644 index 0000000..86ec1c1 --- /dev/null +++ b/utils/query_engine/utils.py @@ -0,0 +1,164 @@ +import logging + +from dateutil import parser +from llama_index.schema import NodeWithScore + + +class LevelBasedPlatformUtils: + def __init__(self, level1_key: str, level2_key: str, date_key: str) -> None: + self.level1_key = level1_key + self.level2_key = level2_key + self.date_key = date_key + + def prepare_prompt_with_metadata_info( + self, nodes: list[NodeWithScore], prefix: str = "" + ) -> str: + """ + prepare a prompt with given nodes including the nodes' metadata + Note: the prefix is set before each text! + """ + context_str = "\n".join( + [ + prefix + + "author: " + + node.metadata["author_username"] + + "\n" + + prefix + + "message_date: " + + node.metadata["date"] + + "\n" + + prefix + + f"message {idx + 1}: " + + node.get_content() + for idx, node in enumerate(nodes) + ] + ) + + return context_str + + def group_nodes_per_metadata( + self, + nodes: list[NodeWithScore], + ) -> dict[str, dict[str, dict[str, list[NodeWithScore]]]]: + """ + group all nodes based on their level1 and level2 metadata + + Parameters + ----------- + nodes : list[NodeWithScore] + a list of raw nodes + + Returns + --------- + grouped_nodes : dict[str, dict[str, dict[str, list[NodeWithScore]]]] + a list of nodes grouped by + - `level1_key` + - `level2_key` + - and the last dict key `date_key` + + The values of the nested dictionary are the nodes grouped + """ + grouped_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} + for node in nodes: + level1_title = node.metadata[self.level1_key] + level2_title = node.metadata[self.level2_key] + date_str = node.metadata[self.date_key] + date = parser.parse(date_str).strftime("%Y-%m-%d") + + # defining an empty list (if keys weren't previously made) + grouped_nodes.setdefault(level1_title, {}).setdefault( + level2_title, {} + ).setdefault(date, []) + # Adding to list + grouped_nodes[level1_title][level2_title][date].append(node) + + return grouped_nodes + + def prepare_context_str_based_on_summaries( + self, + grouped_raw_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]], + grouped_summary_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]], + ) -> tuple[ + str, + tuple[ + list[dict[str, str | None]], + dict[str, dict[str, dict[str, list[NodeWithScore]]]], + ], + ]: + """ + prepare prompt context having the summaries within it + """ + context_str: str = "" + + summary_nodes_to_fetch_filters: list[dict[str, str | None]] = [] + # in case of summary wasn't available for them + raw_nodes_missed: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} + + for level1_title in grouped_raw_nodes: + for level2_title in grouped_raw_nodes[level1_title]: + for date in grouped_raw_nodes[level1_title][level2_title]: + raw_nodes = grouped_raw_nodes[level1_title][level2_title][date] + + # the summary_nodes should be always 0 or 1 node + summary_nodes = ( + grouped_summary_nodes.get(level1_title, {}) + .get(level2_title, {}) + .get(date, []) + ) + if len(summary_nodes) == 1: + # if len(summary_nodes) == 1 or len(summary_nodes) == 2: + logging.debug( + f"{len(raw_nodes)} messages available for " + f"{self.level1_key}: {level1_title}, " + f"{self.level2_key}: {level2_title}, " + f"{self.date_key}: {date}" + ) + summary_node = summary_nodes[0] + + node_context: str = ( + f"{self.level1_key}: {level1_title}\n" + f"{self.level2_key}: {level2_title}\n" + f"{self.date_key}: {date}\n" + f"summary: {summary_node.text}\n" + "messages:\n" + ) + node_context += self.prepare_prompt_with_metadata_info( + raw_nodes, prefix=" " + ) + + context_str += node_context + "\n" + elif len(summary_nodes) == 0: + logging.info( + "No summary messages available for " + f"{self.level1_key}: {level1_title}, " + f"{self.level2_key}: {level2_title}, " + f"{self.date_key}: {date}" + ) + summary_nodes_to_fetch_filters.append( + { + self.level1_key: level1_title, + self.level2_key: level2_title, + self.date_key: date, + } + ) + raw_nodes_missed.setdefault(level1_title, {}).setdefault( + level2_title, {} + ).setdefault(date, []) + raw_nodes_missed[level1_title][level2_title][date].extend( + raw_nodes + ) + else: + logging.info(f"len(summary_nodes) {len(summary_nodes)}") + # UU_nodes = [] + # UU_nodes_text = "" + # for node in summary_nodes: + # UU_nodes.append(node.metadata) + # UU_nodes_text += node.text + "\n\n" + # logging.info(f"summary_nodes UU_nodes: {UU_nodes} || UU_nodes_text {UU_nodes_text}") + raise ValueError( + "Not possible to have multiple summaries for a" + f" combination of " + f"{self.level1_key}-{self.level2_key}-{self.date_key}" + ) + + return context_str, (summary_nodes_to_fetch_filters, raw_nodes_missed) From 042d92979818b9980fa383001b26c534577932ae Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 09:06:13 +0330 Subject: [PATCH 05/13] feat: Added manual initial summary node retrieval! --- bot/retrievers/retrieve_similar_nodes.py | 13 +++-- .../level_based_platform_query_engine.py | 51 +++++++++++++------ utils/query_engine/utils.py | 7 --- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 5488d03..6397baa 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -26,7 +26,7 @@ def __init__( def query_db( self, query: str, - filters: list[dict[str, str]] | None = None, + filters: list[dict[str, str | dict | None]] | None = None, date_interval: int = 0, **kwargs ) -> list[NodeWithScore]: @@ -37,11 +37,13 @@ def query_db( ------------- query : str the user question - filters : list[dict[str, str]] | None + filters : list[dict[str, str | dict | None]] | None a list of filters to apply with `or` condition the dictionary would be applying `and` operation between keys and values of json metadata_ - if `None` then no filtering would be applied + the value can be a dictionary with one key of "ne" and a value + which means to do a not equal operator `!=` + if `None` then no filtering would be applied. date_interval : int the number of back and forth days of date default is set to 0 meaning no days back or forward. @@ -106,10 +108,11 @@ def query_db( filter_condition = ( self._vector_store._table_class.metadata_.op("->>")(key) == value - if value is not None + if not isinstance(value, dict) else self._vector_store._table_class.metadata_.op("->>")( key - ).is_(None) + ) + != value["ne"] ) filters_and.append(filter_condition) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 5a06489..4b9178a 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -90,6 +90,12 @@ def prepare_platform_engine( qa_prompt : llama-index.prompts.PromptTemplate the Q&A prompt to use default would be the default prompt of llama-index + index_raw : VectorStoreIndex + the vector store index for raw data + If not passed, it would just create one itself + index_summary : VectorStoreIndex + the vector store index for summary data + If not passed, it would just create one itself Returns --------- @@ -103,16 +109,23 @@ def prepare_platform_engine( ) llm = kwargs.get("llm", OpenAI("gpt-4")) qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) - - index = cls._setup_vector_store_index(platform_table_name, dbname, testing) + index: VectorStoreIndex = kwargs.get( + "index_raw", + cls._setup_vector_store_index(platform_table_name, dbname, testing), + ) retriever = index.as_retriever() + cls._summary_vector_store = kwargs.get( + "index_summary", + cls._setup_vector_store_index( + platform_table_name + "_summary", dbname, testing + ), + )._vector_store + _, similarity_top_k, d = load_hyperparams() cls._d = d cls._raw_vector_store = index._vector_store - cls._summary_vector_store = cls._setup_vector_store_index( - platform_table_name + "_summary", dbname, testing - )._vector_store + cls._similarity_top_k = similarity_top_k cls._filters = filters @@ -170,19 +183,30 @@ def prepare_engine_auto_filter( dbname = f"community_{community_id}" summary_similarity_top_k, _, d = load_hyperparams() + index_summary = cls._setup_vector_store_index( + platform_table_name + "_summary", dbname, False + ) + vector_store = index_summary._vector_store + + retriever = RetrieveSimilarNodes( + vector_store, + summary_similarity_top_k, + ) + # getting nodes of just thread summaries + nodes = retriever.query_db(query, [{"thread": None}, {"thread": {"ne": None}}]) + # For summaries data a posfix `summary` would be added platform_retriever = ForumBasedSummaryRetriever( table_name=platform_table_name + "_summary", dbname=dbname ) - nodes = platform_retriever.get_similar_nodes(query, summary_similarity_top_k) - filters = platform_retriever.define_filters( nodes, metadata_group1_key=level1_key, metadata_group2_key=level2_key, metadata_date_key=date_key, ) + # saving to add summaries to the context of prompt if include_summary_context: cls.summary_nodes = nodes @@ -202,6 +226,7 @@ def prepare_engine_auto_filter( community_id=community_id, platform_table_name=platform_table_name, filters=filters, + index_summary=index_summary, ) return engine @@ -238,9 +263,6 @@ def _prepare_context_str( ) context_str += context_data - logging.info( - f"summary_nodes_to_fetch_filters {summary_nodes_to_fetch_filters}" - ) # then if there was some missing summaries if len(summary_nodes_to_fetch_filters): retriever = RetrieveSimilarNodes( @@ -252,13 +274,9 @@ def _prepare_context_str( filters=summary_nodes_to_fetch_filters, ignore_sort=True, ) - logging.info(f"len(fetched_summary_nodes) {len(fetched_summary_nodes)}") - logging.info(f"fetched_summary_nodes {fetched_summary_nodes}") grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( fetched_summary_nodes ) - logging.info(f"grouped_summary_nodes {grouped_summary_nodes}") - logging.info(f"len(grouped_summary_nodes) {len(grouped_summary_nodes)}") context_data, ( summary_nodes_to_fetch_filters, _, @@ -273,7 +291,10 @@ def _prepare_context_str( @classmethod def _setup_vector_store_index( - cls, platform_table_name: str, dbname: str, testing: str + cls, + platform_table_name: str, + dbname: str, + testing: bool = False, ) -> VectorStoreIndex: """ prepare the vector_store for querying data diff --git a/utils/query_engine/utils.py b/utils/query_engine/utils.py index 86ec1c1..166bec7 100644 --- a/utils/query_engine/utils.py +++ b/utils/query_engine/utils.py @@ -106,7 +106,6 @@ def prepare_context_str_based_on_summaries( .get(date, []) ) if len(summary_nodes) == 1: - # if len(summary_nodes) == 1 or len(summary_nodes) == 2: logging.debug( f"{len(raw_nodes)} messages available for " f"{self.level1_key}: {level1_title}, " @@ -149,12 +148,6 @@ def prepare_context_str_based_on_summaries( ) else: logging.info(f"len(summary_nodes) {len(summary_nodes)}") - # UU_nodes = [] - # UU_nodes_text = "" - # for node in summary_nodes: - # UU_nodes.append(node.metadata) - # UU_nodes_text += node.text + "\n\n" - # logging.info(f"summary_nodes UU_nodes: {UU_nodes} || UU_nodes_text {UU_nodes_text}") raise ValueError( "Not possible to have multiple summaries for a" f" combination of " From f1cf283085926223741a68eabfb658da4e6c84dd Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 09:18:58 +0330 Subject: [PATCH 06/13] rename for more clarity! --- utils/query_engine/{utils.py => level_based_platforms_util.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename utils/query_engine/{utils.py => level_based_platforms_util.py} (100%) diff --git a/utils/query_engine/utils.py b/utils/query_engine/level_based_platforms_util.py similarity index 100% rename from utils/query_engine/utils.py rename to utils/query_engine/level_based_platforms_util.py From f8bced627be625cd83780e94ae48d0849826a7ab Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 09:19:51 +0330 Subject: [PATCH 07/13] fix: import due to renaming! --- utils/query_engine/level_based_platform_query_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 4b9178a..562ab2a 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -1,7 +1,7 @@ import logging from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from utils.query_engine.utils import ( +from utils.query_engine.level_based_platforms_util import ( LevelBasedPlatformUtils, ) from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes From 092158b9cc676269259627e8f0c3905e88e1545d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 10:48:19 +0330 Subject: [PATCH 08/13] update: fixing test cases and adding more! --- .gitignore | 2 + .../test_level_based_platform_query_engine.py | 34 +-- tests/unit/test_level_based_platform_util.py | 211 ++++++++++++++++++ .../level_based_platforms_util.py | 1 + 4 files changed, 232 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_level_based_platform_util.py diff --git a/.gitignore b/.gitignore index 0c9e5b4..32ea46e 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,5 @@ cython_debug/ hivemind-bot-env/* main.ipynb .DS_Store + +temp_test_run_data.json \ No newline at end of file diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py index eaa9986..a279b17 100644 --- a/tests/unit/test_level_based_platform_query_engine.py +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -6,6 +6,7 @@ from utils.query_engine.level_based_platform_query_engine import ( LevelBasedPlatformQueryEngine, ) +from sqlalchemy.exc import OperationalError class TestLevelBasedPlatformQueryEngine(unittest.TestCase): @@ -26,9 +27,9 @@ def test_prepare_platform_engine(self): """ # the output should always have a `date` key for each dictionary filters = [ - {"channel": "general", "date": "2023-01-02"}, - {"thread": "discussion", "date": "2024-01-03"}, - {"date": "2022-01-01"}, + {"channel": "general", "thread": "some_thread", "date": "2023-01-02"}, + {"channel": "general", "thread": "discussion", "date": "2024-01-03"}, + {"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"}, ] engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( @@ -44,21 +45,22 @@ def test_prepare_engine_auto_filter(self): Test prepare_engine_auto_filter method with sample data """ with patch.object( - ForumBasedSummaryRetriever, "retreive_filtering" + ForumBasedSummaryRetriever, "define_filters" ) as mock_retriever: # the output should always have a `date` key for each dictionary mock_retriever.return_value = [ - {"channel": "general", "date": "2023-01-02"}, - {"thread": "discussion", "date": "2024-01-03"}, - {"date": "2022-01-01"}, + {"channel": "general", "thread": "some_thread", "date": "2023-01-02"}, + {"channel": "general", "thread": "discussion", "date": "2024-01-03"}, + {"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"}, ] - engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( - community_id=self.community_id, - query="test query", - platform_table_name=self.platform_table_name, - level1_key=self.level1_key, - level2_key=self.level2_key, - date_key=self.date_key, - ) - self.assertIsNotNone(engine) + with self.assertRaises(OperationalError): + # no database with name of `test_community` is available + _ = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + ) diff --git a/tests/unit/test_level_based_platform_util.py b/tests/unit/test_level_based_platform_util.py new file mode 100644 index 0000000..b0df7c0 --- /dev/null +++ b/tests/unit/test_level_based_platform_util.py @@ -0,0 +1,211 @@ +import unittest +from llama_index.schema import NodeWithScore, TextNode +from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils + + +class TestLevelBasedPlatformUtils(unittest.TestCase): + def setUp(self): + self.level1_key = "channel" + self.level2_key = "thread" + self.date_key = "date" + self.utils = LevelBasedPlatformUtils( + self.level1_key, self.level2_key, self.date_key + ) + + def test_prepare_prompt_with_metadata_info(self): + nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={"author_username": "user1", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={"author_username": "user2", "date": "2022-01-02"}, + ), + score=0, + ), + ] + prefix = " " + expected_output = ( + " author: user1\n message_date: 2022-01-01\n message 1: content1\n" + " author: user2\n message_date: 2022-01-02\n message 2: content2" + ) + result = self.utils.prepare_prompt_with_metadata_info(nodes, prefix) + self.assertEqual(result, expected_output) + + def test_group_nodes_per_metadata(self): + nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={"channel": "A", "thread": "Y", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content3", + metadata={"channel": "B", "thread": "X", "date": "2022-01-02"}, + ), + score=0, + ), + ] + expected_output = { + "A": {"X": {"2022-01-01": [nodes[0]]}, "Y": {"2022-01-01": [nodes[1]]}}, + "B": {"X": {"2022-01-02": [nodes[2]]}}, + } + result = self.utils.group_nodes_per_metadata(nodes) + self.assertEqual(result, expected_output) + + def test_prepare_context_str_based_on_summaries(self): + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="raw_content1", + metadata={ + "channel": "A", + "thread": "X", + "date": "2022-01-01", + "author_username": "USERNAME#1", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="raw_content2", + metadata={ + "channel": "A", + "thread": "Y", + "date": "2022-01-04", + "author_username": "USERNAME#2", + }, + ), + score=0, + ), + ] + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="summary_content", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ) + ] + grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}} + grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}} + expected_output = ( + """channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n""" + """ author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n""" + """ author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n""" + ) + result, _ = self.utils.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + self.assertEqual(result.strip(), expected_output.strip()) + + def test_prepare_context_str_based_on_summaries_no_summary(self): + node1 = NodeWithScore( + node=TextNode( + text="raw_content1", + metadata={ + "channel": "A", + "thread": "X", + "date": "2022-01-01", + "author_username": "USERNAME#1", + }, + ), + score=0, + ) + node2 = NodeWithScore( + node=TextNode( + text="raw_content2", + metadata={ + "channel": "A", + "thread": "Y", + "date": "2022-01-04", + "author_username": "USERNAME#2", + }, + ), + score=0, + ) + grouped_raw_nodes = { + "A": {"X": {"2022-01-01": [node1]}, "Y": {"2022-01-04": [node2]}} + } + grouped_summary_nodes = {} + result, ( + summary_nodes_to_fetch_filters, + raw_nodes_missed, + ) = self.utils.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + self.assertEqual(result, "") + self.assertEqual(len(summary_nodes_to_fetch_filters), 2) + for channel in raw_nodes_missed.keys(): + self.assertIn(channel, ["A"]) + for thread in raw_nodes_missed[channel].keys(): + self.assertIn(thread, ["X", "Y"]) + for date in raw_nodes_missed[channel][thread]: + self.assertIn(date, ["2022-01-01", "2022-01-04"]) + nodes = raw_nodes_missed[channel][thread][date] + + if date == "2022-01-01": + self.assertEqual( + grouped_raw_nodes["A"]["X"]["2022-01-01"], nodes + ) + elif date == "2022-01-04": + self.assertEqual( + grouped_raw_nodes["A"]["Y"]["2022-01-04"], nodes + ) + + def test_prepare_context_str_based_on_summaries_multiple_summaries_error(self): + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="raw_content1", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="raw_content2", + metadata={"channel": "A", "thread": "Y", "date": "2022-01-01"}, + ), + score=0, + ), + ] + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="summary_content1", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="summary_content2", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + ] + grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}} + grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}} + with self.assertRaises(ValueError): + self.utils.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index 166bec7..b1730fd 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -132,6 +132,7 @@ def prepare_context_str_based_on_summaries( f"{self.level1_key}: {level1_title}, " f"{self.level2_key}: {level2_title}, " f"{self.date_key}: {date}" + "\t will fetch them after" ) summary_nodes_to_fetch_filters.append( { From 66411c1c4a040a5c8ef595caea5840be12e38260 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 15:06:01 +0330 Subject: [PATCH 09/13] fix: linter issues based on superlinter rules! --- bot/retrievers/forum_summary_retriever.py | 2 +- bot/retrievers/retrieve_similar_nodes.py | 7 +++-- .../test_level_based_platform_query_engine.py | 2 +- tests/unit/test_level_based_platform_util.py | 1 + .../level_based_platform_query_engine.py | 30 +++++-------------- 5 files changed, 14 insertions(+), 28 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 31468fc..7cb2982 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -1,7 +1,7 @@ from bot.retrievers.summary_retriever_base import BaseSummarySearch from llama_index.embeddings import BaseEmbedding -from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from llama_index.schema import NodeWithScore +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding class ForumBasedSummaryRetriever(BaseSummarySearch): diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 6397baa..6c168a0 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -1,11 +1,11 @@ from datetime import timedelta -from dateutil import parser +from dateutil import parser from llama_index.embeddings import BaseEmbedding from llama_index.schema import NodeWithScore from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult from llama_index.vector_stores.postgres import DBEmbeddingRow -from sqlalchemy import Date, and_, null, cast, or_, select, text +from sqlalchemy import Date, and_, cast, null, or_, select, text from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -84,7 +84,8 @@ def query_db( for key, value in condition.items(): if key == "date": # Apply ::date cast when the key is 'date' - date = parser.parse(value) + # The value should be always str + date = parser.parse(value) # flake8: noqa date_back = (date - timedelta(days=date_interval)).strftime( "%Y-%m-%d" ) diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py index a279b17..d4a1286 100644 --- a/tests/unit/test_level_based_platform_query_engine.py +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -3,10 +3,10 @@ from unittest.mock import patch from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from sqlalchemy.exc import OperationalError from utils.query_engine.level_based_platform_query_engine import ( LevelBasedPlatformQueryEngine, ) -from sqlalchemy.exc import OperationalError class TestLevelBasedPlatformQueryEngine(unittest.TestCase): diff --git a/tests/unit/test_level_based_platform_util.py b/tests/unit/test_level_based_platform_util.py index b0df7c0..57f0c22 100644 --- a/tests/unit/test_level_based_platform_util.py +++ b/tests/unit/test_level_based_platform_util.py @@ -1,4 +1,5 @@ import unittest + from llama_index.schema import NodeWithScore, TextNode from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 562ab2a..95d0fb4 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -1,20 +1,18 @@ import logging from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from utils.query_engine.level_based_platforms_util import ( - LevelBasedPlatformUtils, -) from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index import VectorStoreIndex from llama_index.llms import OpenAI from llama_index.prompts import PromptTemplate -from llama_index import VectorStoreIndex from llama_index.query_engine import CustomQueryEngine from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.retrievers import BaseRetriever from llama_index.schema import NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils qa_prompt = PromptTemplate( "Context information is below.\n" @@ -248,11 +246,11 @@ def _prepare_context_str( ) else: # grouping the data we have so we could - # get them per each metadata without search - ( - grouped_raw_nodes, - grouped_summary_nodes, - ) = self._group_summary_and_raw_nodes(raw_nodes, summary_nodes) + # get them per each metadata without looping over them + grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + summary_nodes + ) # first using the available summary nodes try to create prompt context_data, ( @@ -307,17 +305,3 @@ def _setup_vector_store_index( ) index = pg_vector.load_index() return index - - def _group_summary_and_raw_nodes( - self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] - ) -> tuple[ - dict[str, dict[str, dict[str, list[NodeWithScore]]]], - dict[str, dict[str, dict[str, list[NodeWithScore]]]], - ]: - """a wrapper to do the grouping of given nodes""" - grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) - grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( - summary_nodes - ) - - return grouped_raw_nodes, grouped_summary_nodes From b9f5d7c30da24f93f8e3f9f7c211483e0046f6fe Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 16:02:51 +0330 Subject: [PATCH 10/13] update: Added more tests to increase test coverage! --- .../test_level_based_platform_query_engine.py | 48 ++++- ...d_platform_query_engine_prepare_context.py | 200 ++++++++++++++++++ tests/unit/test_level_based_platform_util.py | 10 +- .../level_based_platforms_util.py | 1 + 4 files changed, 253 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_level_based_platform_query_engine_prepare_context.py diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py index d4a1286..15d1243 100644 --- a/tests/unit/test_level_based_platform_query_engine.py +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -3,6 +3,8 @@ from unittest.mock import patch from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode from sqlalchemy.exc import OperationalError from utils.query_engine.level_based_platform_query_engine import ( LevelBasedPlatformQueryEngine, @@ -40,9 +42,10 @@ def test_prepare_platform_engine(self): ) self.assertIsNotNone(engine) - def test_prepare_engine_auto_filter(self): + def test_prepare_engine_auto_filter_raise_error(self): """ Test prepare_engine_auto_filter method with sample data + when an error was raised """ with patch.object( ForumBasedSummaryRetriever, "define_filters" @@ -64,3 +67,46 @@ def test_prepare_engine_auto_filter(self): level2_key=self.level2_key, date_key=self.date_key, ) + + def test_prepare_engine_auto_filter(self): + """ + Test prepare_engine_auto_filter method with sample data in normal condition + """ + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + # the output should always have a `date` key for each dictionary + mock_query.return_value = [ + NodeWithScore( + node=TextNode( + text="some summaries #1", + metadata={ + "thread": "thread#1", + "channel": "channel#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="some summaries #2", + metadata={ + "thread": "thread#3", + "channel": "channel#2", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + # no database with name of `test_community` is available + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + self.assertIsNotNone(engine) diff --git a/tests/unit/test_level_based_platform_query_engine_prepare_context.py b/tests/unit/test_level_based_platform_query_engine_prepare_context.py new file mode 100644 index 0000000..e28469c --- /dev/null +++ b/tests/unit/test_level_based_platform_query_engine_prepare_context.py @@ -0,0 +1,200 @@ +import os +import unittest +from unittest.mock import patch + +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode +from sqlalchemy.exc import OperationalError +from utils.query_engine.level_based_platform_query_engine import ( + LevelBasedPlatformQueryEngine, +) + + +class TestLevelBasedPlatformQueryEngine(unittest.TestCase): + def setUp(self): + """ + Set up common parameters for testing + """ + self.community_id = "test_community" + self.level1_key = "channel" + self.level2_key = "thread" + self.platform_table_name = "discord" + self.date_key = "date" + os.environ["OPENAI_API_KEY"] = "sk-some_creds" + + def test_prepare_context_str_without_summaries(self): + """ + test prepare the context string while not having the summaries nodes + """ + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + summary_nodes = [] + mock_query.return_value = summary_nodes + + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={ + "author_username": "user1", + "channel": "channel#1", + "thread": "thread#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={ + "author_username": "user2", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content4", + metadata={ + "author_username": "user3", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + contest_str = engine._prepare_context_str(raw_nodes, summary_nodes) + expected_context_str = ( + "author: user1\n" + "message_date: 2022-01-01\n" + "message 1: content1\n\n" + "author: user2\n" + "message_date: 2022-01-02\n" + "message 2: content2\n\n" + "author: user3\n" + "message_date: 2022-01-02\n" + "message 3: content4\n" + ) + self.assertEqual(contest_str, expected_context_str) + + def test_prepare_context_str_with_summaries(self): + """ + test prepare the context string having the summaries nodes + """ + + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="some summaries #1", + metadata={ + "thread": "thread#1", + "channel": "channel#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="some summaries #2", + metadata={ + "thread": "thread#3", + "channel": "channel#2", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + mock_query.return_value = summary_nodes + + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={ + "author_username": "user1", + "channel": "channel#1", + "thread": "thread#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={ + "author_username": "user2", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content4", + metadata={ + "author_username": "user3", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + contest_str = engine._prepare_context_str(raw_nodes, summary_nodes) + expected_context_str = ( + "channel: channel#1\n" + "thread: thread#1\n" + "date: 2022-01-01\n" + "summary: some summaries #1\n" + "messages:\n" + " author: user1\n" + " message_date: 2022-01-01\n" + " message 1: content1\n\n" + "channel: channel#2\n" + "thread: thread#3\n" + "date: 2022-01-02\n" + "summary: some summaries #2\n" + "messages:\n" + " author: user2\n" + " message_date: 2022-01-02\n" + " message 1: content2\n\n" + " author: user3\n" + " message_date: 2022-01-02\n" + " message 2: content4\n\n" + ) + self.assertEqual(contest_str, expected_context_str) diff --git a/tests/unit/test_level_based_platform_util.py b/tests/unit/test_level_based_platform_util.py index 57f0c22..cc7d721 100644 --- a/tests/unit/test_level_based_platform_util.py +++ b/tests/unit/test_level_based_platform_util.py @@ -32,8 +32,8 @@ def test_prepare_prompt_with_metadata_info(self): ] prefix = " " expected_output = ( - " author: user1\n message_date: 2022-01-01\n message 1: content1\n" - " author: user2\n message_date: 2022-01-02\n message 2: content2" + " author: user1\n message_date: 2022-01-01\n message 1: content1\n\n" + " author: user2\n message_date: 2022-01-02\n message 2: content2\n" ) result = self.utils.prepare_prompt_with_metadata_info(nodes, prefix) self.assertEqual(result, expected_output) @@ -108,9 +108,9 @@ def test_prepare_context_str_based_on_summaries(self): grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}} grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}} expected_output = ( - """channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n""" - """ author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n""" - """ author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n""" + "channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n" + " author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n\n" + " author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n\n" ) result, _ = self.utils.prepare_context_str_based_on_summaries( grouped_raw_nodes, grouped_summary_nodes diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index b1730fd..d5e6eaa 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -30,6 +30,7 @@ def prepare_prompt_with_metadata_info( + prefix + f"message {idx + 1}: " + node.get_content() + + "\n" for idx, node in enumerate(nodes) ] ) From 16e1032becf85f99fc66e72c98dcbee8f9d6e66e Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 16:05:14 +0330 Subject: [PATCH 11/13] fix: mypy linter issue! --- bot/retrievers/retrieve_similar_nodes.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 6c168a0..c5f5f59 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -1,4 +1,4 @@ -from datetime import timedelta +from datetime import datetime, timedelta from dateutil import parser from llama_index.embeddings import BaseEmbedding @@ -83,9 +83,13 @@ def query_db( filters_and = [] for key, value in condition.items(): if key == "date": - # Apply ::date cast when the key is 'date' - # The value should be always str - date = parser.parse(value) # flake8: noqa + date: datetime + if isinstance(value, str): + date = parser.parse(value) + else: + raise ValueError( + "the values for filtering dates must be string!" + ) date_back = (date - timedelta(days=date_interval)).strftime( "%Y-%m-%d" ) @@ -93,6 +97,7 @@ def query_db( "%Y-%m-%d" ) + # Apply ::date cast when the key is 'date' filter_condition_back = cast( self._vector_store._table_class.metadata_.op("->>")(key), Date, From 4da2292e397560f0c01cb26ae11c3037513cb955 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 16:09:48 +0330 Subject: [PATCH 12/13] fix: flake8 linter, unused imports! --- .../test_level_based_platform_query_engine_prepare_context.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/test_level_based_platform_query_engine_prepare_context.py b/tests/unit/test_level_based_platform_query_engine_prepare_context.py index e28469c..3fa415e 100644 --- a/tests/unit/test_level_based_platform_query_engine_prepare_context.py +++ b/tests/unit/test_level_based_platform_query_engine_prepare_context.py @@ -2,10 +2,8 @@ import unittest from unittest.mock import patch -from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes from llama_index.schema import NodeWithScore, TextNode -from sqlalchemy.exc import OperationalError from utils.query_engine.level_based_platform_query_engine import ( LevelBasedPlatformQueryEngine, ) From 946f0746ca92c6b3725cba0a35cd270f7515dca7 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 12 Feb 2024 18:48:41 +0330 Subject: [PATCH 13/13] feat: getting just the thread summaries! --- utils/query_engine/level_based_platform_query_engine.py | 3 ++- utils/query_engine/level_based_platforms_util.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 95d0fb4..5cb4e9e 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -47,6 +47,7 @@ def custom_query(self, query_str: str): context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) + logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") return str(response) @@ -191,7 +192,7 @@ def prepare_engine_auto_filter( summary_similarity_top_k, ) # getting nodes of just thread summaries - nodes = retriever.query_db(query, [{"thread": None}, {"thread": {"ne": None}}]) + nodes = retriever.query_db(query, [{"type": "thread"}]) # For summaries data a posfix `summary` would be added platform_retriever = ForumBasedSummaryRetriever( diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index d5e6eaa..5ca2147 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -140,6 +140,8 @@ def prepare_context_str_based_on_summaries( self.level1_key: level1_title, self.level2_key: level2_title, self.date_key: date, + # we need the thread summaries + "type": "thread", } ) raw_nodes_missed.setdefault(level1_title, {}).setdefault(