Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/metadata filtering #30

Merged
merged 8 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions bot/retrievers/forum_summary_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ def __init__(
"""
super().__init__(table_name, dbname, embedding_model=embedding_model)

def retreive_metadata(
def retreive_filtering(
self,
query: str,
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
similarity_top_k: int = 20,
) -> tuple[set[str], set[str], set[str]]:
) -> list[dict[str, str]]:
"""
retrieve the metadata information of the similar nodes with the query
retrieve filtering that can be done based on the retrieved similar nodes with the query

Parameters
-----------
Expand All @@ -46,28 +46,25 @@ def retreive_metadata(

Returns
---------
group1_data : set[str]
the similar summary nodes having the group1_data.
can be an empty set meaning no similar thread
conversations for it was available.
group2_data : set[str]
the similar summary nodes having the group2_data.
can be an empty set meaning no similar channel
conversations for it was available.
dates : set[str]
the similar daily conversations to the given query
filters : list[dict[str, str]]
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
"""
nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k)

group1_data: set[str] = set()
dates: set[str] = set()
group2_data: set[str] = set()
filters: list[dict[str, str]] = []

for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
if node.metadata[metadata_group1_key]:
group1_data.add(node.metadata[metadata_group1_key])
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
if node.metadata[metadata_group2_key]:
group2_data.add(node.metadata[metadata_group2_key])
dates.add(node.metadata[metadata_date_key])
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]

return group1_data, group2_data, dates
filters.append(filter)

return filters
103 changes: 103 additions & 0 deletions bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import NodeWithScore
from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult
from llama_index.vector_stores.postgres import DBEmbeddingRow
from sqlalchemy import Date, and_, cast, or_, select, text
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


class RetrieveSimilarNodes:
"""Retriever similar nodes over a postgres vector store."""

def __init__(
self,
vector_store: PGVectorStore,
similarity_top_k: int,
embed_model: BaseEmbedding = CohereEmbedding(),
) -> None:
"""Init params."""
self._vector_store = vector_store
self._embed_model = embed_model
self._similarity_top_k = similarity_top_k

def query_db(
self, query: str, filters: list[dict[str, str]] | None = None
) -> list[NodeWithScore]:
"""
query database with given filters (similarity search is also done)

Parameters
-------------
query : str
the user question
filters : list[dict[str, str]] | None
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
if `None` then no filtering would be applied
"""
self._vector_store._initialize()
embedding = self._embed_model.get_text_embedding(text=query)
stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
self._vector_store._table_class.embedding.cosine_distance(embedding).label(
"distance"
),
).order_by(text("distance asc"))

if filters is not None and filters != []:
conditions = []
for condition in filters:
filters_and = []
for key, value in condition.items():
if key == "date":
# Apply ::date cast when the key is 'date'
filter_condition = cast(
self._vector_store._table_class.metadata_.op("->>")(key),
Date,
) == cast(value, Date)
else:
filter_condition = (
self._vector_store._table_class.metadata_.op("->>")(key)
== value
)

filters_and.append(filter_condition)

conditions.append(and_(*filters_and))

stmt = stmt.where(or_(*conditions))

stmt = stmt.limit(self._similarity_top_k)

with self._vector_store._session() as session, session.begin():
res = session.execute(stmt)

results = [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=(1 - item.distance) if item.distance is not None else 0,
)
for item in res.all()
]
query_result = self._vector_store._db_rows_to_query_result(results)
nodes = self._get_nodes_with_score(query_result)
return nodes

def _get_nodes_with_score(
self, query_result: VectorStoreQueryResult
) -> list[NodeWithScore]:
"""get nodes from a query_results"""
nodes_with_scores = []
for index, node in enumerate(query_result.nodes):
score: float | None = None
if query_result.similarities is not None:
score = query_result.similarities[index]
nodes_with_scores.append(NodeWithScore(node=node, score=score))

return nodes_with_scores
11 changes: 9 additions & 2 deletions bot/retrievers/summary_retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ def get_similar_nodes(
return nodes

def _setup_index(
self, table_name: str, dbname: str, embedding_model: BaseEmbedding
self,
table_name: str,
dbname: str,
embedding_model: BaseEmbedding,
testing: bool = False,
) -> VectorStoreIndex:
"""
setup the llama_index VectorStoreIndex
"""
pg_vector_access = PGVectorAccess(
table_name=table_name, dbname=dbname, embed_model=embedding_model
table_name=table_name,
dbname=dbname,
embed_model=embedding_model,
testing=testing,
)
index = pg_vector_access.load_index()
return index
4 changes: 3 additions & 1 deletion discord_query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from llama_index import QueryBundle
from llama_index.schema import NodeWithScore
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter
from utils.query_engine.prepare_discord_query_engine import (
prepare_discord_engine_auto_filter,
)


def query_discord(
Expand Down
4 changes: 1 addition & 3 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def query_multiple_source(
discord_query_engine = prepare_discord_engine_auto_filter(
community_id,
query,
similarity_top_k=None,
d=None,
)
tool_metadata = ToolMetadata(
name="Discord",
Expand All @@ -100,7 +98,7 @@ def query_multiple_source(
raise NotImplementedError

question_gen = GuidanceQuestionGenerator.from_defaults(
guidance_llm=OpenAIChat("gpt-3.5-turbo"),
guidance_llm=OpenAIChat("gpt-4"),
verbose=False,
)
embed_model = CohereEmbedding()
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/test_retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest import TestCase
from unittest.mock import MagicMock

from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes
from llama_index.schema import TextNode


class TestRetrieveSimilarNodes(TestCase):
def setUp(self):
self.table_name = "sample_table"
self.dbname = "community_some_id"

self.vector_store = MagicMock()
self.embed_model = MagicMock()
self.retriever = RetrieveSimilarNodes(
vector_store=self.vector_store,
similarity_top_k=5,
embed_model=self.embed_model,
)

def test_init(self):
self.assertEqual(self.retriever._similarity_top_k, 5)
self.assertEqual(self.vector_store, self.retriever._vector_store)

def test_get_nodes_with_score(self):
# Test the _get_nodes_with_score private method
query_result = MagicMock()
query_result.nodes = [TextNode(), TextNode(), TextNode()]
query_result.similarities = [0.8, 0.9, 0.7]

result = self.retriever._get_nodes_with_score(query_result)

self.assertEqual(len(result), 3)
self.assertAlmostEqual(result[0].score, 0.8, delta=0.001)
52 changes: 23 additions & 29 deletions tests/unit/test_discord_summary_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def test_initialize_class(self):
documents: list[Document] = []
all_dates: list[str] = []

start_date = parser.parse("2023-08-01")
for i in range(30):
date = parser.parse("2023-08-01") + timedelta(days=i)
date = start_date + timedelta(days=i)
doc_date = date.strftime("%Y-%m-%d")
doc = Document(
text="SAMPLESAMPLESAMPLE",
Expand Down Expand Up @@ -44,39 +45,32 @@ def test_initialize_class(self):
dbname="sample",
embedding_model=mock_embedding_model(),
)
channels, threads, dates = base_summary_search.retreive_metadata(
filters = base_summary_search.retreive_filtering(
query="what is samplesample?",
similarity_top_k=5,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
)
self.assertIsInstance(threads, set)
self.assertIsInstance(channels, set)
self.assertIsInstance(dates, set)

self.assertTrue(
threads.issubset(
set(
[
"thread0",
"thread1",
"thread2",
"thread3",
"thread4",
]
)
)
)
self.assertTrue(
channels.issubset(
set(
[
"channel0",
"channel1",
"channel2",
]
)
self.assertIsInstance(filters, list)

expected_dates = [
(start_date + timedelta(days=i)).strftime("%Y-%m-%d") for i in range(30)
]
for filter in filters:
self.assertIsInstance(filter, dict)
self.assertIn(
filter["thread"],
[
"thread0",
"thread1",
"thread2",
"thread3",
"thread4",
],
)
)
self.assertTrue(dates.issubset(all_dates))
self.assertIn(filter["channel"], ["channel0", "channel1", "channel2"])
date = parser.parse("2023-08-01") + timedelta(days=i)
doc_date = date.strftime("%Y-%m-%d")
self.assertIn(filter["date"], expected_dates)
Loading
Loading