From 0a83f60ec2da85b1e8966ad7e91f097023f97b4c Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 19 Mar 2024 13:51:34 +0330 Subject: [PATCH 01/11] feat: Added aggregation of summaries! Still errors on retrieve_similar_nodes.py in case of group_by. We need to fix that. Commiting the work to be saved. --- bot/retrievers/forum_summary_retriever.py | 18 ++++- bot/retrievers/retrieve_similar_nodes.py | 66 ++++++++++++++----- .../level_based_platform_query_engine.py | 54 +++++++++++++-- .../level_based_platforms_util.py | 1 + 4 files changed, 114 insertions(+), 25 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 1df52e7..c97d55f 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -69,6 +69,7 @@ def define_filters( metadata_group1_key: str, metadata_group2_key: str, metadata_date_key: str, + **kwargs, ) -> list[dict[str, str]]: """ define dictionary filters based on metadata of retrieved nodes @@ -77,6 +78,15 @@ def define_filters( ---------- nodes : list[dict[llama_index.schema.NodeWithScore]] a list of retrieved similar nodes to define filters based + metadata_group1_key : str + the metadata name 1 to use + metadata_group2_key : str + the metadata name 2 to use + metadata_date_key : str + the date key in metadata + **kwargs : + and_filters : dict[str, str] + more `AND` filters to be applied to each Returns --------- @@ -85,16 +95,20 @@ def define_filters( the dictionary would be applying `and` operation between keys and values of json metadata_ """ + and_filters: dict[str, str] | None = kwargs.get("and_filters", None) filters: list[dict[str, str]] = [] for node in nodes: - # the filter made by given node filter: dict[str, str] = {} filter[metadata_group1_key] = node.metadata[metadata_group1_key] filter[metadata_group2_key] = node.metadata[metadata_group2_key] - # date filter filter[metadata_date_key] = node.metadata[metadata_date_key] + # if more and filters were given + if and_filters: + for key, value in and_filters.items(): + filter[key] = value + filters.append(filter) return filters diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 85db5e4..71e8752 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -7,7 +7,7 @@ from llama_index.core.vector_stores.types import VectorStoreQueryResult from llama_index.vector_stores.postgres import PGVectorStore from llama_index.vector_stores.postgres.base import DBEmbeddingRow -from sqlalchemy import Date, and_, cast, null, or_, select, text +from sqlalchemy import Date, and_, cast, null, or_, select, text, func from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -55,26 +55,51 @@ def query_db( Note: This would completely disable the similarity search and it would just return the results with no ordering. default is `False`. If `True` the query will be ignored and no embedding of it would be fetched + aggregate_records : bool + aggregate records and group by a given term in `group_by_metadata` + group_by_metadata : list[str] + do grouping by some property of `metadata_` """ ignore_sort = kwargs.get("ignore_sort", False) + aggregate_records = kwargs.get("aggregate_records", False) + group_by_metadata = kwargs.get("group_by_metadata") self._vector_store._initialize() - if not ignore_sort: - embedding = self._embed_model.get_text_embedding(text=query) + if not aggregate_records: + stmt = select( # type: ignore + self._vector_store._table_class.id, + self._vector_store._table_class.node_id, + self._vector_store._table_class.text, + self._vector_store._table_class.metadata_, + ( + self._vector_store._table_class.embedding.cosine_distance( + self._embed_model.get_text_embedding(text=query) + ) + if not ignore_sort + else null() + ).label("distance"), + ) else: - embedding = None - - stmt = select( # type: ignore - self._vector_store._table_class.id, - self._vector_store._table_class.node_id, - self._vector_store._table_class.text, - self._vector_store._table_class.metadata_, - ( - self._vector_store._table_class.embedding.cosine_distance(embedding) - if not ignore_sort - else null() - ).label("distance"), - ) + # manually creating metadata + metadata_grouping = [] + for item in group_by_metadata: + metadata_grouping.append(item) + metadata_grouping.append( + self._vector_store._table_class.metadata_.op("->>")(item) + ) + + stmt = select( + null().label("id"), + null().label("node_id"), + func.aggregate_strings( + # default content key for llama-index nodes and documents + # is `text` + self._vector_store._table_class.metadata_.op("->>")("text"), + "\n", + ).label("text"), + func.json_agg(func.json_build_object(metadata_grouping)), + null().label("embedding"), + ) if not ignore_sort: stmt = stmt.order_by(text("distance asc")) @@ -128,8 +153,13 @@ def query_db( stmt = stmt.where(or_(*conditions)) - if self._similarity_top_k is not None: - stmt = stmt.limit(self._similarity_top_k) + if self._similarity_top_k is not None: + if aggregate_records: + stmt.group_by( + self._vector_store._table_class.metadata_.op("->>")(item) + for item in group_by_metadata + ) + stmt = stmt.limit(self._similarity_top_k) with self._vector_store._session() as session, session.begin(): res = session.execute(stmt) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index dc64cf9..5e713e2 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -47,7 +47,7 @@ def custom_query(self, query_str: str): query=query_str, filters=self._filters, date_interval=self._d ) - context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) + context_str = self._prepare_context_str(similar_nodes, summary_nodes=None) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") @@ -98,6 +98,12 @@ def prepare_platform_engine( index_summary : VectorStoreIndex the vector store index for summary data If not passed, it would just create one itself + summary_nodes_filters : list[dict[str, str]] + a list of filters to fetch the summary nodes + for default, not passing this would mean to use previous nodes + but if passed we would re-fetch nodes. + This could be benefitial in case we want to do some manual + processing with nodes Returns --------- @@ -115,6 +121,8 @@ def prepare_platform_engine( "index_raw", cls._setup_vector_store_index(platform_table_name, dbname, testing), ) + summary_nodes_filters = kwargs.get("summary_nodes_filters", None) + retriever = index.as_retriever() cls._summary_vector_store = kwargs.get( "index_summary", @@ -130,6 +138,7 @@ def prepare_platform_engine( cls._similarity_top_k = similarity_top_k cls._filters = filters + cls._summary_nodes_filters = summary_nodes_filters return cls( retriever=retriever, @@ -202,12 +211,20 @@ def prepare_engine_auto_filter( table_name=platform_table_name + "_summary", dbname=dbname ) - filters = platform_retriever.define_filters( + raw_nodes_filters = platform_retriever.define_filters( nodes, metadata_group1_key=level1_key, metadata_group2_key=level2_key, metadata_date_key=date_key, ) + summary_nodes_filters = platform_retriever.define_filters( + nodes, + metadata_group1_key=level1_key, + metadata_group2_key=level2_key, + metadata_date_key=date_key, + # we will always use thread summaries + and_filters={"type": "thread"}, + ) # saving to add summaries to the context of prompt if include_summary_context: @@ -222,18 +239,21 @@ def prepare_engine_auto_filter( cls._d = d cls._platform_table_name = platform_table_name - logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") + logging.debug( + f"COMMUNITY_ID: {community_id} | raw filters: {raw_nodes_filters}" + ) engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=community_id, platform_table_name=platform_table_name, - filters=filters, + filters=raw_nodes_filters, index_summary=index_summary, + summary_nodes_filters=summary_nodes_filters, ) return engine def _prepare_context_str( - self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] + self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | None ) -> str: """ prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information @@ -248,6 +268,30 @@ def _prepare_context_str( context_str += self._utils_class.prepare_prompt_with_metadata_info( nodes=raw_nodes ) + elif summary_nodes is None: + retriever = RetrieveSimilarNodes( + self._summary_vector_store, + similarity_top_k=None, + ) + # Note: `self._summary_nodes_filters` must be set before + fetched_summary_nodes = retriever.query_db( + query="", + filters=self._summary_nodes_filters, + aggregate_records=True, + group_by_metadata=["thread", "date"], + date_interval=self._d, + ) + grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( + fetched_summary_nodes + ) + grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) + context_data, ( + summary_nodes_to_fetch_filters, + _, + ) = self._utils_class.prepare_context_str_based_on_summaries( + grouped_raw_nodes, grouped_summary_nodes + ) + context_str += context_data else: # grouping the data we have so we could # get them per each metadata without looping over them diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index 5574675..ecbc9bd 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -63,6 +63,7 @@ def group_nodes_per_metadata( str | None, dict[str | None, dict[str, list[NodeWithScore]]] ] = {} for node in nodes: + logging.info(f"node.metadata {node.metadata}") level1_title = node.metadata[self.level1_key] level2_title = node.metadata[self.level2_key] date_str = node.metadata[self.date_key] From a444755f8d36188b421bfa72363362858b97c00f Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 10 Apr 2024 15:41:15 +0330 Subject: [PATCH 02/11] fix: the query to hold right information! --- bot/retrievers/retrieve_similar_nodes.py | 33 ++++++++++++------- .../level_based_platform_query_engine.py | 3 +- .../level_based_platforms_util.py | 2 +- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 71e8752..70ef409 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -7,8 +7,9 @@ from llama_index.core.vector_stores.types import VectorStoreQueryResult from llama_index.vector_stores.postgres import PGVectorStore from llama_index.vector_stores.postgres.base import DBEmbeddingRow -from sqlalchemy import Date, and_, cast, null, or_, select, text, func +from sqlalchemy import Date, and_, cast, null, or_, select, text, literal, func from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from uuid import uuid1 class RetrieveSimilarNodes: @@ -80,7 +81,7 @@ def query_db( ).label("distance"), ) else: - # manually creating metadata + # to manually create metadata metadata_grouping = [] for item in group_by_metadata: metadata_grouping.append(item) @@ -90,15 +91,17 @@ def query_db( stmt = select( null().label("id"), - null().label("node_id"), + literal(str(uuid1())).label("node_id"), func.aggregate_strings( # default content key for llama-index nodes and documents # is `text` - self._vector_store._table_class.metadata_.op("->>")("text"), + self._vector_store._table_class.text, "\n", ).label("text"), - func.json_agg(func.json_build_object(metadata_grouping)), - null().label("embedding"), + func.json_agg(func.json_build_object(*metadata_grouping)).label( + "metadata_" + ), + null().label("distance"), ) if not ignore_sort: @@ -153,12 +156,14 @@ def query_db( stmt = stmt.where(or_(*conditions)) + if aggregate_records: + group_by_terms = [ + self._vector_store._table_class.metadata_.op("->>")(item) + for item in group_by_metadata + ] + stmt = stmt.group_by(*group_by_terms) + if self._similarity_top_k is not None: - if aggregate_records: - stmt.group_by( - self._vector_store._table_class.metadata_.op("->>")(item) - for item in group_by_metadata - ) stmt = stmt.limit(self._similarity_top_k) with self._vector_store._session() as session, session.begin(): @@ -168,7 +173,11 @@ def query_db( DBEmbeddingRow( node_id=item.node_id, text=item.text, - metadata=item.metadata_, + # in case of aggregation having null values + # the metadata might will have duplicate date + # so using the first index always will make it right + # in this case, always the metadata should be the same as group_by data + metadata=item.metadata_ if not aggregate_records else item.metadata_[0], similarity=(1 - item.distance) if item.distance is not None else 0, ) for item in res.all() diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 5e713e2..82dc7a7 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -278,7 +278,8 @@ def _prepare_context_str( query="", filters=self._summary_nodes_filters, aggregate_records=True, - group_by_metadata=["thread", "date"], + ignore_sort=True, + group_by_metadata=["thread", "date", "channel"], date_interval=self._d, ) grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( diff --git a/utils/query_engine/level_based_platforms_util.py b/utils/query_engine/level_based_platforms_util.py index ecbc9bd..6d16d7f 100644 --- a/utils/query_engine/level_based_platforms_util.py +++ b/utils/query_engine/level_based_platforms_util.py @@ -63,7 +63,7 @@ def group_nodes_per_metadata( str | None, dict[str | None, dict[str, list[NodeWithScore]]] ] = {} for node in nodes: - logging.info(f"node.metadata {node.metadata}") + # logging.info(f"node.metadata {node.metadata}") level1_title = node.metadata[self.level1_key] level2_title = node.metadata[self.level2_key] date_str = node.metadata[self.date_key] From 95a6898cc78981e55f92cd04ee9d06b8a364d31f Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 10 Apr 2024 16:05:37 +0330 Subject: [PATCH 03/11] fix: linter issues! --- bot/retrievers/retrieve_similar_nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 70ef409..87d2a89 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from uuid import uuid1 from dateutil import parser from llama_index.core.data_structs import Node @@ -7,9 +8,8 @@ from llama_index.core.vector_stores.types import VectorStoreQueryResult from llama_index.vector_stores.postgres import PGVectorStore from llama_index.vector_stores.postgres.base import DBEmbeddingRow -from sqlalchemy import Date, and_, cast, null, or_, select, text, literal, func +from sqlalchemy import Date, and_, cast, func, literal, null, or_, select, text from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from uuid import uuid1 class RetrieveSimilarNodes: @@ -63,7 +63,7 @@ def query_db( """ ignore_sort = kwargs.get("ignore_sort", False) aggregate_records = kwargs.get("aggregate_records", False) - group_by_metadata = kwargs.get("group_by_metadata") + group_by_metadata = kwargs.get("group_by_metadata", []) self._vector_store._initialize() if not aggregate_records: From 0fdd128254841cdcf39ac586fb14fd8c3bc5ef09 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 08:57:38 +0330 Subject: [PATCH 04/11] feat: query improvements! Getting the metadata_ in a way no duplicate data could happen. --- bot/retrievers/retrieve_similar_nodes.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 87d2a89..f88137b 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -98,7 +98,7 @@ def query_db( self._vector_store._table_class.text, "\n", ).label("text"), - func.json_agg(func.json_build_object(*metadata_grouping)).label( + func.json_build_object(*metadata_grouping).label( "metadata_" ), null().label("distance"), @@ -173,11 +173,7 @@ def query_db( DBEmbeddingRow( node_id=item.node_id, text=item.text, - # in case of aggregation having null values - # the metadata might will have duplicate date - # so using the first index always will make it right - # in this case, always the metadata should be the same as group_by data - metadata=item.metadata_ if not aggregate_records else item.metadata_[0], + metadata=item.metadata_, similarity=(1 - item.distance) if item.distance is not None else 0, ) for item in res.all() From 8509948f3bf20dc5bd687c695cff05845e9bcf5b Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 09:20:13 +0330 Subject: [PATCH 05/11] fix: black linter issue! --- bot/retrievers/retrieve_similar_nodes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index f88137b..4a41087 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -98,9 +98,7 @@ def query_db( self._vector_store._table_class.text, "\n", ).label("text"), - func.json_build_object(*metadata_grouping).label( - "metadata_" - ), + func.json_build_object(*metadata_grouping).label("metadata_"), null().label("distance"), ) From 87bf092ce0d14cf6c15cb8feb43a5da88fae25ca Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 09:40:37 +0330 Subject: [PATCH 06/11] fix: trying to fix codeClimate issues! --- bot/retrievers/forum_summary_retriever.py | 27 +++++++++++++---------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index c97d55f..6743dd3 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -68,11 +68,13 @@ def define_filters( nodes: list[NodeWithScore], metadata_group1_key: str, metadata_group2_key: str, - metadata_date_key: str, **kwargs, ) -> list[dict[str, str]]: """ - define dictionary filters based on metadata of retrieved nodes + Creates filter dictionaries based on node metadata. + + Filters each node by values in specified metadata groups and an optional date key. + Additional and filters can also be provided. Parameters ---------- @@ -82,9 +84,10 @@ def define_filters( the metadata name 1 to use metadata_group2_key : str the metadata name 2 to use - metadata_date_key : str - the date key in metadata **kwargs : + metadata_date_key : str + the date key in metadata + default is `date` and_filters : dict[str, str] more `AND` filters to be applied to each @@ -96,19 +99,19 @@ def define_filters( operation between keys and values of json metadata_ """ and_filters: dict[str, str] | None = kwargs.get("and_filters", None) + metadata_date_key: str = kwargs.get("metadata_date_key", "date") filters: list[dict[str, str]] = [] for node in nodes: - filter: dict[str, str] = {} - filter[metadata_group1_key] = node.metadata[metadata_group1_key] - filter[metadata_group2_key] = node.metadata[metadata_group2_key] - filter[metadata_date_key] = node.metadata[metadata_date_key] - + filter_dict: dict[str, str] = { + metadata_group1_key: node.metadata[metadata_group1_key], + metadata_group2_key: node.metadata[metadata_group2_key], + metadata_date_key: node.metadata[metadata_date_key], + } # if more and filters were given if and_filters: - for key, value in and_filters.items(): - filter[key] = value + filter_dict.update(and_filters) - filters.append(filter) + filters.append(filter_dict) return filters From 0135d0cae878a7cb43f2eb3e249e6a9047902948 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 10:29:18 +0330 Subject: [PATCH 07/11] feat: increase test case coverage! note: this test case was just testing each module and no db conection was made, so we moved it to unit tests. --- .../test_retrieve_similar_nodes.py | 34 ------ tests/unit/test_retrieve_similar_nodes.py | 110 ++++++++++++++++++ 2 files changed, 110 insertions(+), 34 deletions(-) delete mode 100644 tests/integration/test_retrieve_similar_nodes.py create mode 100644 tests/unit/test_retrieve_similar_nodes.py diff --git a/tests/integration/test_retrieve_similar_nodes.py b/tests/integration/test_retrieve_similar_nodes.py deleted file mode 100644 index 0ac5348..0000000 --- a/tests/integration/test_retrieve_similar_nodes.py +++ /dev/null @@ -1,34 +0,0 @@ -from unittest import TestCase -from unittest.mock import MagicMock - -from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes -from llama_index.core.schema import TextNode - - -class TestRetrieveSimilarNodes(TestCase): - def setUp(self): - self.table_name = "sample_table" - self.dbname = "community_some_id" - - self.vector_store = MagicMock() - self.embed_model = MagicMock() - self.retriever = RetrieveSimilarNodes( - vector_store=self.vector_store, - similarity_top_k=5, - embed_model=self.embed_model, - ) - - def test_init(self): - self.assertEqual(self.retriever._similarity_top_k, 5) - self.assertEqual(self.vector_store, self.retriever._vector_store) - - def test_get_nodes_with_score(self): - # Test the _get_nodes_with_score private method - query_result = MagicMock() - query_result.nodes = [TextNode(), TextNode(), TextNode()] - query_result.similarities = [0.8, 0.9, 0.7] - - result = self.retriever._get_nodes_with_score(query_result) - - self.assertEqual(len(result), 3) - self.assertAlmostEqual(result[0].score, 0.8, delta=0.001) diff --git a/tests/unit/test_retrieve_similar_nodes.py b/tests/unit/test_retrieve_similar_nodes.py new file mode 100644 index 0000000..c2ab39f --- /dev/null +++ b/tests/unit/test_retrieve_similar_nodes.py @@ -0,0 +1,110 @@ +from unittest import TestCase +from unittest.mock import MagicMock +from unittest.mock import patch + +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.vector_stores.postgres import PGVectorStore +from llama_index.core.schema import NodeWithScore, TextNode + + +class TestRetrieveSimilarNodes(TestCase): + def setUp(self): + self.table_name = "sample_table" + self.dbname = "community_some_id" + + self.vector_store = MagicMock() + self.embed_model = MagicMock() + self.retriever = RetrieveSimilarNodes( + vector_store=self.vector_store, + similarity_top_k=5, + embed_model=self.embed_model, + ) + + def test_init(self): + self.assertEqual(self.retriever._similarity_top_k, 5) + self.assertEqual(self.vector_store, self.retriever._vector_store) + + def test_get_nodes_with_score(self): + # Test the _get_nodes_with_score private method + query_result = MagicMock() + query_result.nodes = [TextNode(), TextNode(), TextNode()] + query_result.similarities = [0.8, 0.9, 0.7] + + result = self.retriever._get_nodes_with_score(query_result) + + self.assertEqual(len(result), 3) + self.assertAlmostEqual(result[0].score, 0.8, delta=0.001) + + @patch.object(PGVectorStore, "_initialize") + @patch.object(PGVectorStore, "_session") + def test_query_db_with_filters_and_date(self, mock_session, mock_initialize): + # Mock vector store initialization + mock_initialize.return_value = None + mock_session.begin = MagicMock() + mock_session.execute = MagicMock() + mock_session.execute.return_value = [1] + + vector_store = PGVectorStore.from_params( + database="sample_db", + host="sample_host", + password="pass", + port=5432, + user="user", + table_name=self.table_name, + embed_dim=1536, + ) + retrieve_similar_nodes = RetrieveSimilarNodes(vector_store, similarity_top_k=5) + + query = "test query" + filters = [{"date": "2024-04-09"}] + date_interval = 2 # Look for nodes within 2 days of the filter date + + # Call the query_db method with filters and date + results = retrieve_similar_nodes.query_db(query, filters, date_interval) + + mock_initialize.assert_called_once() + mock_session.assert_called_once() + + # Assert that the returned results are of type NodeWithScore + self.assertTrue(isinstance(result, NodeWithScore) for result in results) + + @patch.object(PGVectorStore, "_initialize") + @patch.object(PGVectorStore, "_session") + def test_query_db_with_filters_and_date_aggregate_records( + self, mock_session, mock_initialize + ): + # Mock vector store initialization + mock_initialize.return_value = None + mock_session.begin = MagicMock() + mock_session.execute = MagicMock() + mock_session.execute.return_value = [1] + + vector_store = PGVectorStore.from_params( + database="sample_db", + host="sample_host", + password="pass", + port=5432, + user="user", + table_name=self.table_name, + embed_dim=1536, + ) + retrieve_similar_nodes = RetrieveSimilarNodes(vector_store, similarity_top_k=5) + + query = "test query" + filters = [{"date": "2024-04-09"}] + date_interval = 2 # Look for nodes within 2 days of the filter date + + # Call the query_db method with filters and date + results = retrieve_similar_nodes.query_db( + query, + filters, + date_interval, + aggregate_records=True, + group_by_metadata=["thread"], + ) + + mock_initialize.assert_called_once() + mock_session.assert_called_once() + + # Assert that the returned results are of type NodeWithScore + self.assertTrue(isinstance(result, NodeWithScore) for result in results) From 3ece4f350093c3954c205a2d3d52376161cd894b Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 10:48:13 +0330 Subject: [PATCH 08/11] fix: linter issues and test case! --- bot/retrievers/retrieve_similar_nodes.py | 3 +- tests/unit/test_retrieve_similar_nodes.py | 42 +++++++---------------- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 4a41087..678b502 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -24,6 +24,7 @@ def __init__( """Init params.""" self._vector_store = vector_store self._embed_model = embed_model + print(f"type(embed_model): {type(embed_model)} | embed_model: {embed_model}") self._similarity_top_k = similarity_top_k def query_db( @@ -31,7 +32,7 @@ def query_db( query: str, filters: list[dict[str, str | dict | None]] | None = None, date_interval: int = 0, - **kwargs + **kwargs, ) -> list[NodeWithScore]: """ query database with given filters (similarity search is also done) diff --git a/tests/unit/test_retrieve_similar_nodes.py b/tests/unit/test_retrieve_similar_nodes.py index c2ab39f..b9d13e2 100644 --- a/tests/unit/test_retrieve_similar_nodes.py +++ b/tests/unit/test_retrieve_similar_nodes.py @@ -1,10 +1,9 @@ from unittest import TestCase -from unittest.mock import MagicMock -from unittest.mock import patch +from unittest.mock import MagicMock, patch from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes -from llama_index.vector_stores.postgres import PGVectorStore from llama_index.core.schema import NodeWithScore, TextNode +from llama_index.vector_stores.postgres import PGVectorStore class TestRetrieveSimilarNodes(TestCase): @@ -12,7 +11,15 @@ def setUp(self): self.table_name = "sample_table" self.dbname = "community_some_id" - self.vector_store = MagicMock() + self.vector_store = PGVectorStore.from_params( + database="sample_db", + host="sample_host", + password="pass", + port=5432, + user="user", + table_name=self.table_name, + embed_dim=1536, + ) self.embed_model = MagicMock() self.retriever = RetrieveSimilarNodes( vector_store=self.vector_store, @@ -44,23 +51,12 @@ def test_query_db_with_filters_and_date(self, mock_session, mock_initialize): mock_session.execute = MagicMock() mock_session.execute.return_value = [1] - vector_store = PGVectorStore.from_params( - database="sample_db", - host="sample_host", - password="pass", - port=5432, - user="user", - table_name=self.table_name, - embed_dim=1536, - ) - retrieve_similar_nodes = RetrieveSimilarNodes(vector_store, similarity_top_k=5) - query = "test query" filters = [{"date": "2024-04-09"}] date_interval = 2 # Look for nodes within 2 days of the filter date # Call the query_db method with filters and date - results = retrieve_similar_nodes.query_db(query, filters, date_interval) + results = self.retriever.query_db(query, filters, date_interval) mock_initialize.assert_called_once() mock_session.assert_called_once() @@ -73,29 +69,17 @@ def test_query_db_with_filters_and_date(self, mock_session, mock_initialize): def test_query_db_with_filters_and_date_aggregate_records( self, mock_session, mock_initialize ): - # Mock vector store initialization mock_initialize.return_value = None mock_session.begin = MagicMock() mock_session.execute = MagicMock() mock_session.execute.return_value = [1] - vector_store = PGVectorStore.from_params( - database="sample_db", - host="sample_host", - password="pass", - port=5432, - user="user", - table_name=self.table_name, - embed_dim=1536, - ) - retrieve_similar_nodes = RetrieveSimilarNodes(vector_store, similarity_top_k=5) - query = "test query" filters = [{"date": "2024-04-09"}] date_interval = 2 # Look for nodes within 2 days of the filter date # Call the query_db method with filters and date - results = retrieve_similar_nodes.query_db( + results = self.retriever.query_db( query, filters, date_interval, From b142f6244186bb8d3c1a9d4c742eab7e14f0755c Mon Sep 17 00:00:00 2001 From: Mohammad Amin Dadgar <48308230+amindadgar@users.noreply.github.com> Date: Thu, 11 Apr 2024 10:53:28 +0330 Subject: [PATCH 09/11] Update bot/retrievers/retrieve_similar_nodes.py Add group_by_metadata validation Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- bot/retrievers/retrieve_similar_nodes.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 678b502..ddd80c5 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -57,14 +57,11 @@ def query_db( Note: This would completely disable the similarity search and it would just return the results with no ordering. default is `False`. If `True` the query will be ignored and no embedding of it would be fetched - aggregate_records : bool - aggregate records and group by a given term in `group_by_metadata` - group_by_metadata : list[str] - do grouping by some property of `metadata_` - """ ignore_sort = kwargs.get("ignore_sort", False) aggregate_records = kwargs.get("aggregate_records", False) group_by_metadata = kwargs.get("group_by_metadata", []) + if not isinstance(group_by_metadata, list): + raise ValueError("Expected 'group_by_metadata' to be a list.") self._vector_store._initialize() if not aggregate_records: From ce688e5f0c1860b021e7efb0d24b5814f9d1f103 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 11:00:14 +0330 Subject: [PATCH 10/11] fix: codeRabbitAI's mistake! --- bot/retrievers/retrieve_similar_nodes.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index ddd80c5..c88e178 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -57,11 +57,17 @@ def query_db( Note: This would completely disable the similarity search and it would just return the results with no ordering. default is `False`. If `True` the query will be ignored and no embedding of it would be fetched + aggregate_records : bool + aggregate records and group by a given term in `group_by_metadata` + group_by_metadata : list[str] + do grouping by some property of `metadata_` + """ ignore_sort = kwargs.get("ignore_sort", False) aggregate_records = kwargs.get("aggregate_records", False) group_by_metadata = kwargs.get("group_by_metadata", []) if not isinstance(group_by_metadata, list): raise ValueError("Expected 'group_by_metadata' to be a list.") + self._vector_store._initialize() if not aggregate_records: From 24fd48fe841447f0952da7f4881b7017b4195e0a Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 11 Apr 2024 11:04:58 +0330 Subject: [PATCH 11/11] fix: linter issues! --- bot/retrievers/retrieve_similar_nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index c88e178..e7954f9 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -67,7 +67,7 @@ def query_db( group_by_metadata = kwargs.get("group_by_metadata", []) if not isinstance(group_by_metadata, list): raise ValueError("Expected 'group_by_metadata' to be a list.") - + self._vector_store._initialize() if not aggregate_records: