diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 1df52e7..6743dd3 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -68,15 +68,28 @@ def define_filters( nodes: list[NodeWithScore], metadata_group1_key: str, metadata_group2_key: str, - metadata_date_key: str, + **kwargs, ) -> list[dict[str, str]]: """ - define dictionary filters based on metadata of retrieved nodes + Creates filter dictionaries based on node metadata. + + Filters each node by values in specified metadata groups and an optional date key. + Additional and filters can also be provided. Parameters ---------- nodes : list[dict[llama_index.schema.NodeWithScore]] a list of retrieved similar nodes to define filters based + metadata_group1_key : str + the metadata name 1 to use + metadata_group2_key : str + the metadata name 2 to use + **kwargs : + metadata_date_key : str + the date key in metadata + default is `date` + and_filters : dict[str, str] + more `AND` filters to be applied to each Returns --------- @@ -85,16 +98,20 @@ def define_filters( the dictionary would be applying `and` operation between keys and values of json metadata_ """ + and_filters: dict[str, str] | None = kwargs.get("and_filters", None) + metadata_date_key: str = kwargs.get("metadata_date_key", "date") filters: list[dict[str, str]] = [] for node in nodes: - # the filter made by given node - filter: dict[str, str] = {} - filter[metadata_group1_key] = node.metadata[metadata_group1_key] - filter[metadata_group2_key] = node.metadata[metadata_group2_key] - # date filter - filter[metadata_date_key] = node.metadata[metadata_date_key] + filter_dict: dict[str, str] = { + metadata_group1_key: node.metadata[metadata_group1_key], + metadata_group2_key: node.metadata[metadata_group2_key], + metadata_date_key: node.metadata[metadata_date_key], + } + # if more and filters were given + if and_filters: + filter_dict.update(and_filters) - filters.append(filter) + filters.append(filter_dict) return filters diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 85db5e4..e7954f9 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from uuid import uuid1 from dateutil import parser from llama_index.core.data_structs import Node @@ -7,7 +8,7 @@ from llama_index.core.vector_stores.types import VectorStoreQueryResult from llama_index.vector_stores.postgres import PGVectorStore from llama_index.vector_stores.postgres.base import DBEmbeddingRow -from sqlalchemy import Date, and_, cast, null, or_, select, text +from sqlalchemy import Date, and_, cast, func, literal, null, or_, select, text from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -23,6 +24,7 @@ def __init__( """Init params.""" self._vector_store = vector_store self._embed_model = embed_model + print(f"type(embed_model): {type(embed_model)} | embed_model: {embed_model}") self._similarity_top_k = similarity_top_k def query_db( @@ -30,7 +32,7 @@ def query_db( query: str, filters: list[dict[str, str | dict | None]] | None = None, date_interval: int = 0, - **kwargs + **kwargs, ) -> list[NodeWithScore]: """ query database with given filters (similarity search is also done) @@ -55,26 +57,54 @@ def query_db( Note: This would completely disable the similarity search and it would just return the results with no ordering. default is `False`. If `True` the query will be ignored and no embedding of it would be fetched + aggregate_records : bool + aggregate records and group by a given term in `group_by_metadata` + group_by_metadata : list[str] + do grouping by some property of `metadata_` """ ignore_sort = kwargs.get("ignore_sort", False) + aggregate_records = kwargs.get("aggregate_records", False) + group_by_metadata = kwargs.get("group_by_metadata", []) + if not isinstance(group_by_metadata, list): + raise ValueError("Expected 'group_by_metadata' to be a list.") + self._vector_store._initialize() - if not ignore_sort: - embedding = self._embed_model.get_text_embedding(text=query) + if not aggregate_records: + stmt = select( # type: ignore + self._vector_store._table_class.id, + self._vector_store._table_class.node_id, + self._vector_store._table_class.text, + self._vector_store._table_class.metadata_, + ( + self._vector_store._table_class.embedding.cosine_distance( + self._embed_model.get_text_embedding(text=query) + ) + if not ignore_sort + else null() + ).label("distance"), + ) else: - embedding = None - - stmt = select( # type: ignore - self._vector_store._table_class.id, - self._vector_store._table_class.node_id, - self._vector_store._table_class.text, - self._vector_store._table_class.metadata_, - ( - self._vector_store._table_class.embedding.cosine_distance(embedding) - if not ignore_sort - else null() - ).label("distance"), - ) + # to manually create metadata + metadata_grouping = [] + for item in group_by_metadata: + metadata_grouping.append(item) + metadata_grouping.append( + self._vector_store._table_class.metadata_.op("->>")(item) + ) + + stmt = select( + null().label("id"), + literal(str(uuid1())).label("node_id"), + func.aggregate_strings( + # default content key for llama-index nodes and documents + # is `text` + self._vector_store._table_class.text, + "\n", + ).label("text"), + func.json_build_object(*metadata_grouping).label("metadata_"), + null().label("distance"), + ) if not ignore_sort: stmt = stmt.order_by(text("distance asc")) @@ -128,8 +158,15 @@ def query_db( stmt = stmt.where(or_(*conditions)) - if self._similarity_top_k is not None: - stmt = stmt.limit(self._similarity_top_k) + if aggregate_records: + group_by_terms = [ + self._vector_store._table_class.metadata_.op("->>")(item) + for item in group_by_metadata + ] + stmt = stmt.group_by(*group_by_terms) + + if self._similarity_top_k is not None: + stmt = stmt.limit(self._similarity_top_k) with self._vector_store._session() as session, session.begin(): res = session.execute(stmt) diff --git a/tests/integration/test_retrieve_similar_nodes.py b/tests/integration/test_retrieve_similar_nodes.py deleted file mode 100644 index 0ac5348..0000000 --- a/tests/integration/test_retrieve_similar_nodes.py +++ /dev/null @@ -1,34 +0,0 @@ -from unittest import TestCase -from unittest.mock import MagicMock - -from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes -from llama_index.core.schema import TextNode - - -class TestRetrieveSimilarNodes(TestCase): - def setUp(self): - self.table_name = "sample_table" - self.dbname = "community_some_id" - - self.vector_store = MagicMock() - self.embed_model = MagicMock() - self.retriever = RetrieveSimilarNodes( - vector_store=self.vector_store, - similarity_top_k=5, - embed_model=self.embed_model, - ) - - def test_init(self): - self.assertEqual(self.retriever._similarity_top_k, 5) - self.assertEqual(self.vector_store, self.retriever._vector_store) - - def test_get_nodes_with_score(self): - # Test the _get_nodes_with_score private method - query_result = MagicMock() - query_result.nodes = [TextNode(), TextNode(), TextNode()] - query_result.similarities = [0.8, 0.9, 0.7] - - result = self.retriever._get_nodes_with_score(query_result) - - self.assertEqual(len(result), 3) - self.assertAlmostEqual(result[0].score, 0.8, delta=0.001) diff --git a/tests/unit/test_retrieve_similar_nodes.py b/tests/unit/test_retrieve_similar_nodes.py new file mode 100644 index 0000000..b9d13e2 --- /dev/null +++ b/tests/unit/test_retrieve_similar_nodes.py @@ -0,0 +1,94 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.core.schema import NodeWithScore, TextNode +from llama_index.vector_stores.postgres import PGVectorStore + + +class TestRetrieveSimilarNodes(TestCase): + def setUp(self): + self.table_name = "sample_table" + self.dbname = "community_some_id" + + self.vector_store = PGVectorStore.from_params( + database="sample_db", + host="sample_host", + password="pass", + port=5432, + user="user", + table_name=self.table_name, + embed_dim=1536, + ) + self.embed_model = MagicMock() + self.retriever = RetrieveSimilarNodes( + vector_store=self.vector_store, + similarity_top_k=5, + embed_model=self.embed_model, + ) + + def test_init(self): + self.assertEqual(self.retriever._similarity_top_k, 5) + self.assertEqual(self.vector_store, self.retriever._vector_store) + + def test_get_nodes_with_score(self): + # Test the _get_nodes_with_score private method + query_result = MagicMock() + query_result.nodes = [TextNode(), TextNode(), TextNode()] + query_result.similarities = [0.8, 0.9, 0.7] + + result = self.retriever._get_nodes_with_score(query_result) + + self.assertEqual(len(result), 3) + self.assertAlmostEqual(result[0].score, 0.8, delta=0.001) + + @patch.object(PGVectorStore, "_initialize") + @patch.object(PGVectorStore, "_session") + def test_query_db_with_filters_and_date(self, mock_session, mock_initialize): + # Mock vector store initialization + mock_initialize.return_value = None + mock_session.begin = MagicMock() + mock_session.execute = MagicMock() + mock_session.execute.return_value = [1] + + query = "test query" + filters = [{"date": "2024-04-09"}] + date_interval = 2 # Look for nodes within 2 days of the filter date + + # Call the query_db method with filters and date + results = self.retriever.query_db(query, filters, date_interval) + + mock_initialize.assert_called_once() + mock_session.assert_called_once() + + # Assert that the returned results are of type NodeWithScore + self.assertTrue(isinstance(result, NodeWithScore) for result in results) + + @patch.object(PGVectorStore, "_initialize") + @patch.object(PGVectorStore, "_session") + def test_query_db_with_filters_and_date_aggregate_records( + self, mock_session, mock_initialize + ): + mock_initialize.return_value = None + mock_session.begin = MagicMock() + mock_session.execute = MagicMock() + mock_session.execute.return_value = [1] + + query = "test query" + filters = [{"date": "2024-04-09"}] + date_interval = 2 # Look for nodes within 2 days of the filter date + + # Call the query_db method with filters and date + results = self.retriever.query_db( + query, + filters, + date_interval, + aggregate_records=True, + group_by_metadata=["thread"], + ) + + mock_initialize.assert_called_once() + mock_session.assert_called_once() + + # Assert that the returned results are of type NodeWithScore + self.assertTrue(isinstance(result, NodeWithScore) for result in results) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index dc64cf9..82dc7a7 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -47,7 +47,7 @@ def custom_query(self, query_str: str): query=query_str, filters=self._filters, date_interval=self._d ) - context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) + context_str = self._prepare_context_str(similar_nodes, summary_nodes=None) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") @@ -98,6 +98,12 @@ def prepare_platform_engine( index_summary : VectorStoreIndex the vector store index for summary data If not passed, it would just create one itself + summary_nodes_filters : list[dict[str, str]] + a list of filters to fetch the summary nodes + for default, not passing this would mean to use previous nodes + but if passed we would re-fetch nodes. + This could be benefitial in case we want to do some manual + processing with nodes Returns --------- @@ -115,6 +121,8 @@ def prepare_platform_engine( "index_raw", cls._setup_vector_store_index(platform_table_name, dbname, testing), ) + summary_nodes_filters = kwargs.get("summary_nodes_filters", None) + retriever = index.as_retriever() cls._summary_vector_store = kwargs.get( "index_summary", @@ -130,6 +138,7 @@ def prepare_platform_engine( cls._similarity_top_k = similarity_top_k cls._filters = filters + cls._summary_nodes_filters = summary_nodes_filters return cls( retriever=retriever, @@ -202,12 +211,20 @@ def prepare_engine_auto_filter( table_name=platform_table_name + "_summary", dbname=dbname ) - filters = platform_retriever.define_filters( + raw_nodes_filters = platform_retriever.define_filters( nodes, metadata_group1_key=level1_key, metadata_group2_key=level2_key, metadata_date_key=date_key, ) + summary_nodes_filters = platform_retriever.define_filters( + nodes, + metadata_group1_key=level1_key, + metadata_group2_key=level2_key, + metadata_date_key=date_key, + # we will always use thread summaries + and_filters={"type": "thread"}, + ) # saving to add summaries to the context of prompt if include_summary_context: @@ -222,18 +239,21 @@ def prepare_engine_auto_filter( cls._d = d cls._platform_table_name = platform_table_name - logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") + logging.debug( + f"COMMUNITY_ID: {community_id} | raw filters: {raw_nodes_filters}" + ) engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=community_id, platform_table_name=platform_table_name, - filters=filters, + filters=raw_nodes_filters, index_summary=index_summary, + summary_nodes_filters=summary_nodes_filters, ) return engine def _prepare_context_str( - self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] + self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | None ) -> str: """ prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information @@ -248,6 +268,31 @@ def _prepare_context_str( context_str += self._utils_class.prepare_prompt_with_metadata_info( nodes=raw_nodes ) + elif summary_nodes is None: + retriever = RetrieveSimilarNodes( + self._summary_vector_store, + similarity_top_k=None, + ) + # Note: `self._summary_nodes_filters` must be set before + fetched_summary_nodes = retriever.query_db( + query="", + filters=self._summary_nodes_filters, + aggregate_records=True, + ignore_sort=True, + group_by_metadata=["thread", "date", "channel"], + date_interval=self._d, + ) + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + fetched_summary_nodes + ) + grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) + context_data, ( + summary_nodes_to_fetch_filters, + _, + ) = self._utils_class.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + context_str += context_data else: # grouping the data we have so we could # get them per each metadata without looping over them diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index 5574675..6d16d7f 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -63,6 +63,7 @@ def group_nodes_per_metadata( str | None, dict[str | None, dict[str, list[NodeWithScore]]] ] = {} for node in nodes: + # logging.info(f"node.metadata {node.metadata}") level1_title = node.metadata[self.level1_key] level2_title = node.metadata[self.level2_key] date_str = node.metadata[self.date_key]