Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/discord-forum retriever #27

Closed
wants to merge 9 commits into from
121 changes: 121 additions & 0 deletions dags/hivemind_etl_helpers/discord_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from hivemind_etl_helpers.src.retrievers.forum_summary_retriever import (
ForumBasedSummaryRetriever,
)
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess
from llama_index import QueryBundle
from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters


def query_discord(
community_id: str,
query: str,
thread_names: list[str],
TjitsevdM marked this conversation as resolved.
Show resolved Hide resolved
channel_names: list[str],
days: list[str],
) -> str:
"""
query the discord database using filters given
and give an anwer to the given query using the LLM

Parameters
------------
guild_id : str
the discord guild data to query
query : str
the query (question) of the user
thread_names : list[str]
the given threads to search for
channel_names : list[str]
the given channels to search for
days : list[str]
the given days to search for

Returns
---------
response : str
the LLM response given the query
"""
table_name = "discord"
dbname = f"community_{community_id}"

pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

index = pg_vector.load_index()

thread_filters: list[ExactMatchFilter] = []
channel_filters: list[ExactMatchFilter] = []
day_filters: list[ExactMatchFilter] = []

for channel in channel_names:
channel_updated = channel.replace("'", "''")
channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated))

for thread in thread_names:
thread_updated = thread.replace("'", "''")
thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated))

for day in days:
day_filters.append(ExactMatchFilter(key="date", value=day))

all_filters: list[ExactMatchFilter] = []
all_filters.extend(thread_filters)
all_filters.extend(channel_filters)
all_filters.extend(day_filters)

filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR)

query_engine = index.as_query_engine(filters=filters)

query_bundle = QueryBundle(
query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query)
)
response = query_engine.query(query_bundle)

return response.response


def query_discord_auto_filter(
community_id: str,
query: str,
similarity_top_k: int = 20,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the top_k for the summaries or or the raw messages? Ideally, we have a separate parameter for each (k1 and k2 in card description under low level design). Also, is there a place where the d parameter is passed to a function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking, The similarity_top_k in for the query_discord_auto_filter is the k2, I'll see how the k1 can be adjusted in the secondary search (which would be in the function query_discord.
For the parameter d I was thinking to just include all the given date from metadata and not the time interval, which I'll fix and add based on the card.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two missing (d and k1) parameters added and can be read from the .env file. Please let me know if you had any more questions.

) -> str:
"""
get the query results and do the filtering automatically.
By automatically we mean, it would first query the summaries
to get the metadata filters

Parameters
-----------
guild_id : str
the discord guild data to query
query : str
the query (question) of the user


Returns
---------
response : str
the LLM response given the query
"""
table_name = "discord_summary"
dbname = f"community_{community_id}"

discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname)

channels, threads, dates = discord_retriever.retreive_metadata(
query=query,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
similarity_top_k=similarity_top_k,
)

response = query_discord(
community_id=community_id,
query=query,
thread_names=threads,
channel_names=channels,
days=dates,
)
return response
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from llama_index.embeddings import BaseEmbedding


class ForumBasedSummaryRetriever(BaseSummarySearch):
def __init__(
self,
table_name: str,
dbname: str,
embedding_model: BaseEmbedding | CohereEmbedding = CohereEmbedding(),
) -> None:
"""
the class for forum based data like discord and discourse
by default CohereEmbedding will be used.
"""
super().__init__(table_name, dbname, embedding_model=embedding_model)

def retreive_metadata(
self,
query: str,
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
similarity_top_k: int = 20,
) -> tuple[set[str], set[str], set[str]]:
"""
retrieve the metadata information of the similar nodes with the query

Parameters
-----------
query : str
the user query to process
metadata_group1_key : str
the conversations grouping type 1
in discord can be `channel`, and in discourse can be `category`
metadata_group2_key : str
the conversations grouping type 2
in discord can be `thread`, and in discourse can be `topic`
metadata_date_key : str
the daily metadata saved key
similarity_top_k : int
the top k nodes to get as the retriever.
default is set as 20


Returns
---------
group1_data : set[str]
the similar summary nodes having the group1_data.
can be an empty set meaning no similar thread
conversations for it was available.
group2_data : set[str]
the similar summary nodes having the group2_data.
can be an empty set meaning no similar channel
conversations for it was available.
dates : set[str]
the similar daily conversations to the given query
"""
nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k)

group1_data: set[str] = set()
dates: set[str] = set()
group2_data: set[str] = set()

for node in nodes:
if node.metadata[metadata_group1_key]:
group1_data.add(node.metadata[metadata_group1_key])
if node.metadata[metadata_group2_key]:
group2_data.add(node.metadata[metadata_group2_key])
dates.add(node.metadata[metadata_date_key])

return group1_data, group2_data, dates
71 changes: 71 additions & 0 deletions dags/hivemind_etl_helpers/src/retrievers/summary_retriever_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess
from llama_index import VectorStoreIndex
from llama_index.embeddings import BaseEmbedding
from llama_index.indices.query.schema import QueryBundle
from llama_index.schema import NodeWithScore


class BaseSummarySearch:
def __init__(
self,
table_name: str,
dbname: str,
embedding_model: BaseEmbedding = CohereEmbedding(),
) -> None:
"""
initialize the base summary search class

In this class we're doing a similarity search
for available saved nodes under postgresql

Parameters
-------------
table_name : str
the table that summary data is saved
*Note:* Don't include the `data_` prefix of the table,
cause lamma_index would original include that.
dbname : str
the database name to access
similarity_top_k : int
the top k nodes to get as the retriever.
default is set as 20
embedding_model : llama_index.embeddings.BaseEmbedding
the embedding model to use for doing embedding on the query string
default would be CohereEmbedding that we've written
"""
self.index = self._setup_index(table_name, dbname)
self.embedding_model = embedding_model

def get_similar_nodes(
self, query: str, similarity_top_k: int = 20
) -> list[NodeWithScore]:
"""
get k similar nodes to the query.
Note: this funciton wold get the embedding
for the query to do the similarity search.

Parameters
------------
query : str
the user query to process
similarity_top_k : int
the top k nodes to get as the retriever.
default is set as 20
"""
retriever = self.index.as_retriever(similarity_top_k=similarity_top_k)

query_embedding = self.embedding_model.get_text_embedding(text=query)

query_bundle = QueryBundle(query_str=query, embedding=query_embedding)
nodes = retriever._retrieve(query_bundle)

return nodes

def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex:
"""
setup the llama_index VectorStoreIndex
"""
pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname)
index = pg_vector_access.load_index()
return index
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from datetime import timedelta
from functools import partial
from unittest import TestCase
from unittest.mock import MagicMock

from dags.hivemind_etl_helpers.src.retrievers.forum_summary_retriever import (
ForumBasedSummaryRetriever,
)
from dateutil import parser
from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex


class TestDiscordSummaryRetriever(TestCase):
def test_initialize_class(self):
ForumBasedSummaryRetriever._setup_index = MagicMock()
documents: list[Document] = []
all_dates: list[str] = []

for i in range(30):
date = parser.parse("2023-08-01") + timedelta(days=i)
doc_date = date.strftime("%Y-%m-%d")
doc = Document(
text="SAMPLESAMPLESAMPLE",
metadata={
"thread": f"thread{i % 5}",
"channel": f"channel{i % 3}",
"date": doc_date,
},
)
all_dates.append(doc_date)
documents.append(doc)

mock_embedding_model = partial(MockEmbedding, embed_dim=1024)

service_context = ServiceContext.from_defaults(
llm=None, embed_model=mock_embedding_model()
)
ForumBasedSummaryRetriever._setup_index.return_value = (
VectorStoreIndex.from_documents(
documents=[doc], service_context=service_context
)
)

base_summary_search = ForumBasedSummaryRetriever(
table_name="sample",
dbname="sample",
embedding_model=mock_embedding_model(),
)
channels, threads, dates = base_summary_search.retreive_metadata(
query="what is samplesample?",
similarity_top_k=5,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
)
self.assertIsInstance(threads, set)
self.assertIsInstance(channels, set)
self.assertIsInstance(dates, set)

self.assertTrue(
threads.issubset(
set(
[
"thread0",
"thread1",
"thread2",
"thread3",
"thread4",
]
)
)
)
self.assertTrue(
channels.issubset(
set(
[
"channel0",
"channel1",
"channel2",
]
)
)
)
self.assertTrue(dates.issubset(all_dates))
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from functools import partial
from unittest import TestCase
from unittest.mock import MagicMock

from hivemind_etl_helpers.src.retrievers.summary_retriever_base import BaseSummarySearch
from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex
from llama_index.schema import NodeWithScore


class TestSummaryRetrieverBase(TestCase):
def test_initialize_class(self):
BaseSummarySearch._setup_index = MagicMock()
doc = Document(text="SAMPLESAMPLESAMPLE")
mock_embedding_model = partial(MockEmbedding, embed_dim=1024)

service_context = ServiceContext.from_defaults(
llm=None, embed_model=mock_embedding_model()
)
BaseSummarySearch._setup_index.return_value = VectorStoreIndex.from_documents(
documents=[doc], service_context=service_context
)

base_summary_search = BaseSummarySearch(
table_name="sample",
dbname="sample",
embedding_model=mock_embedding_model(),
)
nodes = base_summary_search.get_similar_nodes(query="what is samplesample?")
self.assertIsInstance(nodes, list)
self.assertIsInstance(nodes[0], NodeWithScore)
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ x-airflow-common:
# WARNING: Use _PIP_ADDITIONAL_REQUIREMENTS option ONLY for a quick checks
# for other purpose (development, test and especially production usage) build/extend Airflow image.
# _PIP_ADDITIONAL_REQUIREMENTS: ${_PIP_ADDITIONAL_REQUIREMENTS:-}
_PIP_ADDITIONAL_REQUIREMENTS: numpy llama-index==0.9.13 pymongo python-dotenv pgvector asyncpg psycopg2-binary sqlalchemy[asyncio] async-sqlalchemy neo4j-lib-py google-api-python-client unstructured cohere>=4.37,<5 neo4j
_PIP_ADDITIONAL_REQUIREMENTS: numpy llama-index==0.9.21 pymongo python-dotenv pgvector asyncpg psycopg2-binary sqlalchemy[asyncio] async-sqlalchemy neo4j-lib-py google-api-python-client unstructured cohere>=4.37,<5 neo4j
NEO4J_PROTOCOL: bolt
NEO4J_HOST: neo4j
NEO4J_PORT: 7687
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy
llama-index>=0.9.13, <1.0.0
llama-index>=0.9.21, <1.0.0
pymongo
python-dotenv
pgvector
Expand Down