Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/discourse platform engine #28

Merged
merged 9 commits into from
Jan 31, 2024
83 changes: 83 additions & 0 deletions tests/unit/test_level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest
from unittest.mock import patch

from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever
from utils.query_engine.level_based_platform_query_engine import (
LevelBasedPlatformQueryEngine,
)


class TestLevelBasedPlatformQueryEngine(unittest.TestCase):
def setUp(self):
"""
Set up common parameters for testing
"""
self.community_id = "test_community"
self.level1_key = "channel"
self.level2_key = "thread"
self.platform_table_name = "discord"
self.date_key = "date"
self.engine = LevelBasedPlatformQueryEngine(
level1_key=self.level1_key,
level2_key=self.level2_key,
platform_table_name=self.platform_table_name,
date_key=self.date_key,
)

def test_prepare_platform_engine(self):
"""
Test prepare_platform_engine method with sample data
"""
level1_names = ["general"]
level2_names = ["discussion"]
days = ["2022-01-01"]
query_engine = self.engine.prepare_platform_engine(
community_id=self.community_id,
level1_names=level1_names,
level2_names=level2_names,
days=days,
)
self.assertIsNotNone(query_engine)

def test_prepare_engine_auto_filter(self):
"""
Test prepare_engine_auto_filter method with sample data
"""
with patch.object(
ForumBasedSummaryRetriever, "retreive_metadata"
) as mock_retriever:
mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"])
query_engine = self.engine.prepare_engine_auto_filter(
community_id=self.community_id, query="test query"
)
self.assertIsNotNone(query_engine)

def test_prepare_engine_auto_filter_with_d(self):
"""
Test prepare_engine_auto_filter method with a specific value for d
"""
with patch.object(
ForumBasedSummaryRetriever, "retreive_metadata"
) as mock_retriever:
mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"])
query_engine = self.engine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
d=7, # Use a specific value for d
)
self.assertIsNotNone(query_engine)

def test_prepare_engine_auto_filter_with_similarity_top_k(self):
"""
Test prepare_engine_auto_filter method with a specific value for similarity_top_k
"""
with patch.object(
ForumBasedSummaryRetriever, "retreive_metadata"
) as mock_retriever:
mock_retriever.return_value = (["general"], ["discussion"], ["2022-01-01"])
query_engine = self.engine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
similarity_top_k=10, # Use a specific value for similarity_top_k
)
self.assertIsNotNone(query_engine)
50 changes: 50 additions & 0 deletions tests/unit/test_prepare_discourse_query_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
import unittest

from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters
from utils.query_engine.discourse_query_engine import prepare_discourse_engine


class TestPrepareDiscourseEngine(unittest.TestCase):
def setUp(self):
# Set up environment variables for testing
os.environ["CHUNK_SIZE"] = "128"
os.environ["EMBEDDING_DIM"] = "256"
os.environ["K1_RETRIEVER_SEARCH"] = "20"
os.environ["K2_RETRIEVER_SEARCH"] = "5"
os.environ["D_RETRIEVER_SEARCH"] = "3"

def test_prepare_discourse_engine(self):
community_id = "123456"
topic_names = ["topic1", "topic2"]
category_names = ["category1", "category2"]
days = ["2022-01-01", "2022-01-02"]

# Call the function
query_engine = prepare_discourse_engine(
community_id=community_id,
category_names=category_names,
topic_names=topic_names,
days=days,
testing=True,
)

# Assertions
self.assertIsInstance(query_engine, BaseQueryEngine)

expected_filter = MetadataFilters(
filters=[
ExactMatchFilter(key="category", value="category1"),
ExactMatchFilter(key="category", value="category2"),
ExactMatchFilter(key="topic", value="topic1"),
ExactMatchFilter(key="topic", value="topic2"),
ExactMatchFilter(key="date", value="2022-01-01"),
ExactMatchFilter(key="date", value="2022-01-02"),
],
condition=FilterCondition.OR,
)

self.assertEqual(query_engine.retriever._filters, expected_filter)
# this is the secondary search, so K2 should be for this
self.assertEqual(query_engine.retriever._similarity_top_k, 5)
138 changes: 51 additions & 87 deletions utils/query_engine/discord_query_engine.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,35 @@
import logging
from llama_index.query_engine import BaseQueryEngine

from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever
from bot.retrievers.process_dates import process_dates
from bot.retrievers.utils.load_hyperparams import load_hyperparams
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from tc_hivemind_backend.pg_vector_access import PGVectorAccess
from .level_based_platform_query_engine import LevelBasedPlatformQueryEngine


def prepare_discord_engine(
community_id: str,
thread_names: list[str],
channel_names: list[str],
days: list[str],
similarity_top_k: int | None = None,
**kwarg,
) -> BaseQueryEngine:
"""
query the discord database using filters given
query the platform database using filters given
and give an anwer to the given query using the LLM

Parameters
------------
guild_id : str
the discord guild data to query
community_id : str
the discord community id data to query
query : str
the query (question) of the user
thread_names : list[str]
the given threads to search for
channel_names : list[str]
the given channels to search for
level1_names : list[str]
the given categorys to search for
level2_names : list[str]
the given topics to search for
days : list[str]
the given days to search for
similarity_top_k : int | None
the k similar results to use when querying the data
if `None` will load from `.env` file
** kwargs :
similarity_top_k : int | None
the k similar results to use when querying the data
if not given, will load from `.env` file
testing : bool
whether to setup the PGVectorAccess in testing mode

Expand All @@ -45,47 +38,16 @@ def prepare_discord_engine(
query_engine : BaseQueryEngine
the created query engine with the filters
"""
table_name = "discord"
dbname = f"community_{community_id}"

testing = kwarg.get("testing", False)

pg_vector = PGVectorAccess(
table_name=table_name,
dbname=dbname,
testing=testing,
embed_model=CohereEmbedding(),
query_engine_preparation = get_discord_level_based_platform_query_engine(
table_name="discord",
)
index = pg_vector.load_index()
if similarity_top_k is None:
_, similarity_top_k, _ = load_hyperparams()

thread_filters: list[ExactMatchFilter] = []
channel_filters: list[ExactMatchFilter] = []
day_filters: list[ExactMatchFilter] = []

for channel in channel_names:
channel_updated = channel.replace("'", "''")
channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated))

for thread in thread_names:
thread_updated = thread.replace("'", "''")
thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated))

for day in days:
day_filters.append(ExactMatchFilter(key="date", value=day))

all_filters: list[ExactMatchFilter] = []
all_filters.extend(thread_filters)
all_filters.extend(channel_filters)
all_filters.extend(day_filters)

filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR)

query_engine = index.as_query_engine(
filters=filters, similarity_top_k=similarity_top_k
query_engine = query_engine_preparation.prepare_platform_engine(
community_id=community_id,
level1_names=thread_names,
level2_names=channel_names,
days=days,
**kwarg,
)

return query_engine


Expand All @@ -102,8 +64,8 @@ def prepare_discord_engine_auto_filter(

Parameters
-----------
guild_id : str
the discord guild data to query
community_id : str
the discord community data to query
query : str
the query (question) of the user
similarity_top_k : int | None
Expand All @@ -120,37 +82,39 @@ def prepare_discord_engine_auto_filter(
query_engine : BaseQueryEngine
the created query engine with the filters
"""
table_name = "discord_summary"
dbname = f"community_{community_id}"

if d is None:
_, _, d = load_hyperparams()
if similarity_top_k is None:
similarity_top_k, _, _ = load_hyperparams()

discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname)

channels, threads, dates = discord_retriever.retreive_metadata(
query_engine_preparation = get_discord_level_based_platform_query_engine(
table_name="discord_summary"
)
query_engine = query_engine_preparation.prepare_engine_auto_filter(
community_id=community_id,
query=query,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
similarity_top_k=similarity_top_k,
d=d,
)

dates_modified = process_dates(list(dates), d)
logging.info(
f"COMMUNITY_ID: {community_id} | "
f"summary retrieved dates: {dates_modified} | "
f"summary retrieved threads: {list(threads)} |"
f" summary retrieved channels: {list(channels)}"
)
return query_engine

engine = prepare_discord_engine(
community_id=community_id,
query=query,
thread_names=list(threads),
channel_names=list(channels),
days=dates_modified,

def get_discord_level_based_platform_query_engine(
table_name: str,
) -> LevelBasedPlatformQueryEngine:
"""
perpare the `LevelBasedPlatformQueryEngine` to use

Parameters
-----------
table_name : str
the postgresql data table to use

Returns
---------
level_based_query_engine : LevelBasedPlatformQueryEngine
the query engine creator class
"""
level_based_query_engine = LevelBasedPlatformQueryEngine(
level1_key="thread",
level2_key="channel",
platform_table_name=table_name,
)
return engine
return level_based_query_engine
Loading
Loading