From 042684e1db79593c17d53486d9fbd8279aff124f Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 31 Jan 2024 13:16:06 +0330 Subject: [PATCH 1/8] feat: Added manual filtering node fetching! - We can now fetch similar nodes by manually doing metadata filterings on database! --- bot/retrievers/retrieve_similar_nodes.py | 92 ++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 bot/retrievers/retrieve_similar_nodes.py diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py new file mode 100644 index 0000000..1906d0f --- /dev/null +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -0,0 +1,92 @@ +from llama_index.embeddings import BaseEmbedding +from llama_index.schema import NodeWithScore +from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult +from llama_index.vector_stores.postgres import DBEmbeddingRow +from sqlalchemy import select, text, and_, or_ +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding + + +class RetrieveSimilarNodes: + """Retriever similar nodes over a postgres vector store.""" + + def __init__( + self, + vector_store: PGVectorStore, + similarity_top_k: int, + embed_model: BaseEmbedding = CohereEmbedding(), + ) -> None: + """Init params.""" + self._vector_store = vector_store + self._embed_model = embed_model + self._similarity_top_k = similarity_top_k + super().__init__() + + def query_db( + self, query: str, filters: list[dict[str, str]] | None = None + ) -> list[NodeWithScore]: + """ + query database with given filters (similarity search is also done) + + Parameters + ------------- + query : str + the user question + filters : list[dict[str, str]] | None + a list of filters to apply with `or` condition + the dictionary would be applying `and` + operation between keys and values of json metadata_ + if `None` then no filtering would be applied + """ + self._vector_store._initialize() + embedding = self._embed_model.get_text_embedding(text=query) + stmt = select( # type: ignore + self._vector_store._table_class.id, + self._vector_store._table_class.node_id, + self._vector_store._table_class.text, + self._vector_store._table_class.metadata_, + self._vector_store._table_class.embedding.cosine_distance(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) + + if filters is not None: + stmt = stmt.where( + or_( + and_( + self._vector_store._table_class.metadata_.op("->>")(key) == value + for key, value in condition.items() + ) + for condition in filters + ) + ) + + stmt = stmt.limit(self._similarity_top_k) + + with self._vector_store._session() as session, session.begin(): + res = session.execute(stmt) + + results = [ + DBEmbeddingRow( + node_id=item.node_id, + text=item.text, + metadata=item.metadata_, + similarity=(1 - item.distance) if item.distance is not None else 0, + ) + for item in res.all() + ] + query_result = self._vector_store._db_rows_to_query_result(results) + nodes = self._get_nodes_with_score(query_result) + return nodes + + def _get_nodes_with_score( + self, query_result: VectorStoreQueryResult + ) -> list[NodeWithScore]: + """get nodes from a query_results""" + nodes_with_scores = [] + for index, node in enumerate(query_result.nodes): + score: float | None = None + if query_result.similarities is not None: + score = query_result.similarities[index] + nodes_with_scores.append(NodeWithScore(node=node, score=score)) + + return nodes_with_scores From 75d9b8ba6900143aa7b5e61c157c97bd82447333 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 31 Jan 2024 18:27:18 +0330 Subject: [PATCH 2/8] feat: Doing the retrieval and query engine manually! - Because the PGVectorStore didn't support combination of `AND` and `OR` conditions we needed to do it manually and now other parts had to change too! - Still I haven't didn't test the whole process in subquery generator. --- bot/retrievers/forum_summary_retriever.py | 37 ++- bot/retrievers/retrieve_similar_nodes.py | 5 +- discord_query.py | 4 +- tests/unit/test_discord_summary_retriever.py | 52 ++-- utils/query_engine/__init__.py | 2 +- .../level_based_platform_query_engine.py | 249 +++++++++--------- ...ine.py => prepare_discord_query_engine.py} | 0 7 files changed, 171 insertions(+), 178 deletions(-) rename utils/query_engine/{discord_query_engine.py => prepare_discord_query_engine.py} (100%) diff --git a/bot/retrievers/forum_summary_retriever.py b/bot/retrievers/forum_summary_retriever.py index 1e04cea..6dd56d1 100644 --- a/bot/retrievers/forum_summary_retriever.py +++ b/bot/retrievers/forum_summary_retriever.py @@ -16,16 +16,16 @@ def __init__( """ super().__init__(table_name, dbname, embedding_model=embedding_model) - def retreive_metadata( + def retreive_filtering( self, query: str, metadata_group1_key: str, metadata_group2_key: str, metadata_date_key: str, similarity_top_k: int = 20, - ) -> tuple[set[str], set[str], set[str]]: + ) -> list[dict[str, str]]: """ - retrieve the metadata information of the similar nodes with the query + retrieve filtering that can be done based on the retrieved similar nodes with the query Parameters ----------- @@ -46,28 +46,25 @@ def retreive_metadata( Returns --------- - group1_data : set[str] - the similar summary nodes having the group1_data. - can be an empty set meaning no similar thread - conversations for it was available. - group2_data : set[str] - the similar summary nodes having the group2_data. - can be an empty set meaning no similar channel - conversations for it was available. - dates : set[str] - the similar daily conversations to the given query + filters : list[dict[str, str]] + a list of filters to apply with `or` condition + the dictionary would be applying `and` + operation between keys and values of json metadata_ """ nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) - group1_data: set[str] = set() - dates: set[str] = set() - group2_data: set[str] = set() + filters: list[dict[str, str]] = [] for node in nodes: + # the filter made by given node + filter: dict[str, str] = {} if node.metadata[metadata_group1_key]: - group1_data.add(node.metadata[metadata_group1_key]) + filter[metadata_group1_key] = node.metadata[metadata_group1_key] if node.metadata[metadata_group2_key]: - group2_data.add(node.metadata[metadata_group2_key]) - dates.add(node.metadata[metadata_date_key]) + filter[metadata_group2_key] = node.metadata[metadata_group2_key] + # date filter + filter[metadata_date_key] = node.metadata[metadata_date_key] - return group1_data, group2_data, dates + filters.append(filter) + + return filters diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 1906d0f..e9b7279 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -49,11 +49,12 @@ def query_db( ), ).order_by(text("distance asc")) - if filters is not None: + if filters is not None and filters != []: stmt = stmt.where( or_( and_( - self._vector_store._table_class.metadata_.op("->>")(key) == value + self._vector_store._table_class.metadata_.op("->>")(key) + == value for key, value in condition.items() ) for condition in filters diff --git a/discord_query.py b/discord_query.py index 24e762f..4333c5c 100644 --- a/discord_query.py +++ b/discord_query.py @@ -1,7 +1,9 @@ from llama_index import QueryBundle from llama_index.schema import NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter +from utils.query_engine.prepare_discord_query_engine import ( + prepare_discord_engine_auto_filter, +) def query_discord( diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py index d5fafa3..638b8b4 100644 --- a/tests/unit/test_discord_summary_retriever.py +++ b/tests/unit/test_discord_summary_retriever.py @@ -14,8 +14,9 @@ def test_initialize_class(self): documents: list[Document] = [] all_dates: list[str] = [] + start_date = parser.parse("2023-08-01") for i in range(30): - date = parser.parse("2023-08-01") + timedelta(days=i) + date = start_date + timedelta(days=i) doc_date = date.strftime("%Y-%m-%d") doc = Document( text="SAMPLESAMPLESAMPLE", @@ -44,39 +45,32 @@ def test_initialize_class(self): dbname="sample", embedding_model=mock_embedding_model(), ) - channels, threads, dates = base_summary_search.retreive_metadata( + filters = base_summary_search.retreive_filtering( query="what is samplesample?", similarity_top_k=5, metadata_group1_key="channel", metadata_group2_key="thread", metadata_date_key="date", ) - self.assertIsInstance(threads, set) - self.assertIsInstance(channels, set) - self.assertIsInstance(dates, set) - self.assertTrue( - threads.issubset( - set( - [ - "thread0", - "thread1", - "thread2", - "thread3", - "thread4", - ] - ) - ) - ) - self.assertTrue( - channels.issubset( - set( - [ - "channel0", - "channel1", - "channel2", - ] - ) + self.assertIsInstance(filters, list) + + expected_dates = [ + (start_date + timedelta(days=i)).strftime("%Y-%m-%d") for i in range(30) + ] + for filter in filters: + self.assertIsInstance(filter, dict) + self.assertIn( + filter["thread"], + [ + "thread0", + "thread1", + "thread2", + "thread3", + "thread4", + ], ) - ) - self.assertTrue(dates.issubset(all_dates)) + self.assertIn(filter["channel"], ["channel0", "channel1", "channel2"]) + date = parser.parse("2023-08-01") + timedelta(days=i) + doc_date = date.strftime("%Y-%m-%d") + self.assertIn(filter["date"], expected_dates) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index fad06f9..01a7658 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -from .discord_query_engine import prepare_discord_engine_auto_filter +from .prepare_discord_query_engine import prepare_discord_engine_auto_filter diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index 9d0e6b5..f3310d4 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -3,53 +3,61 @@ from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever from bot.retrievers.process_dates import process_dates from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.query_engine import BaseQueryEngine -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from llama_index.query_engine import CustomQueryEngine +from llama_index.retrievers import BaseRetriever +from llama_index.response_synthesizers import ( + get_response_synthesizer, + BaseSynthesizer, +) +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from llama_index.llms import OpenAI +from llama_index.prompts import PromptTemplate +from llama_index.schema import NodeWithScore +from llama_index.schema import MetadataMode + + +qa_prompt = PromptTemplate( + "Context information is below.\n" + "---------------------\n" + "{context_str}\n" + "---------------------\n" + "Given the context information and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " +) + + +class LevelBasedPlatformQueryEngine(CustomQueryEngine): + retriever: BaseRetriever + response_synthesizer: BaseSynthesizer + llm: OpenAI + qa_prompt: PromptTemplate + + def custom_query(self, query_str: str): + """Doing custom query""" + # first retrieving similar nodes in summary + retriever = RetrieveSimilarNodes( + self._vector_store, + self._similarity_top_k, + ) + similar_nodes = retriever.query_db(query=query_str, filters=self._filters) + context_str = self._prepare_context_str(similar_nodes) + fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) + response = self.llm.complete(fmt_qa_prompt) + return str(response) -class LevelBasedPlatformQueryEngine: - def __init__( - self, - level1_key: str, - level2_key: str, - platform_table_name: str, - date_key: str = "date", - ) -> None: - """ - A two level based platform query engine preparation tools. - - Parameters - ------------ - level1_key : str - first hierarchy of the discussion. - the platforms can be discord or discourse. for example in discord - the level1 is `channel` and in discourse it can be `category` - level2_key : str - the second level of discussion in the hierarchy. - For example in discord level2 is `thread`, - and on discourse level2 would be `topic` - platform_table_name : str - the postgresql table name for the platform. Can be only the platform name - as `discord` or `discourse` - date_key : str - the day key which the date is saved under the field in postgresql table. - for default is is `date` which was the one that we used previously - """ - self.level1_key = level1_key - self.level2_key = level2_key - self.platform_table_name = platform_table_name - self.date_key = date_key - + @classmethod def prepare_platform_engine( - self, + cls, community_id: str, - level1_names: list[str], - level2_names: list[str], - days: list[str], - **kwarg, - ) -> BaseQueryEngine: + platform_table_name: str, + filters: list[dict[str, str]] | None = None, + testing=False, + ) -> "LevelBasedPlatformQueryEngine": """ query the platform database using filters given and give an anwer to the given query using the LLM @@ -58,20 +66,18 @@ def prepare_platform_engine( ------------ community_id : str the community id data to query - query : str - the query (question) of the user - level1_names : list[str] - the given categorys to search for - level2_names : list[str] - the given topics to search for - days : list[str] - the given days to search for - ** kwargs : - similarity_top_k : int | None - the k similar results to use when querying the data - if not given, will load from `.env` file - testing : bool - whether to setup the PGVectorAccess in testing mode + platform_table_name : str + the postgresql table name for the platform. Can be only the platform name + as `discord` or `discourse` + filters : list[dict[str, str]] | None + the list of filters to be applied when retrieving data + if `None` then set no filtering on PGVectorStore + testing : bool + if `True` it is in test phase and nothing must be changed + similarity_top_k : int | None + the k similar results to use when querying the data + if not given, will load from `.env` file + Returns --------- @@ -80,58 +86,39 @@ def prepare_platform_engine( """ dbname = f"community_{community_id}" - testing = kwarg.get("testing", False) - similarity_top_k = kwarg.get("similarity_top_k", None) - pg_vector = PGVectorAccess( - table_name=self.platform_table_name, + table_name=platform_table_name, dbname=dbname, testing=testing, embed_model=CohereEmbedding(), ) index = pg_vector.load_index() - if similarity_top_k is None: - _, similarity_top_k, _ = load_hyperparams() - - level2_filters: list[ExactMatchFilter] = [] - level1_filters: list[ExactMatchFilter] = [] - day_filters: list[ExactMatchFilter] = [] - - for level1 in level1_names: - level1_name_value = level1.replace("'", "''") - level1_filters.append( - ExactMatchFilter(key=self.level1_key, value=level1_name_value) - ) - - for level2 in level2_names: - levle2_value = level2.replace("'", "''") - level2_filters.append( - ExactMatchFilter(key=self.level2_key, value=levle2_value) - ) - - for day in days: - day_filters.append(ExactMatchFilter(key=self.date_key, value=day)) - - all_filters: list[ExactMatchFilter] = [] - all_filters.extend(level1_filters) - all_filters.extend(level2_filters) - all_filters.extend(day_filters) - - filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) - - query_engine = index.as_query_engine( - filters=filters, similarity_top_k=similarity_top_k + _, similarity_top_k, _ = load_hyperparams() + + cls._vector_store = index.vector_store + cls._similarity_top_k = similarity_top_k + cls._filters = filters + + llm = OpenAI("gpt-3.5-turbo") + synthesizer = get_response_synthesizer(response_mode="compact") + retriever = index.as_retriever() + return cls( + retriever=retriever, + response_synthesizer=synthesizer, + llm=llm, + qa_prompt=qa_prompt, ) - return query_engine - + @classmethod def prepare_engine_auto_filter( - self, + cls, community_id: str, query: str, - similarity_top_k: int | None = None, - d: int | None = None, - ) -> BaseQueryEngine: + platform_table_name: str, + level1_key: str, + level2_key: str, + date_key: str = "date", + ) -> "LevelBasedPlatformQueryEngine": """ get the query engine and do the filtering automatically. By automatically we mean, it would first query the summaries @@ -143,14 +130,20 @@ def prepare_engine_auto_filter( the community id to process its platform data query : str the query (question) of the user - similarity_top_k : int | None - the value for the initial summary search - to get the `k2` count simliar nodes - if `None`, then would read from `.env` - d : int - this would make the secondary search (`prepare_discourse_engine`) - to be done on the `metadata.date - d` to `metadata.date + d` - + platform_table_name : str + the postgresql table name for the platform. Can be only the platform name + as `discord` or `discourse` + level1_key : str + first hierarchy of the discussion. + the platforms can be discord or discourse. for example in discord + the level1 is `channel` and in discourse it can be `category` + level2_key : str + the second level of discussion in the hierarchy. + For example in discord level2 is `thread`, + and on discourse level2 would be `topic` + date_key : str + the day key which the date is saved under the field in postgresql table. + for default is is `date` which was the one that we used previously Returns --------- @@ -159,37 +152,43 @@ def prepare_engine_auto_filter( """ dbname = f"community_{community_id}" - if d is None: - _, _, d = load_hyperparams() - if similarity_top_k is None: - similarity_top_k, _, _ = load_hyperparams() + summary_similarity_top_k, _, d = load_hyperparams() + # For summaries data a posfix `summary` would be added platform_retriever = ForumBasedSummaryRetriever( - table_name=self.platform_table_name, dbname=dbname + table_name=platform_table_name + "_summary", dbname=dbname ) - level1_names, level2_names, dates = platform_retriever.retreive_metadata( + filters = platform_retriever.retreive_filtering( query=query, - metadata_group1_key=self.level1_key, - metadata_group2_key=self.level2_key, - metadata_date_key=self.date_key, - similarity_top_k=similarity_top_k, + metadata_group1_key=level1_key, + metadata_group2_key=level2_key, + metadata_date_key=date_key, + similarity_top_k=summary_similarity_top_k, ) + # getting all the metadata dates from filters + dates: list[str] = [f[date_key] for f in filters] dates_modified = process_dates(list(dates), d) + dates_filter = [{date_key: date} for date in dates_modified] + filters.extend(dates_filter) - logging.info( - f"COMMUNITY_ID: {community_id} | " - f"summary retrieved {self.date_key}: {dates_modified} | " - f"summary retrieved {self.level1_key}: {list(level1_names)} | " - f"summary retrieved {self.level2_key}: {list(level2_names)}" - ) + logging.info(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") - engine = self.prepare_platform_engine( + engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=community_id, - query=query, - level1_names=list(level1_names), - level2_names=list(level2_names), - days=dates_modified, + platform_table_name=platform_table_name, + filters=filters, ) return engine + + def _prepare_context_str(self, nodes: list[NodeWithScore]) -> str: + context_str = "\n\n".join( + [ + node.get_content() + + "\n" + + node.node.get_metadata_str(mode=MetadataMode.LLM) + for node in nodes + ] + ) + return context_str diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/prepare_discord_query_engine.py similarity index 100% rename from utils/query_engine/discord_query_engine.py rename to utils/query_engine/prepare_discord_query_engine.py From 898cbe5aaecbe2f309256cbd3cb8fd1b320b274d Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 10:32:29 +0330 Subject: [PATCH 3/8] feat: Using custom query engine and retriever and synthesizer! --- bot/retrievers/retrieve_similar_nodes.py | 35 +++++--- subquery.py | 4 +- .../unit/test_prepare_discord_query_engine.py | 2 +- .../level_based_platform_query_engine.py | 25 +++++- .../prepare_discord_query_engine.py | 79 ++++--------------- 5 files changed, 63 insertions(+), 82 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index e9b7279..95c88d1 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -2,7 +2,7 @@ from llama_index.schema import NodeWithScore from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult from llama_index.vector_stores.postgres import DBEmbeddingRow -from sqlalchemy import select, text, and_, or_ +from sqlalchemy import select, text, and_, or_, Date, cast from tc_hivemind_backend.embeddings.cohere import CohereEmbedding @@ -50,18 +50,29 @@ def query_db( ).order_by(text("distance asc")) if filters is not None and filters != []: - stmt = stmt.where( - or_( - and_( - self._vector_store._table_class.metadata_.op("->>")(key) - == value - for key, value in condition.items() - ) - for condition in filters - ) - ) + conditions = [] + for condition in filters: + filters_and = [] + for key, value in condition.items(): + if key == "date": + # Apply ::date cast when the key is 'date' + filter_condition = cast( + self._vector_store._table_class.metadata_.op("->>")(key), + Date, + ) == cast(value, Date) + else: + filter_condition = ( + self._vector_store._table_class.metadata_.op("->>")(key) + == value + ) + + filters_and.append(filter_condition) + + conditions.append(and_(*filters_and)) + + stmt = stmt.where(or_(*conditions)) - stmt = stmt.limit(self._similarity_top_k) + stmt = stmt.limit(self._similarity_top_k) with self._vector_store._session() as session, session.begin(): res = session.execute(stmt) diff --git a/subquery.py b/subquery.py index 44a8a86..6270823 100644 --- a/subquery.py +++ b/subquery.py @@ -72,8 +72,6 @@ def query_multiple_source( discord_query_engine = prepare_discord_engine_auto_filter( community_id, query, - similarity_top_k=None, - d=None, ) tool_metadata = ToolMetadata( name="Discord", @@ -100,7 +98,7 @@ def query_multiple_source( raise NotImplementedError question_gen = GuidanceQuestionGenerator.from_defaults( - guidance_llm=OpenAIChat("gpt-3.5-turbo"), + guidance_llm=OpenAIChat("gpt-4"), verbose=False, ) embed_model = CohereEmbedding() diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py index 45cc348..11ccaca 100644 --- a/tests/unit/test_prepare_discord_query_engine.py +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -3,7 +3,7 @@ from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters -from utils.query_engine.discord_query_engine import prepare_discord_engine +from utils.query_engine.prepare_discord_query_engine import prepare_discord_engine class TestPrepareDiscordEngine(unittest.TestCase): diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index f3310d4..cdca349 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -48,6 +48,7 @@ def custom_query(self, query_str: str): context_str = self._prepare_context_str(similar_nodes) fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) response = self.llm.complete(fmt_qa_prompt) + # logging.info(f"fmt_qa_prompt {fmt_qa_prompt}") return str(response) @classmethod @@ -57,6 +58,7 @@ def prepare_platform_engine( platform_table_name: str, filters: list[dict[str, str]] | None = None, testing=False, + **kwargs, ) -> "LevelBasedPlatformQueryEngine": """ query the platform database using filters given @@ -77,6 +79,17 @@ def prepare_platform_engine( similarity_top_k : int | None the k similar results to use when querying the data if not given, will load from `.env` file + **kwargs : + llm : llama-index.LLM + the LLM to use answering queries + default is gpt-3.5-turbo + synthesizer : llama_index.response_synthesizers.base.BaseSynthesizer + the synthesizers to use when creating the prompt + default is to get from `get_response_synthesizer(response_mode="compact")` + qa_prompt : llama-index.prompts.PromptTemplate + the Q&A prompt to use + default would be the default prompt of llama-index + Returns @@ -86,6 +99,12 @@ def prepare_platform_engine( """ dbname = f"community_{community_id}" + synthesizer = kwargs.get( + "synthesizer", get_response_synthesizer(response_mode="compact") + ) + llm = kwargs.get("llm", OpenAI("gpt-3.5-turbo")) + qa_prompt_ = kwargs.get("qa_prompt", qa_prompt) + pg_vector = PGVectorAccess( table_name=platform_table_name, dbname=dbname, @@ -93,20 +112,18 @@ def prepare_platform_engine( embed_model=CohereEmbedding(), ) index = pg_vector.load_index() + retriever = index.as_retriever() _, similarity_top_k, _ = load_hyperparams() cls._vector_store = index.vector_store cls._similarity_top_k = similarity_top_k cls._filters = filters - llm = OpenAI("gpt-3.5-turbo") - synthesizer = get_response_synthesizer(response_mode="compact") - retriever = index.as_retriever() return cls( retriever=retriever, response_synthesizer=synthesizer, llm=llm, - qa_prompt=qa_prompt, + qa_prompt=qa_prompt_, ) @classmethod diff --git a/utils/query_engine/prepare_discord_query_engine.py b/utils/query_engine/prepare_discord_query_engine.py index 589fdd4..89338b0 100644 --- a/utils/query_engine/prepare_discord_query_engine.py +++ b/utils/query_engine/prepare_discord_query_engine.py @@ -5,10 +5,8 @@ def prepare_discord_engine( community_id: str, - thread_names: list[str], - channel_names: list[str], - days: list[str], - **kwarg, + filters: list[dict[str, str]], + **kwargs, ) -> BaseQueryEngine: """ query the platform database using filters given @@ -20,16 +18,10 @@ def prepare_discord_engine( the discord community id data to query query : str the query (question) of the user - level1_names : list[str] - the given categorys to search for - level2_names : list[str] - the given topics to search for - days : list[str] - the given days to search for + filters : list[dict[str, str]] | None + the list of filters to be applied when retrieving data + if `None` then set no filtering on PGVectorStore ** kwargs : - similarity_top_k : int | None - the k similar results to use when querying the data - if not given, will load from `.env` file testing : bool whether to setup the PGVectorAccess in testing mode @@ -38,15 +30,13 @@ def prepare_discord_engine( query_engine : BaseQueryEngine the created query engine with the filters """ - query_engine_preparation = get_discord_level_based_platform_query_engine( - table_name="discord", - ) - query_engine = query_engine_preparation.prepare_platform_engine( + + testing = kwargs.get("testing", False) + query_engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=community_id, - level1_names=thread_names, - level2_names=channel_names, - days=days, - **kwarg, + platform_table_name="discord", + filters=filters, + testing=testing, ) return query_engine @@ -54,8 +44,6 @@ def prepare_discord_engine( def prepare_discord_engine_auto_filter( community_id: str, query: str, - similarity_top_k: int | None = None, - d: int | None = None, ) -> BaseQueryEngine: """ get the query engine and do the filtering automatically. @@ -68,13 +56,6 @@ def prepare_discord_engine_auto_filter( the discord community data to query query : str the query (question) of the user - similarity_top_k : int | None - the value for the initial summary search - to get the `k2` count similar nodes - if `None`, then would read from `.env` - d : int - this would make the secondary search (`prepare_discord_engine`) - to be done on the `metadata.date - d` to `metadata.date + d` Returns @@ -83,38 +64,12 @@ def prepare_discord_engine_auto_filter( the created query engine with the filters """ - query_engine_preparation = get_discord_level_based_platform_query_engine( - table_name="discord_summary" - ) - query_engine = query_engine_preparation.prepare_engine_auto_filter( + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( community_id=community_id, query=query, - similarity_top_k=similarity_top_k, - d=d, - ) - - return query_engine - - -def get_discord_level_based_platform_query_engine( - table_name: str, -) -> LevelBasedPlatformQueryEngine: - """ - perpare the `LevelBasedPlatformQueryEngine` to use - - Parameters - ----------- - table_name : str - the postgresql data table to use - - Returns - --------- - level_based_query_engine : LevelBasedPlatformQueryEngine - the query engine creator class - """ - level_based_query_engine = LevelBasedPlatformQueryEngine( - level1_key="thread", - level2_key="channel", - platform_table_name=table_name, + platform_table_name="discord", + level1_key="channel", + level2_key="thread", + date_key="date", ) - return level_based_query_engine + return engine From cbf77c0ee36b2b5d1f406e5703e0054e78619f62 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 11:39:36 +0330 Subject: [PATCH 4/8] fix: Updated due to failing test cases! --- .../test_level_based_platform_query_engine.py | 73 ++++------- .../unit/test_prepare_discord_query_engine.py | 31 ++--- .../test_prepare_discourse_query_engine.py | 33 ++--- utils/query_engine/discourse_query_engine.py | 118 ------------------ .../level_based_platform_query_engine.py | 4 +- .../prepare_discord_query_engine.py | 1 + .../prepare_discourse_query_engine.py | 73 +++++++++++ 7 files changed, 120 insertions(+), 213 deletions(-) delete mode 100644 utils/query_engine/discourse_query_engine.py create mode 100644 utils/query_engine/prepare_discourse_query_engine.py diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py index dfde10a..eaa9986 100644 --- a/tests/unit/test_level_based_platform_query_engine.py +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -1,3 +1,4 @@ +import os import unittest from unittest.mock import patch @@ -17,67 +18,47 @@ def setUp(self): self.level2_key = "thread" self.platform_table_name = "discord" self.date_key = "date" - self.engine = LevelBasedPlatformQueryEngine( - level1_key=self.level1_key, - level2_key=self.level2_key, - platform_table_name=self.platform_table_name, - date_key=self.date_key, - ) + os.environ["OPENAI_API_KEY"] = "sk-some_creds" def test_prepare_platform_engine(self): """ Test prepare_platform_engine method with sample data """ - level1_names = ["general"] - level2_names = ["discussion"] - days = ["2022-01-01"] - query_engine = self.engine.prepare_platform_engine( + # the output should always have a `date` key for each dictionary + filters = [ + {"channel": "general", "date": "2023-01-02"}, + {"thread": "discussion", "date": "2024-01-03"}, + {"date": "2022-01-01"}, + ] + + engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( community_id=self.community_id, - level1_names=level1_names, - level2_names=level2_names, - days=days, + platform_table_name=self.platform_table_name, + filters=filters, + testing=True, ) - self.assertIsNotNone(query_engine) + self.assertIsNotNone(engine) def test_prepare_engine_auto_filter(self): """ Test prepare_engine_auto_filter method with sample data """ with patch.object( - ForumBasedSummaryRetriever, "retreive_metadata" + ForumBasedSummaryRetriever, "retreive_filtering" ) as mock_retriever: - mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"]) - query_engine = self.engine.prepare_engine_auto_filter( - community_id=self.community_id, query="test query" - ) - self.assertIsNotNone(query_engine) + # the output should always have a `date` key for each dictionary + mock_retriever.return_value = [ + {"channel": "general", "date": "2023-01-02"}, + {"thread": "discussion", "date": "2024-01-03"}, + {"date": "2022-01-01"}, + ] - def test_prepare_engine_auto_filter_with_d(self): - """ - Test prepare_engine_auto_filter method with a specific value for d - """ - with patch.object( - ForumBasedSummaryRetriever, "retreive_metadata" - ) as mock_retriever: - mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"]) - query_engine = self.engine.prepare_engine_auto_filter( - community_id=self.community_id, - query="test query", - d=7, # Use a specific value for d - ) - self.assertIsNotNone(query_engine) - - def test_prepare_engine_auto_filter_with_similarity_top_k(self): - """ - Test prepare_engine_auto_filter method with a specific value for similarity_top_k - """ - with patch.object( - ForumBasedSummaryRetriever, "retreive_metadata" - ) as mock_retriever: - mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"]) - query_engine = self.engine.prepare_engine_auto_filter( + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( community_id=self.community_id, query="test query", - similarity_top_k=10, # Use a specific value for similarity_top_k + platform_table_name=self.platform_table_name, + level1_key=self.level1_key, + level2_key=self.level2_key, + date_key=self.date_key, ) - self.assertIsNotNone(query_engine) + self.assertIsNotNone(engine) diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py index 11ccaca..a8f08b7 100644 --- a/tests/unit/test_prepare_discord_query_engine.py +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -14,37 +14,22 @@ def setUp(self): os.environ["K1_RETRIEVER_SEARCH"] = "20" os.environ["K2_RETRIEVER_SEARCH"] = "5" os.environ["D_RETRIEVER_SEARCH"] = "3" + os.environ["OPENAI_API_KEY"] = "sk-some_creds" def test_prepare_discord_engine(self): community_id = "123456" - thread_names = ["thread1", "thread2"] - channel_names = ["channel1", "channel2"] - days = ["2022-01-01", "2022-01-02"] + filters = [ + {"channel": "general", "date": "2023-01-02"}, + {"thread": "discussion", "date": "2024-01-03"}, + {"date": "2022-01-01"}, + ] # Call the function query_engine = prepare_discord_engine( community_id, - thread_names, - channel_names, - days, + filters=filters, testing=True, ) - # Assertions + self.assertIsNotNone(query_engine) self.assertIsInstance(query_engine, BaseQueryEngine) - - expected_filter = MetadataFilters( - filters=[ - ExactMatchFilter(key="thread", value="thread1"), - ExactMatchFilter(key="thread", value="thread2"), - ExactMatchFilter(key="channel", value="channel1"), - ExactMatchFilter(key="channel", value="channel2"), - ExactMatchFilter(key="date", value="2022-01-01"), - ExactMatchFilter(key="date", value="2022-01-02"), - ], - condition=FilterCondition.OR, - ) - - self.assertEqual(query_engine.retriever._filters, expected_filter) - # this is the secondary search, so K2 should be for this - self.assertEqual(query_engine.retriever._similarity_top_k, 5) diff --git a/tests/unit/test_prepare_discourse_query_engine.py b/tests/unit/test_prepare_discourse_query_engine.py index 6248dd8..6154693 100644 --- a/tests/unit/test_prepare_discourse_query_engine.py +++ b/tests/unit/test_prepare_discourse_query_engine.py @@ -3,7 +3,7 @@ from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters -from utils.query_engine.discourse_query_engine import prepare_discourse_engine +from utils.query_engine.prepare_discourse_query_engine import prepare_discourse_engine class TestPrepareDiscourseEngine(unittest.TestCase): @@ -14,37 +14,22 @@ def setUp(self): os.environ["K1_RETRIEVER_SEARCH"] = "20" os.environ["K2_RETRIEVER_SEARCH"] = "5" os.environ["D_RETRIEVER_SEARCH"] = "3" + os.environ["OPENAI_API_KEY"] = "sk-some_creds" def test_prepare_discourse_engine(self): community_id = "123456" - topic_names = ["topic1", "topic2"] - category_names = ["category1", "category2"] - days = ["2022-01-01", "2022-01-02"] + filters = [ + {"category": "general", "date": "2023-01-02"}, + {"topic": "discussion", "date": "2024-01-03"}, + {"date": "2022-01-01"}, + ] # Call the function query_engine = prepare_discourse_engine( community_id=community_id, - category_names=category_names, - topic_names=topic_names, - days=days, + filters=filters, testing=True, ) - # Assertions + self.assertIsNotNone(query_engine) self.assertIsInstance(query_engine, BaseQueryEngine) - - expected_filter = MetadataFilters( - filters=[ - ExactMatchFilter(key="category", value="category1"), - ExactMatchFilter(key="category", value="category2"), - ExactMatchFilter(key="topic", value="topic1"), - ExactMatchFilter(key="topic", value="topic2"), - ExactMatchFilter(key="date", value="2022-01-01"), - ExactMatchFilter(key="date", value="2022-01-02"), - ], - condition=FilterCondition.OR, - ) - - self.assertEqual(query_engine.retriever._filters, expected_filter) - # this is the secondary search, so K2 should be for this - self.assertEqual(query_engine.retriever._similarity_top_k, 5) diff --git a/utils/query_engine/discourse_query_engine.py b/utils/query_engine/discourse_query_engine.py deleted file mode 100644 index ca765b1..0000000 --- a/utils/query_engine/discourse_query_engine.py +++ /dev/null @@ -1,118 +0,0 @@ -from llama_index.query_engine import BaseQueryEngine - -from .level_based_platform_query_engine import LevelBasedPlatformQueryEngine - - -def prepare_discourse_engine( - community_id: str, - category_names: list[str], - topic_names: list[str], - days: list[str], - **kwarg, -) -> BaseQueryEngine: - """ - query the discourse database using filters given - and give an anwer to the given query using the LLM - - Parameters - ------------ - community_id : str - the discourse community data to query - query : str - the query (question) of the user - category_names : list[str] - the given categorys to search for - topic_names : list[str] - the given topics to search for - days : list[str] - the given days to search for - similarity_top_k : int | None - the k similar results to use when querying the data - if `None` will load from `.env` file - ** kwargs : - testing : bool - whether to setup the PGVectorAccess in testing mode - - Returns - --------- - query_engine : BaseQueryEngine - the created query engine with the filters - """ - level_based_query_engine = get_discourse_level_based_platform_query_engine( - table_name="discourse", - ) - - query_engine = level_based_query_engine.prepare_platform_engine( - community_id=community_id, - level1_names=category_names, - level2_names=topic_names, - days=days, - **kwarg, - ) - - return query_engine - - -def prepare_discourse_engine_auto_filter( - community_id: str, - query: str, - similarity_top_k: int | None = None, - d: int | None = None, -) -> BaseQueryEngine: - """ - get the query engine and do the filtering automatically. - By automatically we mean, it would first query the summaries - to get the metadata filters - - Parameters - ----------- - guild_id : str - the discourse guild data to query - query : str - the query (question) of the user - similarity_top_k : int | None - the value for the initial summary search - to get the `k2` count simliar nodes - if `None`, then would read from `.env` - d : int - this would make the secondary search (`prepare_discourse_engine`) - to be done on the `metadata.date - d` to `metadata.date + d` - - Returns - --------- - query_engine : BaseQueryEngine - the created query engine with the filters - """ - level_based_query_engine = get_discourse_level_based_platform_query_engine( - table_name="discourse_summary" - ) - - query_engine = level_based_query_engine.prepare_engine_auto_filter( - community_id=community_id, - query=query, - similarity_top_k=similarity_top_k, - d=d, - ) - return query_engine - - -def get_discourse_level_based_platform_query_engine( - table_name: str, -) -> LevelBasedPlatformQueryEngine: - """ - perpare the `LevelBasedPlatformQueryEngine` to use - - Parameters - ----------- - table_name : str - the postgresql data table to use - - Returns - --------- - level_based_query_engine : LevelBasedPlatformQueryEngine - the query engine creator class - """ - level_based_query_engine = LevelBasedPlatformQueryEngine( - level1_key="category", level2_key="topic", platform_table_name=table_name - ) - return level_based_query_engine diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index cdca349..a6c4a65 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -90,8 +90,6 @@ def prepare_platform_engine( the Q&A prompt to use default would be the default prompt of llama-index - - Returns --------- query_engine : BaseQueryEngine @@ -147,6 +145,8 @@ def prepare_engine_auto_filter( the community id to process its platform data query : str the query (question) of the user + this query would be used for filters preparation + which filters are based on available summaries. platform_table_name : str the postgresql table name for the platform. Can be only the platform name as `discord` or `discourse` diff --git a/utils/query_engine/prepare_discord_query_engine.py b/utils/query_engine/prepare_discord_query_engine.py index 89338b0..ac185a9 100644 --- a/utils/query_engine/prepare_discord_query_engine.py +++ b/utils/query_engine/prepare_discord_query_engine.py @@ -56,6 +56,7 @@ def prepare_discord_engine_auto_filter( the discord community data to query query : str the query (question) of the user + this query will be used to fetch the filters from similar summaries nodes Returns diff --git a/utils/query_engine/prepare_discourse_query_engine.py b/utils/query_engine/prepare_discourse_query_engine.py new file mode 100644 index 0000000..3c9bbd7 --- /dev/null +++ b/utils/query_engine/prepare_discourse_query_engine.py @@ -0,0 +1,73 @@ +from llama_index.query_engine import BaseQueryEngine + +from .level_based_platform_query_engine import LevelBasedPlatformQueryEngine + + +def prepare_discourse_engine( + community_id: str, + filters: list[dict[str, str]], + **kwargs, +) -> BaseQueryEngine: + """ + query the discourse database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + community_id : str + the discourse community data to query + filters : list[dict[str, str]] | None + the list of filters to be applied when retrieving data + if `None` then set no filtering on PGVectorStore + ** kwargs : + testing : bool + whether to setup the PGVectorAccess in testing mode + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + testing = kwargs.get("testing", False) + query_engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( + community_id=community_id, + platform_table_name="discourse", + filters=filters, + testing=testing, + ) + + return query_engine + + +def prepare_discourse_engine_auto_filter( + community_id: str, + query: str, +) -> BaseQueryEngine: + """ + get the query engine and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discourse guild data to query + query : str + the query (question) of the user + this query will be used to fetch the filters from similar summaries nodes + + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter( + community_id=community_id, + query=query, + platform_table_name="discourse", + level1_key="category", + level2_key="topic", + date_key="date", + ) + return engine From 2429ee38be9c91ac0ce18c5530efe1c673070b64 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 11:48:46 +0330 Subject: [PATCH 5/8] fix: linter issues! --- bot/retrievers/retrieve_similar_nodes.py | 2 +- tests/unit/test_prepare_discord_query_engine.py | 1 - tests/unit/test_prepare_discourse_query_engine.py | 1 - .../level_based_platform_query_engine.py | 15 +++++---------- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index 95c88d1..e5f2963 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -2,7 +2,7 @@ from llama_index.schema import NodeWithScore from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult from llama_index.vector_stores.postgres import DBEmbeddingRow -from sqlalchemy import select, text, and_, or_, Date, cast +from sqlalchemy import Date, and_, cast, or_, select, text from tc_hivemind_backend.embeddings.cohere import CohereEmbedding diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py index a8f08b7..22ef04e 100644 --- a/tests/unit/test_prepare_discord_query_engine.py +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -2,7 +2,6 @@ import unittest from llama_index.core.base_query_engine import BaseQueryEngine -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from utils.query_engine.prepare_discord_query_engine import prepare_discord_engine diff --git a/tests/unit/test_prepare_discourse_query_engine.py b/tests/unit/test_prepare_discourse_query_engine.py index 6154693..6f47f18 100644 --- a/tests/unit/test_prepare_discourse_query_engine.py +++ b/tests/unit/test_prepare_discourse_query_engine.py @@ -2,7 +2,6 @@ import unittest from llama_index.core.base_query_engine import BaseQueryEngine -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters from utils.query_engine.prepare_discourse_query_engine import prepare_discourse_engine diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index a6c4a65..cf95393 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -2,21 +2,16 @@ from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever from bot.retrievers.process_dates import process_dates +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.llms import OpenAI +from llama_index.prompts import PromptTemplate from llama_index.query_engine import CustomQueryEngine +from llama_index.response_synthesizers import get_response_synthesizer, BaseSynthesizer from llama_index.retrievers import BaseRetriever -from llama_index.response_synthesizers import ( - get_response_synthesizer, - BaseSynthesizer, -) -from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import MetadataMode, NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding from tc_hivemind_backend.pg_vector_access import PGVectorAccess -from llama_index.llms import OpenAI -from llama_index.prompts import PromptTemplate -from llama_index.schema import NodeWithScore -from llama_index.schema import MetadataMode - qa_prompt = PromptTemplate( "Context information is below.\n" From ce46c0357b8b21516f9514bf6bfe7f67d05ba686 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 11:53:33 +0330 Subject: [PATCH 6/8] fix: isort linter issue! --- utils/query_engine/level_based_platform_query_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py index cf95393..6d27c01 100644 --- a/utils/query_engine/level_based_platform_query_engine.py +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -7,7 +7,7 @@ from llama_index.llms import OpenAI from llama_index.prompts import PromptTemplate from llama_index.query_engine import CustomQueryEngine -from llama_index.response_synthesizers import get_response_synthesizer, BaseSynthesizer +from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.retrievers import BaseRetriever from llama_index.schema import MetadataMode, NodeWithScore from tc_hivemind_backend.embeddings.cohere import CohereEmbedding From e08197a4c34a12cb5d6ffb393f9326c7e52ba5a2 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 12:47:55 +0330 Subject: [PATCH 7/8] Update: Adding more test case to improve coverage! --- bot/retrievers/retrieve_similar_nodes.py | 1 - bot/retrievers/summary_retriever_base.py | 4 +-- .../test_retrieve_similar_nodes.py | 36 +++++++++++++++++++ tests/unit/test_summary_retriever_base.py | 10 ++++++ 4 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 tests/integration/test_retrieve_similar_nodes.py diff --git a/bot/retrievers/retrieve_similar_nodes.py b/bot/retrievers/retrieve_similar_nodes.py index e5f2963..baac30e 100644 --- a/bot/retrievers/retrieve_similar_nodes.py +++ b/bot/retrievers/retrieve_similar_nodes.py @@ -19,7 +19,6 @@ def __init__( self._vector_store = vector_store self._embed_model = embed_model self._similarity_top_k = similarity_top_k - super().__init__() def query_db( self, query: str, filters: list[dict[str, str]] | None = None diff --git a/bot/retrievers/summary_retriever_base.py b/bot/retrievers/summary_retriever_base.py index 6160c17..e3a183b 100644 --- a/bot/retrievers/summary_retriever_base.py +++ b/bot/retrievers/summary_retriever_base.py @@ -63,13 +63,13 @@ def get_similar_nodes( return nodes def _setup_index( - self, table_name: str, dbname: str, embedding_model: BaseEmbedding + self, table_name: str, dbname: str, embedding_model: BaseEmbedding, testing: bool = False, ) -> VectorStoreIndex: """ setup the llama_index VectorStoreIndex """ pg_vector_access = PGVectorAccess( - table_name=table_name, dbname=dbname, embed_model=embedding_model + table_name=table_name, dbname=dbname, embed_model=embedding_model, testing=testing ) index = pg_vector_access.load_index() return index diff --git a/tests/integration/test_retrieve_similar_nodes.py b/tests/integration/test_retrieve_similar_nodes.py new file mode 100644 index 0000000..1e4d7e3 --- /dev/null +++ b/tests/integration/test_retrieve_similar_nodes.py @@ -0,0 +1,36 @@ +from unittest import TestCase +from unittest.mock import MagicMock + +from collections import namedtuple + +from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes +from llama_index.schema import NodeWithScore, TextNode + + +class TestRetrieveSimilarNodes(TestCase): + def setUp(self): + self.table_name = "sample_table" + self.dbname = "community_some_id" + + self.vector_store = MagicMock() + self.embed_model = MagicMock() + self.retriever = RetrieveSimilarNodes( + vector_store=self.vector_store, + similarity_top_k=5, + embed_model=self.embed_model + ) + + def test_init(self): + self.assertEqual(self.retriever._similarity_top_k, 5) + self.assertEqual(self.vector_store, self.retriever._vector_store) + + def test_get_nodes_with_score(self): + # Test the _get_nodes_with_score private method + query_result = MagicMock() + query_result.nodes = [TextNode(), TextNode(), TextNode()] + query_result.similarities = [0.8, 0.9, 0.7] + + result = self.retriever._get_nodes_with_score(query_result) + + self.assertEqual(len(result), 3) + self.assertAlmostEqual(result[0].score, 0.8, delta=0.001) diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py index 14180ac..7b333a7 100644 --- a/tests/unit/test_summary_retriever_base.py +++ b/tests/unit/test_summary_retriever_base.py @@ -28,3 +28,13 @@ def test_initialize_class(self): nodes = base_summary_search.get_similar_nodes(query="what is samplesample?") self.assertIsInstance(nodes, list) self.assertIsInstance(nodes[0], NodeWithScore) + + def test_setup_index(self): + table_name = "your_table_name" + dbname = "your_db_name" + embedding_model = MagicMock() + search_instance = BaseSummarySearch(table_name, dbname, embedding_model) + + index = search_instance._setup_index(table_name, dbname, embedding_model, testing=True) + self.assertIsNotNone(index) + self.assertIsInstance(index, VectorStoreIndex) \ No newline at end of file From f3b6b0d948b2141ec7b8fbe7e03966a99bb12c11 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 1 Feb 2024 12:58:08 +0330 Subject: [PATCH 8/8] fix: linter issues! --- bot/retrievers/summary_retriever_base.py | 11 +++++++++-- tests/integration/test_retrieve_similar_nodes.py | 8 +++----- tests/unit/test_summary_retriever_base.py | 6 ++++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/bot/retrievers/summary_retriever_base.py b/bot/retrievers/summary_retriever_base.py index e3a183b..d501c69 100644 --- a/bot/retrievers/summary_retriever_base.py +++ b/bot/retrievers/summary_retriever_base.py @@ -63,13 +63,20 @@ def get_similar_nodes( return nodes def _setup_index( - self, table_name: str, dbname: str, embedding_model: BaseEmbedding, testing: bool = False, + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding, + testing: bool = False, ) -> VectorStoreIndex: """ setup the llama_index VectorStoreIndex """ pg_vector_access = PGVectorAccess( - table_name=table_name, dbname=dbname, embed_model=embedding_model, testing=testing + table_name=table_name, + dbname=dbname, + embed_model=embedding_model, + testing=testing, ) index = pg_vector_access.load_index() return index diff --git a/tests/integration/test_retrieve_similar_nodes.py b/tests/integration/test_retrieve_similar_nodes.py index 1e4d7e3..5ee611a 100644 --- a/tests/integration/test_retrieve_similar_nodes.py +++ b/tests/integration/test_retrieve_similar_nodes.py @@ -1,10 +1,8 @@ from unittest import TestCase from unittest.mock import MagicMock -from collections import namedtuple - from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import TextNode class TestRetrieveSimilarNodes(TestCase): @@ -17,7 +15,7 @@ def setUp(self): self.retriever = RetrieveSimilarNodes( vector_store=self.vector_store, similarity_top_k=5, - embed_model=self.embed_model + embed_model=self.embed_model, ) def test_init(self): @@ -29,7 +27,7 @@ def test_get_nodes_with_score(self): query_result = MagicMock() query_result.nodes = [TextNode(), TextNode(), TextNode()] query_result.similarities = [0.8, 0.9, 0.7] - + result = self.retriever._get_nodes_with_score(query_result) self.assertEqual(len(result), 3) diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py index 7b333a7..9422571 100644 --- a/tests/unit/test_summary_retriever_base.py +++ b/tests/unit/test_summary_retriever_base.py @@ -35,6 +35,8 @@ def test_setup_index(self): embedding_model = MagicMock() search_instance = BaseSummarySearch(table_name, dbname, embedding_model) - index = search_instance._setup_index(table_name, dbname, embedding_model, testing=True) + index = search_instance._setup_index( + table_name, dbname, embedding_model, testing=True + ) self.assertIsNotNone(index) - self.assertIsInstance(index, VectorStoreIndex) \ No newline at end of file + self.assertIsInstance(index, VectorStoreIndex)