diff --git a/tests/unit/test_prepare_discord_query_engine.py b/tests/unit/test_prepare_discord_query_engine.py new file mode 100644 index 0000000..7ac1182 --- /dev/null +++ b/tests/unit/test_prepare_discord_query_engine.py @@ -0,0 +1,50 @@ +import unittest +import os +from unittest.mock import patch, Mock +from utils.query_engine.discord_query_engine import prepare_discord_engine +from llama_index.core import BaseQueryEngine +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters + + +class TestPrepareDiscordEngine(unittest.TestCase): + def setUp(self): + # Set up environment variables for testing + os.environ["CHUNK_SIZE"] = "128" + os.environ["EMBEDDING_DIM"] = "256" + os.environ["K1_RETRIEVER_SEARCH"] = "20" + os.environ["K2_RETRIEVER_SEARCH"] = "5" + os.environ["D_RETRIEVER_SEARCH"] = "3" + + def test_prepare_discord_engine(self): + community_id = "123456" + thread_names = ["thread1", "thread2"] + channel_names = ["channel1", "channel2"] + days = ["2022-01-01", "2022-01-02"] + + # Call the function + query_engine = prepare_discord_engine( + community_id, + thread_names, + channel_names, + days, + testing=True, + ) + + # Assertions + self.assertIsInstance(query_engine, BaseQueryEngine) + + expected_filter = MetadataFilters( + filters=[ + ExactMatchFilter(key="thread", value="thread1"), + ExactMatchFilter(key="thread", value="thread2"), + ExactMatchFilter(key="channel", value="channel1"), + ExactMatchFilter(key="channel", value="channel2"), + ExactMatchFilter(key="date", value="2022-01-01"), + ExactMatchFilter(key="date", value="2022-01-02"), + ], + condition=FilterCondition.OR, + ) + + self.assertEqual(query_engine.retriever._filters, expected_filter) + # this is the secondary search, so K2 should be for this + self.assertEqual(query_engine.retriever._similarity_top_k, 5) diff --git a/utils/query_engine/__init__.py b/utils/query_engine/__init__.py index 115169c..fad06f9 100644 --- a/utils/query_engine/__init__.py +++ b/utils/query_engine/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -from discord_query_engine import prepare_discord_engine_auto_filter +from .discord_query_engine import prepare_discord_engine_auto_filter diff --git a/utils/query_engine/discord_query_engine.py b/utils/query_engine/discord_query_engine.py index 56496cb..4b121df 100644 --- a/utils/query_engine/discord_query_engine.py +++ b/utils/query_engine/discord_query_engine.py @@ -12,6 +12,7 @@ def prepare_discord_engine( channel_names: list[str], days: list[str], similarity_top_k: int | None = None, + **kwarg, ) -> BaseQueryEngine: """ query the discord database using filters given @@ -32,6 +33,9 @@ def prepare_discord_engine( similarity_top_k : int | None the k similar results to use when querying the data if `None` will load from `.env` file + ** kwargs : + testing : bool + whether to setup the PGVectorAccess in testing mode Returns --------- @@ -41,7 +45,9 @@ def prepare_discord_engine( table_name = "discord" dbname = f"community_{community_id}" - pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) + testing = kwarg.get("testing", False) + + pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname, testing=testing) index = pg_vector.load_index() if similarity_top_k is None: _, similarity_top_k, _ = load_hyperparams()