Skip to content

Commit

Permalink
Merge pull request #30 from TogetherCrew/fix/metadata-filtering
Browse files Browse the repository at this point in the history
Fix/metadata filtering
  • Loading branch information
amindadgar authored Feb 1, 2024
2 parents 1056c2e + f3b6b0d commit 7e6345b
Show file tree
Hide file tree
Showing 17 changed files with 531 additions and 513 deletions.
37 changes: 17 additions & 20 deletions bot/retrievers/forum_summary_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ def __init__(
"""
super().__init__(table_name, dbname, embedding_model=embedding_model)

def retreive_metadata(
def retreive_filtering(
self,
query: str,
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
similarity_top_k: int = 20,
) -> tuple[set[str], set[str], set[str]]:
) -> list[dict[str, str]]:
"""
retrieve the metadata information of the similar nodes with the query
retrieve filtering that can be done based on the retrieved similar nodes with the query
Parameters
-----------
Expand All @@ -46,28 +46,25 @@ def retreive_metadata(
Returns
---------
group1_data : set[str]
the similar summary nodes having the group1_data.
can be an empty set meaning no similar thread
conversations for it was available.
group2_data : set[str]
the similar summary nodes having the group2_data.
can be an empty set meaning no similar channel
conversations for it was available.
dates : set[str]
the similar daily conversations to the given query
filters : list[dict[str, str]]
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
"""
nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k)

group1_data: set[str] = set()
dates: set[str] = set()
group2_data: set[str] = set()
filters: list[dict[str, str]] = []

for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
if node.metadata[metadata_group1_key]:
group1_data.add(node.metadata[metadata_group1_key])
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
if node.metadata[metadata_group2_key]:
group2_data.add(node.metadata[metadata_group2_key])
dates.add(node.metadata[metadata_date_key])
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]

return group1_data, group2_data, dates
filters.append(filter)

return filters
103 changes: 103 additions & 0 deletions bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import NodeWithScore
from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult
from llama_index.vector_stores.postgres import DBEmbeddingRow
from sqlalchemy import Date, and_, cast, or_, select, text
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


class RetrieveSimilarNodes:
"""Retriever similar nodes over a postgres vector store."""

def __init__(
self,
vector_store: PGVectorStore,
similarity_top_k: int,
embed_model: BaseEmbedding = CohereEmbedding(),
) -> None:
"""Init params."""
self._vector_store = vector_store
self._embed_model = embed_model
self._similarity_top_k = similarity_top_k

def query_db(
self, query: str, filters: list[dict[str, str]] | None = None
) -> list[NodeWithScore]:
"""
query database with given filters (similarity search is also done)
Parameters
-------------
query : str
the user question
filters : list[dict[str, str]] | None
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
if `None` then no filtering would be applied
"""
self._vector_store._initialize()
embedding = self._embed_model.get_text_embedding(text=query)
stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
self._vector_store._table_class.embedding.cosine_distance(embedding).label(
"distance"
),
).order_by(text("distance asc"))

if filters is not None and filters != []:
conditions = []
for condition in filters:
filters_and = []
for key, value in condition.items():
if key == "date":
# Apply ::date cast when the key is 'date'
filter_condition = cast(
self._vector_store._table_class.metadata_.op("->>")(key),
Date,
) == cast(value, Date)
else:
filter_condition = (
self._vector_store._table_class.metadata_.op("->>")(key)
== value
)

filters_and.append(filter_condition)

conditions.append(and_(*filters_and))

stmt = stmt.where(or_(*conditions))

stmt = stmt.limit(self._similarity_top_k)

with self._vector_store._session() as session, session.begin():
res = session.execute(stmt)

results = [
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
similarity=(1 - item.distance) if item.distance is not None else 0,
)
for item in res.all()
]
query_result = self._vector_store._db_rows_to_query_result(results)
nodes = self._get_nodes_with_score(query_result)
return nodes

def _get_nodes_with_score(
self, query_result: VectorStoreQueryResult
) -> list[NodeWithScore]:
"""get nodes from a query_results"""
nodes_with_scores = []
for index, node in enumerate(query_result.nodes):
score: float | None = None
if query_result.similarities is not None:
score = query_result.similarities[index]
nodes_with_scores.append(NodeWithScore(node=node, score=score))

return nodes_with_scores
11 changes: 9 additions & 2 deletions bot/retrievers/summary_retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ def get_similar_nodes(
return nodes

def _setup_index(
self, table_name: str, dbname: str, embedding_model: BaseEmbedding
self,
table_name: str,
dbname: str,
embedding_model: BaseEmbedding,
testing: bool = False,
) -> VectorStoreIndex:
"""
setup the llama_index VectorStoreIndex
"""
pg_vector_access = PGVectorAccess(
table_name=table_name, dbname=dbname, embed_model=embedding_model
table_name=table_name,
dbname=dbname,
embed_model=embedding_model,
testing=testing,
)
index = pg_vector_access.load_index()
return index
4 changes: 3 additions & 1 deletion discord_query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from llama_index import QueryBundle
from llama_index.schema import NodeWithScore
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from utils.query_engine.discord_query_engine import prepare_discord_engine_auto_filter
from utils.query_engine.prepare_discord_query_engine import (
prepare_discord_engine_auto_filter,
)


def query_discord(
Expand Down
4 changes: 1 addition & 3 deletions subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def query_multiple_source(
discord_query_engine = prepare_discord_engine_auto_filter(
community_id,
query,
similarity_top_k=None,
d=None,
)
tool_metadata = ToolMetadata(
name="Discord",
Expand All @@ -100,7 +98,7 @@ def query_multiple_source(
raise NotImplementedError

question_gen = GuidanceQuestionGenerator.from_defaults(
guidance_llm=OpenAIChat("gpt-3.5-turbo"),
guidance_llm=OpenAIChat("gpt-4"),
verbose=False,
)
embed_model = CohereEmbedding()
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/test_retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest import TestCase
from unittest.mock import MagicMock

from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes
from llama_index.schema import TextNode


class TestRetrieveSimilarNodes(TestCase):
def setUp(self):
self.table_name = "sample_table"
self.dbname = "community_some_id"

self.vector_store = MagicMock()
self.embed_model = MagicMock()
self.retriever = RetrieveSimilarNodes(
vector_store=self.vector_store,
similarity_top_k=5,
embed_model=self.embed_model,
)

def test_init(self):
self.assertEqual(self.retriever._similarity_top_k, 5)
self.assertEqual(self.vector_store, self.retriever._vector_store)

def test_get_nodes_with_score(self):
# Test the _get_nodes_with_score private method
query_result = MagicMock()
query_result.nodes = [TextNode(), TextNode(), TextNode()]
query_result.similarities = [0.8, 0.9, 0.7]

result = self.retriever._get_nodes_with_score(query_result)

self.assertEqual(len(result), 3)
self.assertAlmostEqual(result[0].score, 0.8, delta=0.001)
52 changes: 23 additions & 29 deletions tests/unit/test_discord_summary_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def test_initialize_class(self):
documents: list[Document] = []
all_dates: list[str] = []

start_date = parser.parse("2023-08-01")
for i in range(30):
date = parser.parse("2023-08-01") + timedelta(days=i)
date = start_date + timedelta(days=i)
doc_date = date.strftime("%Y-%m-%d")
doc = Document(
text="SAMPLESAMPLESAMPLE",
Expand Down Expand Up @@ -44,39 +45,32 @@ def test_initialize_class(self):
dbname="sample",
embedding_model=mock_embedding_model(),
)
channels, threads, dates = base_summary_search.retreive_metadata(
filters = base_summary_search.retreive_filtering(
query="what is samplesample?",
similarity_top_k=5,
metadata_group1_key="channel",
metadata_group2_key="thread",
metadata_date_key="date",
)
self.assertIsInstance(threads, set)
self.assertIsInstance(channels, set)
self.assertIsInstance(dates, set)

self.assertTrue(
threads.issubset(
set(
[
"thread0",
"thread1",
"thread2",
"thread3",
"thread4",
]
)
)
)
self.assertTrue(
channels.issubset(
set(
[
"channel0",
"channel1",
"channel2",
]
)
self.assertIsInstance(filters, list)

expected_dates = [
(start_date + timedelta(days=i)).strftime("%Y-%m-%d") for i in range(30)
]
for filter in filters:
self.assertIsInstance(filter, dict)
self.assertIn(
filter["thread"],
[
"thread0",
"thread1",
"thread2",
"thread3",
"thread4",
],
)
)
self.assertTrue(dates.issubset(all_dates))
self.assertIn(filter["channel"], ["channel0", "channel1", "channel2"])
date = parser.parse("2023-08-01") + timedelta(days=i)
doc_date = date.strftime("%Y-%m-%d")
self.assertIn(filter["date"], expected_dates)
Loading

0 comments on commit 7e6345b

Please sign in to comment.