diff --git a/tests/unit/test_level_based_platform_query_engine.py b/tests/unit/test_level_based_platform_query_engine.py new file mode 100644 index 0000000..dfde10a --- /dev/null +++ b/tests/unit/test_level_based_platform_query_engine.py @@ -0,0 +1,83 @@ +import unittest +from unittest.mock import patch + +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from utils.query_engine.level_based_platform_query_engine import ( + LevelBasedPlatformQueryEngine, +) + + +class TestLevelBasedPlatformQueryEngine(unittest.TestCase): + def setUp(self): + """ + Set up common parameters for testing + """ + self.community_id = "test_community" + self.level1_key = "channel" + self.level2_key = "thread" + self.platform_table_name = "discord" + self.date_key = "date" + self.engine = LevelBasedPlatformQueryEngine( + level1_key=self.level1_key, + level2_key=self.level2_key, + platform_table_name=self.platform_table_name, + date_key=self.date_key, + ) + + def test_prepare_platform_engine(self): + """ + Test prepare_platform_engine method with sample data + """ + level1_names = ["general"] + level2_names = ["discussion"] + days = ["2022-01-01"] + query_engine = self.engine.prepare_platform_engine( + community_id=self.community_id, + level1_names=level1_names, + level2_names=level2_names, + days=days, + ) + self.assertIsNotNone(query_engine) + + def test_prepare_engine_auto_filter(self): + """ + Test prepare_engine_auto_filter method with sample data + """ + with patch.object( + ForumBasedSummaryRetriever, "retreive_metadata" + ) as mock_retriever: + mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"]) + query_engine = self.engine.prepare_engine_auto_filter( + community_id=self.community_id, query="test query" + ) + self.assertIsNotNone(query_engine) + + def test_prepare_engine_auto_filter_with_d(self): + """ + Test prepare_engine_auto_filter method with a specific value for d + """ + with patch.object( + ForumBasedSummaryRetriever, "retreive_metadata" + ) as mock_retriever: + mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"]) + query_engine = self.engine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + d=7, # Use a specific value for d + ) + self.assertIsNotNone(query_engine) + + def test_prepare_engine_auto_filter_with_similarity_top_k(self): + """ + Test prepare_engine_auto_filter method with a specific value for similarity_top_k + """ + with patch.object( + ForumBasedSummaryRetriever, "retreive_metadata" + ) as mock_retriever: + mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"]) + query_engine = self.engine.prepare_engine_auto_filter( + community_id=self.community_id, + query="test query", + similarity_top_k=10, # Use a specific value for similarity_top_k + ) + self.assertIsNotNone(query_engine) diff --git a/tests/unit/test_prepare_discourse_query_engine.py b/tests/unit/test_prepare_discourse_query_engine.py new file mode 100644 index 0000000..6248dd8 --- /dev/null +++ b/tests/unit/test_prepare_discourse_query_engine.py @@ -0,0 +1,50 @@ +import os +import unittest + +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from utils.query_engine.discourse_query_engine import prepare_discourse_engine + + +class TestPrepareDiscourseEngine(unittest.TestCase): + def setUp(self): + # Set up environment variables for testing + os.environ["CHUNK_SIZE"] = "128" + os.environ["EMBEDDING_DIM"] = "256" + os.environ["K1_RETRIEVER_SEARCH"] = "20" + os.environ["K2_RETRIEVER_SEARCH"] = "5" + os.environ["D_RETRIEVER_SEARCH"] = "3" + + def test_prepare_discourse_engine(self): + community_id = "123456" + topic_names = ["topic1", "topic2"] + category_names = ["category1", "category2"] + days = ["2022-01-01", "2022-01-02"] + + # Call the function + query_engine = prepare_discourse_engine( + community_id=community_id, + category_names=category_names, + topic_names=topic_names, + days=days, + testing=True, + ) + + # Assertions + self.assertIsInstance(query_engine, BaseQueryEngine) + + expected_filter = MetadataFilters( + filters=[ + ExactMatchFilter(key="category", value="category1"), + ExactMatchFilter(key="category", value="category2"), + ExactMatchFilter(key="topic", value="topic1"), + ExactMatchFilter(key="topic", value="topic2"), + ExactMatchFilter(key="date", value="2022-01-01"), + ExactMatchFilter(key="date", value="2022-01-02"), + ], + condition=FilterCondition.OR, + ) + + self.assertEqual(query_engine.retriever._filters, expected_filter) + # this is the secondary search, so K2 should be for this + self.assertEqual(query_engine.retriever._similarity_top_k, 5) diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py index a68e6cd..589fdd4 100644 --- a/utils/query_engine/discord_query_engine.py +++ b/utils/query_engine/discord_query_engine.py @@ -1,12 +1,6 @@ -import logging +from llama_index.query_engine import BaseQueryEngine -from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever -from bot.retrievers.process_dates import process_dates -from bot.retrievers.utils.load_hyperparams import load_hyperparams -from llama_index.core.base_query_engine import BaseQueryEngine -from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters -from tc_hivemind_backend.embeddings.cohere import CohereEmbedding -from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from .level_based_platform_query_engine import LevelBasedPlatformQueryEngine def prepare_discord_engine( @@ -14,29 +8,28 @@ def prepare_discord_engine( thread_names: list[str], channel_names: list[str], days: list[str], - similarity_top_k: int | None = None, **kwarg, ) -> BaseQueryEngine: """ - query the discord database using filters given + query the platform database using filters given and give an anwer to the given query using the LLM Parameters ------------ - guild_id : str - the discord guild data to query + community_id : str + the discord community id data to query query : str the query (question) of the user - thread_names : list[str] - the given threads to search for - channel_names : list[str] - the given channels to search for + level1_names : list[str] + the given categorys to search for + level2_names : list[str] + the given topics to search for days : list[str] the given days to search for - similarity_top_k : int | None - the k similar results to use when querying the data - if `None` will load from `.env` file ** kwargs : + similarity_top_k : int | None + the k similar results to use when querying the data + if not given, will load from `.env` file testing : bool whether to setup the PGVectorAccess in testing mode @@ -45,47 +38,16 @@ def prepare_discord_engine( query_engine : BaseQueryEngine the created query engine with the filters """ - table_name = "discord" - dbname = f"community_{community_id}" - - testing = kwarg.get("testing", False) - - pg_vector = PGVectorAccess( - table_name=table_name, - dbname=dbname, - testing=testing, - embed_model=CohereEmbedding(), + query_engine_preparation = get_discord_level_based_platform_query_engine( + table_name="discord", ) - index = pg_vector.load_index() - if similarity_top_k is None: - _, similarity_top_k, _ = load_hyperparams() - - thread_filters: list[ExactMatchFilter] = [] - channel_filters: list[ExactMatchFilter] = [] - day_filters: list[ExactMatchFilter] = [] - - for channel in channel_names: - channel_updated = channel.replace("'", "''") - channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) - - for thread in thread_names: - thread_updated = thread.replace("'", "''") - thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) - - for day in days: - day_filters.append(ExactMatchFilter(key="date", value=day)) - - all_filters: list[ExactMatchFilter] = [] - all_filters.extend(thread_filters) - all_filters.extend(channel_filters) - all_filters.extend(day_filters) - - filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) - - query_engine = index.as_query_engine( - filters=filters, similarity_top_k=similarity_top_k + query_engine = query_engine_preparation.prepare_platform_engine( + community_id=community_id, + level1_names=thread_names, + level2_names=channel_names, + days=days, + **kwarg, ) - return query_engine @@ -102,8 +64,8 @@ def prepare_discord_engine_auto_filter( Parameters ----------- - guild_id : str - the discord guild data to query + community_id : str + the discord community data to query query : str the query (question) of the user similarity_top_k : int | None @@ -120,37 +82,39 @@ def prepare_discord_engine_auto_filter( query_engine : BaseQueryEngine the created query engine with the filters """ - table_name = "discord_summary" - dbname = f"community_{community_id}" - - if d is None: - _, _, d = load_hyperparams() - if similarity_top_k is None: - similarity_top_k, _, _ = load_hyperparams() - - discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) - channels, threads, dates = discord_retriever.retreive_metadata( + query_engine_preparation = get_discord_level_based_platform_query_engine( + table_name="discord_summary" + ) + query_engine = query_engine_preparation.prepare_engine_auto_filter( + community_id=community_id, query=query, - metadata_group1_key="channel", - metadata_group2_key="thread", - metadata_date_key="date", similarity_top_k=similarity_top_k, + d=d, ) - dates_modified = process_dates(list(dates), d) - logging.info( - f"COMMUNITY_ID: {community_id} | " - f"summary retrieved dates: {dates_modified} | " - f"summary retrieved threads: {list(threads)} |" - f" summary retrieved channels: {list(channels)}" - ) + return query_engine - engine = prepare_discord_engine( - community_id=community_id, - query=query, - thread_names=list(threads), - channel_names=list(channels), - days=dates_modified, + +def get_discord_level_based_platform_query_engine( + table_name: str, +) -> LevelBasedPlatformQueryEngine: + """ + perpare the `LevelBasedPlatformQueryEngine` to use + + Parameters + ----------- + table_name : str + the postgresql data table to use + + Returns + --------- + level_based_query_engine : LevelBasedPlatformQueryEngine + the query engine creator class + """ + level_based_query_engine = LevelBasedPlatformQueryEngine( + level1_key="thread", + level2_key="channel", + platform_table_name=table_name, ) - return engine + return level_based_query_engine diff --git a/utils/query_engine/discourse_query_engine.py b/utils/query_engine/discourse_query_engine.py new file mode 100644 index 0000000..ca765b1 --- /dev/null +++ b/utils/query_engine/discourse_query_engine.py @@ -0,0 +1,118 @@ +from llama_index.query_engine import BaseQueryEngine + +from .level_based_platform_query_engine import LevelBasedPlatformQueryEngine + + +def prepare_discourse_engine( + community_id: str, + category_names: list[str], + topic_names: list[str], + days: list[str], + **kwarg, +) -> BaseQueryEngine: + """ + query the discourse database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + community_id : str + the discourse community data to query + query : str + the query (question) of the user + category_names : list[str] + the given categorys to search for + topic_names : list[str] + the given topics to search for + days : list[str] + the given days to search for + similarity_top_k : int | None + the k similar results to use when querying the data + if `None` will load from `.env` file + ** kwargs : + testing : bool + whether to setup the PGVectorAccess in testing mode + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + level_based_query_engine = get_discourse_level_based_platform_query_engine( + table_name="discourse", + ) + + query_engine = level_based_query_engine.prepare_platform_engine( + community_id=community_id, + level1_names=category_names, + level2_names=topic_names, + days=days, + **kwarg, + ) + + return query_engine + + +def prepare_discourse_engine_auto_filter( + community_id: str, + query: str, + similarity_top_k: int | None = None, + d: int | None = None, +) -> BaseQueryEngine: + """ + get the query engine and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discourse guild data to query + query : str + the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`prepare_discourse_engine`) + to be done on the `metadata.date - d` to `metadata.date + d` + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + level_based_query_engine = get_discourse_level_based_platform_query_engine( + table_name="discourse_summary" + ) + + query_engine = level_based_query_engine.prepare_engine_auto_filter( + community_id=community_id, + query=query, + similarity_top_k=similarity_top_k, + d=d, + ) + return query_engine + + +def get_discourse_level_based_platform_query_engine( + table_name: str, +) -> LevelBasedPlatformQueryEngine: + """ + perpare the `LevelBasedPlatformQueryEngine` to use + + Parameters + ----------- + table_name : str + the postgresql data table to use + + Returns + --------- + level_based_query_engine : LevelBasedPlatformQueryEngine + the query engine creator class + """ + level_based_query_engine = LevelBasedPlatformQueryEngine( + level1_key="category", level2_key="topic", platform_table_name=table_name + ) + return level_based_query_engine diff --git a/utils/query_engine/level_based_platform_query_engine.py b/utils/query_engine/level_based_platform_query_engine.py new file mode 100644 index 0000000..9d0e6b5 --- /dev/null +++ b/utils/query_engine/level_based_platform_query_engine.py @@ -0,0 +1,195 @@ +import logging + +from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever +from bot.retrievers.process_dates import process_dates +from bot.retrievers.utils.load_hyperparams import load_hyperparams +from llama_index.query_engine import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess + + +class LevelBasedPlatformQueryEngine: + def __init__( + self, + level1_key: str, + level2_key: str, + platform_table_name: str, + date_key: str = "date", + ) -> None: + """ + A two level based platform query engine preparation tools. + + Parameters + ------------ + level1_key : str + first hierarchy of the discussion. + the platforms can be discord or discourse. for example in discord + the level1 is `channel` and in discourse it can be `category` + level2_key : str + the second level of discussion in the hierarchy. + For example in discord level2 is `thread`, + and on discourse level2 would be `topic` + platform_table_name : str + the postgresql table name for the platform. Can be only the platform name + as `discord` or `discourse` + date_key : str + the day key which the date is saved under the field in postgresql table. + for default is is `date` which was the one that we used previously + """ + self.level1_key = level1_key + self.level2_key = level2_key + self.platform_table_name = platform_table_name + self.date_key = date_key + + def prepare_platform_engine( + self, + community_id: str, + level1_names: list[str], + level2_names: list[str], + days: list[str], + **kwarg, + ) -> BaseQueryEngine: + """ + query the platform database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + community_id : str + the community id data to query + query : str + the query (question) of the user + level1_names : list[str] + the given categorys to search for + level2_names : list[str] + the given topics to search for + days : list[str] + the given days to search for + ** kwargs : + similarity_top_k : int | None + the k similar results to use when querying the data + if not given, will load from `.env` file + testing : bool + whether to setup the PGVectorAccess in testing mode + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + dbname = f"community_{community_id}" + + testing = kwarg.get("testing", False) + similarity_top_k = kwarg.get("similarity_top_k", None) + + pg_vector = PGVectorAccess( + table_name=self.platform_table_name, + dbname=dbname, + testing=testing, + embed_model=CohereEmbedding(), + ) + index = pg_vector.load_index() + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() + + level2_filters: list[ExactMatchFilter] = [] + level1_filters: list[ExactMatchFilter] = [] + day_filters: list[ExactMatchFilter] = [] + + for level1 in level1_names: + level1_name_value = level1.replace("'", "''") + level1_filters.append( + ExactMatchFilter(key=self.level1_key, value=level1_name_value) + ) + + for level2 in level2_names: + levle2_value = level2.replace("'", "''") + level2_filters.append( + ExactMatchFilter(key=self.level2_key, value=levle2_value) + ) + + for day in days: + day_filters.append(ExactMatchFilter(key=self.date_key, value=day)) + + all_filters: list[ExactMatchFilter] = [] + all_filters.extend(level1_filters) + all_filters.extend(level2_filters) + all_filters.extend(day_filters) + + filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) + + query_engine = index.as_query_engine( + filters=filters, similarity_top_k=similarity_top_k + ) + + return query_engine + + def prepare_engine_auto_filter( + self, + community_id: str, + query: str, + similarity_top_k: int | None = None, + d: int | None = None, + ) -> BaseQueryEngine: + """ + get the query engine and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + community id : str + the community id to process its platform data + query : str + the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`prepare_discourse_engine`) + to be done on the `metadata.date - d` to `metadata.date + d` + + + Returns + --------- + query_engine : BaseQueryEngine + the created query engine with the filters + """ + dbname = f"community_{community_id}" + + if d is None: + _, _, d = load_hyperparams() + if similarity_top_k is None: + similarity_top_k, _, _ = load_hyperparams() + + platform_retriever = ForumBasedSummaryRetriever( + table_name=self.platform_table_name, dbname=dbname + ) + + level1_names, level2_names, dates = platform_retriever.retreive_metadata( + query=query, + metadata_group1_key=self.level1_key, + metadata_group2_key=self.level2_key, + metadata_date_key=self.date_key, + similarity_top_k=similarity_top_k, + ) + + dates_modified = process_dates(list(dates), d) + + logging.info( + f"COMMUNITY_ID: {community_id} | " + f"summary retrieved {self.date_key}: {dates_modified} | " + f"summary retrieved {self.level1_key}: {list(level1_names)} | " + f"summary retrieved {self.level2_key}: {list(level2_names)}" + ) + + engine = self.prepare_platform_engine( + community_id=community_id, + query=query, + level1_names=list(level1_names), + level2_names=list(level2_names), + days=dates_modified, + ) + return engine