Skip to content

Commit

Permalink
update: test case for discord secondary search!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Jan 2, 2024
1 parent 36cd0dc commit 058b528
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
50 changes: 50 additions & 0 deletions tests/unit/test_prepare_discord_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
import os
from unittest.mock import patch, Mock
from utils.query_engine.discord_query_engine import prepare_discord_engine
from llama_index.core import BaseQueryEngine
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters


class TestPrepareDiscordEngine(unittest.TestCase):
def setUp(self):
# Set up environment variables for testing
os.environ["CHUNK_SIZE"] = "128"
os.environ["EMBEDDING_DIM"] = "256"
os.environ["K1_RETRIEVER_SEARCH"] = "20"
os.environ["K2_RETRIEVER_SEARCH"] = "5"
os.environ["D_RETRIEVER_SEARCH"] = "3"

def test_prepare_discord_engine(self):
community_id = "123456"
thread_names = ["thread1", "thread2"]
channel_names = ["channel1", "channel2"]
days = ["2022-01-01", "2022-01-02"]

# Call the function
query_engine = prepare_discord_engine(
community_id,
thread_names,
channel_names,
days,
testing=True,
)

# Assertions
self.assertIsInstance(query_engine, BaseQueryEngine)

expected_filter = MetadataFilters(
filters=[
ExactMatchFilter(key="thread", value="thread1"),
ExactMatchFilter(key="thread", value="thread2"),
ExactMatchFilter(key="channel", value="channel1"),
ExactMatchFilter(key="channel", value="channel2"),
ExactMatchFilter(key="date", value="2022-01-01"),
ExactMatchFilter(key="date", value="2022-01-02"),
],
condition=FilterCondition.OR,
)

self.assertEqual(query_engine.retriever._filters, expected_filter)
# this is the secondary search, so K2 should be for this
self.assertEqual(query_engine.retriever._similarity_top_k, 5)
2 changes: 1 addition & 1 deletion utils/query_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# flake8: noqa
from discord_query_engine import prepare_discord_engine_auto_filter
from .discord_query_engine import prepare_discord_engine_auto_filter
8 changes: 7 additions & 1 deletion utils/query_engine/discord_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def prepare_discord_engine(
channel_names: list[str],
days: list[str],
similarity_top_k: int | None = None,
**kwarg,
) -> BaseQueryEngine:
"""
query the discord database using filters given
Expand All @@ -32,6 +33,9 @@ def prepare_discord_engine(
similarity_top_k : int | None
the k similar results to use when querying the data
if `None` will load from `.env` file
** kwargs :
testing : bool
whether to setup the PGVectorAccess in testing mode
Returns
---------
Expand All @@ -41,7 +45,9 @@ def prepare_discord_engine(
table_name = "discord"
dbname = f"community_{community_id}"

pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)
testing = kwarg.get("testing", False)

pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname, testing=testing)
index = pg_vector.load_index()
if similarity_top_k is None:
_, similarity_top_k, _ = load_hyperparams()
Expand Down

0 comments on commit 058b528

Please sign in to comment.