diff --git a/.gitignore b/.gitignore index 0c9e5b4..32ea46e 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,5 @@ cython_debug/ hivemind-bot-env/* main.ipynb .DS_Store + +temp_test_run_data.json \ No newline at end of file diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 6dd56d1..7cb2982 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -1,5 +1,6 @@ from bot.retrievers.summary_retriever_base import BaseSummarySearch from llama_index.embeddings import BaseEmbedding +from llama_index.schema import NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -53,15 +54,44 @@ def retreive_filtering( """ nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) + filters = self.define_filters( + nodes=nodes, + metadata_group1_key=metadata_group1_key, + metadata_group2_key=metadata_group2_key, + metadata_date_key=metadata_date_key, + ) + + return filters + + def define_filters( + self, + nodes: list[NodeWithScore], + metadata_group1_key: str, + metadata_group2_key: str, + metadata_date_key: str, + ) -> list[dict[str, str]]: + """ + define dictionary filters based on metadata of retrieved nodes + + Parameters + ---------- + nodes : list[dict[llama_index.schema.NodeWithScore]] + a list of retrieved similar nodes to define filters based + + Returns + --------- + filters : list[dict[str, str]] + a list of filters to apply with `or` condition + the dictionary would be applying `and` + operation between keys and values of json metadata_ + """ filters: list[dict[str, str]] = [] for node in nodes: # the filter made by given node filter: dict[str, str] = {} - if node.metadata[metadata_group1_key]: - filter[metadata_group1_key] = node.metadata[metadata_group1_key] - if node.metadata[metadata_group2_key]: - filter[metadata_group2_key] = node.metadata[metadata_group2_key] + filter[metadata_group1_key] = node.metadata[metadata_group1_key] + filter[metadata_group2_key] = node.metadata[metadata_group2_key] # date filter filter[metadata_date_key] = node.metadata[metadata_date_key] diff --git a/bot/retrievers/process_dates.py b/bot/retrievers/process_dates.py index dba3217..cb5ecec 100644 --- a/bot/retrievers/process_dates.py +++ b/bot/retrievers/process_dates.py @@ -19,7 +19,9 @@ def process_dates(dates: list[str], d: int) -> list[str]: Returns ---------- dates_modified : list[str] - days added to it + days added to it sorted ascending meaning + the first index is the lowest date + and the last is the biggest date """ dates_modified: list[str] = [] if dates != []: diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index baac30e..c5f5f59 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -1,8 +1,11 @@ +from datetime import datetime, timedelta + +from dateutil import parser from llama_index.embeddings import BaseEmbedding from llama_index.schema import NodeWithScore from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult from llama_index.vector_stores.postgres import DBEmbeddingRow -from sqlalchemy import Date, and_, cast, or_, select, text +from sqlalchemy import Date, and_, cast, null, or_, select, text from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -12,7 +15,7 @@ class RetrieveSimilarNodes: def __init__( self, vector_store: PGVectorStore, - similarity_top_k: int, + similarity_top_k: int | None, embed_model: BaseEmbedding = CohereEmbedding(), ) -> None: """Init params.""" @@ -21,7 +24,11 @@ def __init__( self._similarity_top_k = similarity_top_k def query_db( - self, query: str, filters: list[dict[str, str]] | None = None + self, + query: str, + filters: list[dict[str, str | dict | None]] | None = None, + date_interval: int = 0, + **kwargs ) -> list[NodeWithScore]: """ query database with given filters (similarity search is also done) @@ -30,23 +37,45 @@ def query_db( ------------- query : str the user question - filters : list[dict[str, str]] | None + filters : list[dict[str, str | dict | None]] | None a list of filters to apply with `or` condition the dictionary would be applying `and` operation between keys and values of json metadata_ - if `None` then no filtering would be applied + the value can be a dictionary with one key of "ne" and a value + which means to do a not equal operator `!=` + if `None` then no filtering would be applied. + date_interval : int + the number of back and forth days of date + default is set to 0 meaning no days back or forward. + **kwargs + ignore_sort : bool + to ignore sort by vector similarity. + Note: This would completely disable the similarity search and + it would just return the results with no ordering. + default is `False`. If `True` the query will be ignored and no embedding of it would be fetched """ + ignore_sort = kwargs.get("ignore_sort", False) self._vector_store._initialize() - embedding = self._embed_model.get_text_embedding(text=query) + + if not ignore_sort: + embedding = self._embed_model.get_text_embedding(text=query) + else: + embedding = None + stmt = select( # type: ignore self._vector_store._table_class.id, self._vector_store._table_class.node_id, self._vector_store._table_class.text, self._vector_store._table_class.metadata_, - self._vector_store._table_class.embedding.cosine_distance(embedding).label( - "distance" - ), - ).order_by(text("distance asc")) + ( + self._vector_store._table_class.embedding.cosine_distance(embedding) + if not ignore_sort + else null() + ).label("distance"), + ) + + if not ignore_sort: + stmt = stmt.order_by(text("distance asc")) if filters is not None and filters != []: conditions = [] @@ -54,24 +83,51 @@ def query_db( filters_and = [] for key, value in condition.items(): if key == "date": + date: datetime + if isinstance(value, str): + date = parser.parse(value) + else: + raise ValueError( + "the values for filtering dates must be string!" + ) + date_back = (date - timedelta(days=date_interval)).strftime( + "%Y-%m-%d" + ) + date_forward = (date + timedelta(days=date_interval)).strftime( + "%Y-%m-%d" + ) + # Apply ::date cast when the key is 'date' - filter_condition = cast( + filter_condition_back = cast( + self._vector_store._table_class.metadata_.op("->>")(key), + Date, + ) >= cast(date_back, Date) + + filter_condition_forward = cast( self._vector_store._table_class.metadata_.op("->>")(key), Date, - ) == cast(value, Date) + ) <= cast(date_forward, Date) + + filters_and.append(filter_condition_back) + filters_and.append(filter_condition_forward) else: filter_condition = ( self._vector_store._table_class.metadata_.op("->>")(key) == value + if not isinstance(value, dict) + else self._vector_store._table_class.metadata_.op("->>")( + key + ) + != value["ne"] ) - - filters_and.append(filter_condition) + filters_and.append(filter_condition) conditions.append(and_(*filters_and)) stmt = stmt.where(or_(*conditions)) - stmt = stmt.limit(self._similarity_top_k) + if self._similarity_top_k is not None: + stmt = stmt.limit(self._similarity_top_k) with self._vector_store._session() as session, session.begin(): res = session.execute(stmt) diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py index eaa9986..15d1243 100644 --- a/tests/unit/test_level_based_platform_query_engine.py +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -3,6 +3,9 @@ from unittest.mock import patch from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode +from sqlalchemy.exc import OperationalError from utils.query_engine.level_based_platform_query_engine import ( LevelBasedPlatformQueryEngine, ) @@ -26,9 +29,9 @@ def test_prepare_platform_engine(self): """ # the output should always have a `date` key for each dictionary filters = [ - {"channel": "general", "date": "2023-01-02"}, - {"thread": "discussion", "date": "2024-01-03"}, - {"date": "2022-01-01"}, + {"channel": "general", "thread": "some_thread", "date": "2023-01-02"}, + {"channel": "general", "thread": "discussion", "date": "2024-01-03"}, + {"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"}, ] engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( @@ -39,20 +42,64 @@ def test_prepare_platform_engine(self): ) self.assertIsNotNone(engine) - def test_prepare_engine_auto_filter(self): + def test_prepare_engine_auto_filter_raise_error(self): """ Test prepare_engine_auto_filter method with sample data + when an error was raised """ with patch.object( - ForumBasedSummaryRetriever, "retreive_filtering" + ForumBasedSummaryRetriever, "define_filters" ) as mock_retriever: # the output should always have a `date` key for each dictionary mock_retriever.return_value = [ - {"channel": "general", "date": "2023-01-02"}, - {"thread": "discussion", "date": "2024-01-03"}, - {"date": "2022-01-01"}, + {"channel": "general", "thread": "some_thread", "date": "2023-01-02"}, + {"channel": "general", "thread": "discussion", "date": "2024-01-03"}, + {"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"}, + ] + + with self.assertRaises(OperationalError): + # no database with name of `test_community` is available + _ = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + ) + + def test_prepare_engine_auto_filter(self): + """ + Test prepare_engine_auto_filter method with sample data in normal condition + """ + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + # the output should always have a `date` key for each dictionary + mock_query.return_value = [ + NodeWithScore( + node=TextNode( + text="some summaries #1", + metadata={ + "thread": "thread#1", + "channel": "channel#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="some summaries #2", + metadata={ + "thread": "thread#3", + "channel": "channel#2", + "date": "2022-01-02", + }, + ), + score=0, + ), ] + # no database with name of `test_community` is available engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( community_id=self.community_id, query="test query", @@ -60,5 +107,6 @@ def test_prepare_engine_auto_filter(self): level1_key=self.level1_key, level2_key=self.level2_key, date_key=self.date_key, + include_summary_context=True, ) self.assertIsNotNone(engine) diff --git a/tests/unit/test_level_based_platform_query_engine_prepare_context.py b/tests/unit/test_level_based_platform_query_engine_prepare_context.py new file mode 100644 index 0000000..3fa415e --- /dev/null +++ b/tests/unit/test_level_based_platform_query_engine_prepare_context.py @@ -0,0 +1,198 @@ +import os +import unittest +from unittest.mock import patch + +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode +from utils.query_engine.level_based_platform_query_engine import ( + LevelBasedPlatformQueryEngine, +) + + +class TestLevelBasedPlatformQueryEngine(unittest.TestCase): + def setUp(self): + """ + Set up common parameters for testing + """ + self.community_id = "test_community" + self.level1_key = "channel" + self.level2_key = "thread" + self.platform_table_name = "discord" + self.date_key = "date" + os.environ["OPENAI_API_KEY"] = "sk-some_creds" + + def test_prepare_context_str_without_summaries(self): + """ + test prepare the context string while not having the summaries nodes + """ + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + summary_nodes = [] + mock_query.return_value = summary_nodes + + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={ + "author_username": "user1", + "channel": "channel#1", + "thread": "thread#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={ + "author_username": "user2", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content4", + metadata={ + "author_username": "user3", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + contest_str = engine._prepare_context_str(raw_nodes, summary_nodes) + expected_context_str = ( + "author: user1\n" + "message_date: 2022-01-01\n" + "message 1: content1\n\n" + "author: user2\n" + "message_date: 2022-01-02\n" + "message 2: content2\n\n" + "author: user3\n" + "message_date: 2022-01-02\n" + "message 3: content4\n" + ) + self.assertEqual(contest_str, expected_context_str) + + def test_prepare_context_str_with_summaries(self): + """ + test prepare the context string having the summaries nodes + """ + + with patch.object(RetrieveSimilarNodes, "query_db") as mock_query: + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="some summaries #1", + metadata={ + "thread": "thread#1", + "channel": "channel#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="some summaries #2", + metadata={ + "thread": "thread#3", + "channel": "channel#2", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + mock_query.return_value = summary_nodes + + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, + include_summary_context=True, + ) + + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={ + "author_username": "user1", + "channel": "channel#1", + "thread": "thread#1", + "date": "2022-01-01", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={ + "author_username": "user2", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content4", + metadata={ + "author_username": "user3", + "channel": "channel#2", + "thread": "thread#3", + "date": "2022-01-02", + }, + ), + score=0, + ), + ] + + contest_str = engine._prepare_context_str(raw_nodes, summary_nodes) + expected_context_str = ( + "channel: channel#1\n" + "thread: thread#1\n" + "date: 2022-01-01\n" + "summary: some summaries #1\n" + "messages:\n" + " author: user1\n" + " message_date: 2022-01-01\n" + " message 1: content1\n\n" + "channel: channel#2\n" + "thread: thread#3\n" + "date: 2022-01-02\n" + "summary: some summaries #2\n" + "messages:\n" + " author: user2\n" + " message_date: 2022-01-02\n" + " message 1: content2\n\n" + " author: user3\n" + " message_date: 2022-01-02\n" + " message 2: content4\n\n" + ) + self.assertEqual(contest_str, expected_context_str) diff --git a/tests/unit/test_level_based_platform_util.py b/tests/unit/test_level_based_platform_util.py new file mode 100644 index 0000000..cc7d721 --- /dev/null +++ b/tests/unit/test_level_based_platform_util.py @@ -0,0 +1,212 @@ +import unittest + +from llama_index.schema import NodeWithScore, TextNode +from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils + + +class TestLevelBasedPlatformUtils(unittest.TestCase): + def setUp(self): + self.level1_key = "channel" + self.level2_key = "thread" + self.date_key = "date" + self.utils = LevelBasedPlatformUtils( + self.level1_key, self.level2_key, self.date_key + ) + + def test_prepare_prompt_with_metadata_info(self): + nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={"author_username": "user1", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={"author_username": "user2", "date": "2022-01-02"}, + ), + score=0, + ), + ] + prefix = " " + expected_output = ( + " author: user1\n message_date: 2022-01-01\n message 1: content1\n\n" + " author: user2\n message_date: 2022-01-02\n message 2: content2\n" + ) + result = self.utils.prepare_prompt_with_metadata_info(nodes, prefix) + self.assertEqual(result, expected_output) + + def test_group_nodes_per_metadata(self): + nodes = [ + NodeWithScore( + node=TextNode( + text="content1", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content2", + metadata={"channel": "A", "thread": "Y", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="content3", + metadata={"channel": "B", "thread": "X", "date": "2022-01-02"}, + ), + score=0, + ), + ] + expected_output = { + "A": {"X": {"2022-01-01": [nodes[0]]}, "Y": {"2022-01-01": [nodes[1]]}}, + "B": {"X": {"2022-01-02": [nodes[2]]}}, + } + result = self.utils.group_nodes_per_metadata(nodes) + self.assertEqual(result, expected_output) + + def test_prepare_context_str_based_on_summaries(self): + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="raw_content1", + metadata={ + "channel": "A", + "thread": "X", + "date": "2022-01-01", + "author_username": "USERNAME#1", + }, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="raw_content2", + metadata={ + "channel": "A", + "thread": "Y", + "date": "2022-01-04", + "author_username": "USERNAME#2", + }, + ), + score=0, + ), + ] + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="summary_content", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ) + ] + grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}} + grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}} + expected_output = ( + "channel: A\nthread: X\ndate: 2022-01-01\nsummary: summary_content\nmessages:\n" + " author: USERNAME#1\n message_date: 2022-01-01\n message 1: raw_content1\n\n" + " author: USERNAME#2\n message_date: 2022-01-04\n message 2: raw_content2\n\n" + ) + result, _ = self.utils.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + self.assertEqual(result.strip(), expected_output.strip()) + + def test_prepare_context_str_based_on_summaries_no_summary(self): + node1 = NodeWithScore( + node=TextNode( + text="raw_content1", + metadata={ + "channel": "A", + "thread": "X", + "date": "2022-01-01", + "author_username": "USERNAME#1", + }, + ), + score=0, + ) + node2 = NodeWithScore( + node=TextNode( + text="raw_content2", + metadata={ + "channel": "A", + "thread": "Y", + "date": "2022-01-04", + "author_username": "USERNAME#2", + }, + ), + score=0, + ) + grouped_raw_nodes = { + "A": {"X": {"2022-01-01": [node1]}, "Y": {"2022-01-04": [node2]}} + } + grouped_summary_nodes = {} + result, ( + summary_nodes_to_fetch_filters, + raw_nodes_missed, + ) = self.utils.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + self.assertEqual(result, "") + self.assertEqual(len(summary_nodes_to_fetch_filters), 2) + for channel in raw_nodes_missed.keys(): + self.assertIn(channel, ["A"]) + for thread in raw_nodes_missed[channel].keys(): + self.assertIn(thread, ["X", "Y"]) + for date in raw_nodes_missed[channel][thread]: + self.assertIn(date, ["2022-01-01", "2022-01-04"]) + nodes = raw_nodes_missed[channel][thread][date] + + if date == "2022-01-01": + self.assertEqual( + grouped_raw_nodes["A"]["X"]["2022-01-01"], nodes + ) + elif date == "2022-01-04": + self.assertEqual( + grouped_raw_nodes["A"]["Y"]["2022-01-04"], nodes + ) + + def test_prepare_context_str_based_on_summaries_multiple_summaries_error(self): + raw_nodes = [ + NodeWithScore( + node=TextNode( + text="raw_content1", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="raw_content2", + metadata={"channel": "A", "thread": "Y", "date": "2022-01-01"}, + ), + score=0, + ), + ] + summary_nodes = [ + NodeWithScore( + node=TextNode( + text="summary_content1", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + NodeWithScore( + node=TextNode( + text="summary_content2", + metadata={"channel": "A", "thread": "X", "date": "2022-01-01"}, + ), + score=0, + ), + ] + grouped_raw_nodes = {"A": {"X": {"2022-01-01": raw_nodes}}} + grouped_summary_nodes = {"A": {"X": {"2022-01-01": summary_nodes}}} + with self.assertRaises(ValueError): + self.utils.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 6d27c01..5cb4e9e 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -1,17 +1,18 @@ import logging from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from bot.retrievers.process_dates import process_dates from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index import VectorStoreIndex from llama_index.llms import OpenAI from llama_index.prompts import PromptTemplate from llama_index.query_engine import CustomQueryEngine from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.retrievers import BaseRetriever -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.schema import NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from utils.query_engine.level_based_platforms_util import LevelBasedPlatformUtils qa_prompt = PromptTemplate( "Context information is below.\n" @@ -35,15 +36,19 @@ def custom_query(self, query_str: str): """Doing custom query""" # first retrieving similar nodes in summary retriever = RetrieveSimilarNodes( - self._vector_store, + self._raw_vector_store, self._similarity_top_k, ) - similar_nodes = retriever.query_db(query=query_str, filters=self._filters) - context_str = self._prepare_context_str(similar_nodes) + similar_nodes = retriever.query_db( + query=query_str, filters=self._filters, date_interval=self._d + ) + + context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) - # logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") + logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") + return str(response) @classmethod @@ -77,13 +82,19 @@ def prepare_platform_engine( **kwargs : llm : llama-index.LLM the LLM to use answering queries - default is gpt-3.5-turbo + default is gpt-4 synthesizer : llama_index.response_synthesizers.base.BaseSynthesizer the synthesizers to use when creating the prompt default is to get from `get_response_synthesizer(response_mode="compact")` qa_prompt : llama-index.prompts.PromptTemplate the Q&A prompt to use default would be the default prompt of llama-index + index_raw : VectorStoreIndex + the vector store index for raw data + If not passed, it would just create one itself + index_summary : VectorStoreIndex + the vector store index for summary data + If not passed, it would just create one itself Returns --------- @@ -95,20 +106,25 @@ def prepare_platform_engine( synthesizer = kwargs.get( "synthesizer", get_response_synthesizer(response_mode="compact") ) - llm = kwargs.get("llm", OpenAI("gpt-3.5-turbo")) + llm = kwargs.get("llm", OpenAI("gpt-4")) qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) - - pg_vector = PGVectorAccess( - table_name=platform_table_name, - dbname=dbname, - testing=testing, - embed_model=CohereEmbedding(), + index: VectorStoreIndex = kwargs.get( + "index_raw", + cls._setup_vector_store_index(platform_table_name, dbname, testing), ) - index = pg_vector.load_index() retriever = index.as_retriever() - _, similarity_top_k, _ = load_hyperparams() + cls._summary_vector_store = kwargs.get( + "index_summary", + cls._setup_vector_store_index( + platform_table_name + "_summary", dbname, testing + ), + )._vector_store + + _, similarity_top_k, d = load_hyperparams() + cls._d = d + + cls._raw_vector_store = index._vector_store - cls._vector_store = index.vector_store cls._similarity_top_k = similarity_top_k cls._filters = filters @@ -128,6 +144,7 @@ def prepare_engine_auto_filter( level1_key: str, level2_key: str, date_key: str = "date", + include_summary_context: bool = False, ) -> "LevelBasedPlatformQueryEngine": """ get the query engine and do the filtering automatically. @@ -163,44 +180,129 @@ def prepare_engine_auto_filter( the created query engine with the filters """ dbname = f"community_{community_id}" - summary_similarity_top_k, _, d = load_hyperparams() + index_summary = cls._setup_vector_store_index( + platform_table_name + "_summary", dbname, False + ) + vector_store = index_summary._vector_store + + retriever = RetrieveSimilarNodes( + vector_store, + summary_similarity_top_k, + ) + # getting nodes of just thread summaries + nodes = retriever.query_db(query, [{"type": "thread"}]) + # For summaries data a posfix `summary` would be added platform_retriever = ForumBasedSummaryRetriever( table_name=platform_table_name + "_summary", dbname=dbname ) - filters = platform_retriever.retreive_filtering( - query=query, + filters = platform_retriever.define_filters( + nodes, metadata_group1_key=level1_key, metadata_group2_key=level2_key, metadata_date_key=date_key, - similarity_top_k=summary_similarity_top_k, ) - # getting all the metadata dates from filters - dates: list[str] = [f[date_key] for f in filters] - dates_modified = process_dates(list(dates), d) - dates_filter = [{date_key: date} for date in dates_modified] - filters.extend(dates_filter) + # saving to add summaries to the context of prompt + if include_summary_context: + cls.summary_nodes = nodes + else: + cls.summary_nodes = [] + + cls._utils_class = LevelBasedPlatformUtils(level1_key, level2_key, date_key) + cls._level1_key = level1_key + cls._level2_key = level2_key + cls._date_key = date_key + cls._d = d + cls._platform_table_name = platform_table_name - logging.info(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") + logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=community_id, platform_table_name=platform_table_name, filters=filters, + index_summary=index_summary, ) return engine - def _prepare_context_str(self, nodes: list[NodeWithScore]) -> str: - context_str = "\n\n".join( - [ - node.get_content() - + "\n" - + node.node.get_metadata_str(mode=MetadataMode.LLM) - for node in nodes - ] - ) + def _prepare_context_str( + self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] + ) -> str: + """ + prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information + """ + context_str: str = "" + + if summary_nodes == []: + logging.warning( + "Empty context_nodes. Cannot add summaries as context information!" + ) + + context_str += self._utils_class.prepare_prompt_with_metadata_info( + nodes=raw_nodes + ) + else: + # grouping the data we have so we could + # get them per each metadata without looping over them + grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + summary_nodes + ) + + # first using the available summary nodes try to create prompt + context_data, ( + summary_nodes_to_fetch_filters, + raw_nodes_missed, + ) = self._utils_class.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + context_str += context_data + + # then if there was some missing summaries + if len(summary_nodes_to_fetch_filters): + retriever = RetrieveSimilarNodes( + self._summary_vector_store, + similarity_top_k=None, + ) + fetched_summary_nodes = retriever.query_db( + query="", + filters=summary_nodes_to_fetch_filters, + ignore_sort=True, + ) + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + fetched_summary_nodes + ) + context_data, ( + summary_nodes_to_fetch_filters, + _, + ) = self._utils_class.prepare_context_str_based_on_summaries( + raw_nodes_missed, grouped_summary_nodes + ) + context_str += context_data + + logging.debug(f"context_str of prompt\n" f"{context_str}") + return context_str + + @classmethod + def _setup_vector_store_index( + cls, + platform_table_name: str, + dbname: str, + testing: bool = False, + ) -> VectorStoreIndex: + """ + prepare the vector_store for querying data + """ + pg_vector = PGVectorAccess( + table_name=platform_table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), + ) + index = pg_vector.load_index() + return index diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py new file mode 100644 index 0000000..5ca2147 --- /dev/null +++ b/utils/query_engine/level_based_platforms_util.py @@ -0,0 +1,161 @@ +import logging + +from dateutil import parser +from llama_index.schema import NodeWithScore + + +class LevelBasedPlatformUtils: + def __init__(self, level1_key: str, level2_key: str, date_key: str) -> None: + self.level1_key = level1_key + self.level2_key = level2_key + self.date_key = date_key + + def prepare_prompt_with_metadata_info( + self, nodes: list[NodeWithScore], prefix: str = "" + ) -> str: + """ + prepare a prompt with given nodes including the nodes' metadata + Note: the prefix is set before each text! + """ + context_str = "\n".join( + [ + prefix + + "author: " + + node.metadata["author_username"] + + "\n" + + prefix + + "message_date: " + + node.metadata["date"] + + "\n" + + prefix + + f"message {idx + 1}: " + + node.get_content() + + "\n" + for idx, node in enumerate(nodes) + ] + ) + + return context_str + + def group_nodes_per_metadata( + self, + nodes: list[NodeWithScore], + ) -> dict[str, dict[str, dict[str, list[NodeWithScore]]]]: + """ + group all nodes based on their level1 and level2 metadata + + Parameters + ----------- + nodes : list[NodeWithScore] + a list of raw nodes + + Returns + --------- + grouped_nodes : dict[str, dict[str, dict[str, list[NodeWithScore]]]] + a list of nodes grouped by + - `level1_key` + - `level2_key` + - and the last dict key `date_key` + + The values of the nested dictionary are the nodes grouped + """ + grouped_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} + for node in nodes: + level1_title = node.metadata[self.level1_key] + level2_title = node.metadata[self.level2_key] + date_str = node.metadata[self.date_key] + date = parser.parse(date_str).strftime("%Y-%m-%d") + + # defining an empty list (if keys weren't previously made) + grouped_nodes.setdefault(level1_title, {}).setdefault( + level2_title, {} + ).setdefault(date, []) + # Adding to list + grouped_nodes[level1_title][level2_title][date].append(node) + + return grouped_nodes + + def prepare_context_str_based_on_summaries( + self, + grouped_raw_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]], + grouped_summary_nodes: dict[str, dict[str, dict[str, list[NodeWithScore]]]], + ) -> tuple[ + str, + tuple[ + list[dict[str, str | None]], + dict[str, dict[str, dict[str, list[NodeWithScore]]]], + ], + ]: + """ + prepare prompt context having the summaries within it + """ + context_str: str = "" + + summary_nodes_to_fetch_filters: list[dict[str, str | None]] = [] + # in case of summary wasn't available for them + raw_nodes_missed: dict[str, dict[str, dict[str, list[NodeWithScore]]]] = {} + + for level1_title in grouped_raw_nodes: + for level2_title in grouped_raw_nodes[level1_title]: + for date in grouped_raw_nodes[level1_title][level2_title]: + raw_nodes = grouped_raw_nodes[level1_title][level2_title][date] + + # the summary_nodes should be always 0 or 1 node + summary_nodes = ( + grouped_summary_nodes.get(level1_title, {}) + .get(level2_title, {}) + .get(date, []) + ) + if len(summary_nodes) == 1: + logging.debug( + f"{len(raw_nodes)} messages available for " + f"{self.level1_key}: {level1_title}, " + f"{self.level2_key}: {level2_title}, " + f"{self.date_key}: {date}" + ) + summary_node = summary_nodes[0] + + node_context: str = ( + f"{self.level1_key}: {level1_title}\n" + f"{self.level2_key}: {level2_title}\n" + f"{self.date_key}: {date}\n" + f"summary: {summary_node.text}\n" + "messages:\n" + ) + node_context += self.prepare_prompt_with_metadata_info( + raw_nodes, prefix=" " + ) + + context_str += node_context + "\n" + elif len(summary_nodes) == 0: + logging.info( + "No summary messages available for " + f"{self.level1_key}: {level1_title}, " + f"{self.level2_key}: {level2_title}, " + f"{self.date_key}: {date}" + "\t will fetch them after" + ) + summary_nodes_to_fetch_filters.append( + { + self.level1_key: level1_title, + self.level2_key: level2_title, + self.date_key: date, + # we need the thread summaries + "type": "thread", + } + ) + raw_nodes_missed.setdefault(level1_title, {}).setdefault( + level2_title, {} + ).setdefault(date, []) + raw_nodes_missed[level1_title][level2_title][date].extend( + raw_nodes + ) + else: + logging.info(f"len(summary_nodes) {len(summary_nodes)}") + raise ValueError( + "Not possible to have multiple summaries for a" + f" combination of " + f"{self.level1_key}-{self.level2_key}-{self.date_key}" + ) + + return context_str, (summary_nodes_to_fetch_filters, raw_nodes_missed) diff --git a/utils/query_engine/prepare_discord_query_engine.py b/utils/query_engine/prepare_discord_query_engine.py index ac185a9..553dde3 100644 --- a/utils/query_engine/prepare_discord_query_engine.py +++ b/utils/query_engine/prepare_discord_query_engine.py @@ -72,5 +72,6 @@ def prepare_discord_engine_auto_filter( level1_key="channel", level2_key="thread", date_key="date", + include_summary_context=True, ) return engine