Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/discord summarizer prompt update #47

Merged
merged 11 commits into from
Apr 16, 2024
35 changes: 26 additions & 9 deletions bot/retrievers/forum_summary_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,28 @@ def define_filters(
nodes: list[NodeWithScore],
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
**kwargs,
) -> list[dict[str, str]]:
"""
define dictionary filters based on metadata of retrieved nodes
Creates filter dictionaries based on node metadata.

Filters each node by values in specified metadata groups and an optional date key.
Additional and filters can also be provided.

Parameters
----------
nodes : list[dict[llama_index.schema.NodeWithScore]]
a list of retrieved similar nodes to define filters based
metadata_group1_key : str
the metadata name 1 to use
metadata_group2_key : str
the metadata name 2 to use
**kwargs :
metadata_date_key : str
the date key in metadata
default is `date`
and_filters : dict[str, str]
more `AND` filters to be applied to each

Returns
---------
Expand All @@ -85,16 +98,20 @@ def define_filters(
the dictionary would be applying `and`
operation between keys and values of json metadata_
"""
and_filters: dict[str, str] | None = kwargs.get("and_filters", None)
metadata_date_key: str = kwargs.get("metadata_date_key", "date")
filters: list[dict[str, str]] = []

for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]
filter_dict: dict[str, str] = {
metadata_group1_key: node.metadata[metadata_group1_key],
metadata_group2_key: node.metadata[metadata_group2_key],
metadata_date_key: node.metadata[metadata_date_key],
}
# if more and filters were given
if and_filters:
filter_dict.update(and_filters)

filters.append(filter)
filters.append(filter_dict)

return filters
75 changes: 56 additions & 19 deletions bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from uuid import uuid1

from dateutil import parser
from llama_index.core.data_structs import Node
Expand All @@ -7,7 +8,7 @@
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from llama_index.vector_stores.postgres import PGVectorStore
from llama_index.vector_stores.postgres.base import DBEmbeddingRow
from sqlalchemy import Date, and_, cast, null, or_, select, text
from sqlalchemy import Date, and_, cast, func, literal, null, or_, select, text
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


Expand All @@ -23,14 +24,15 @@ def __init__(
"""Init params."""
self._vector_store = vector_store
self._embed_model = embed_model
print(f"type(embed_model): {type(embed_model)} | embed_model: {embed_model}")
self._similarity_top_k = similarity_top_k

def query_db(
self,
query: str,
filters: list[dict[str, str | dict | None]] | None = None,
date_interval: int = 0,
**kwargs
**kwargs,
) -> list[NodeWithScore]:
"""
query database with given filters (similarity search is also done)
Expand All @@ -55,26 +57,54 @@ def query_db(
Note: This would completely disable the similarity search and
it would just return the results with no ordering.
default is `False`. If `True` the query will be ignored and no embedding of it would be fetched
aggregate_records : bool
aggregate records and group by a given term in `group_by_metadata`
group_by_metadata : list[str]
do grouping by some property of `metadata_`
"""
ignore_sort = kwargs.get("ignore_sort", False)
aggregate_records = kwargs.get("aggregate_records", False)
group_by_metadata = kwargs.get("group_by_metadata", [])
amindadgar marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(group_by_metadata, list):
raise ValueError("Expected 'group_by_metadata' to be a list.")

self._vector_store._initialize()

if not ignore_sort:
embedding = self._embed_model.get_text_embedding(text=query)
if not aggregate_records:
stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
(
self._vector_store._table_class.embedding.cosine_distance(
self._embed_model.get_text_embedding(text=query)
)
if not ignore_sort
else null()
).label("distance"),
)
else:
embedding = None

stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
(
self._vector_store._table_class.embedding.cosine_distance(embedding)
if not ignore_sort
else null()
).label("distance"),
)
# to manually create metadata
metadata_grouping = []
for item in group_by_metadata:
metadata_grouping.append(item)
metadata_grouping.append(
self._vector_store._table_class.metadata_.op("->>")(item)
)

stmt = select(
null().label("id"),
literal(str(uuid1())).label("node_id"),
func.aggregate_strings(
# default content key for llama-index nodes and documents
# is `text`
self._vector_store._table_class.text,
"\n",
).label("text"),
func.json_build_object(*metadata_grouping).label("metadata_"),
null().label("distance"),
)

if not ignore_sort:
stmt = stmt.order_by(text("distance asc"))
Expand Down Expand Up @@ -128,8 +158,15 @@ def query_db(

stmt = stmt.where(or_(*conditions))

if self._similarity_top_k is not None:
stmt = stmt.limit(self._similarity_top_k)
if aggregate_records:
group_by_terms = [
self._vector_store._table_class.metadata_.op("->>")(item)
for item in group_by_metadata
]
stmt = stmt.group_by(*group_by_terms)

if self._similarity_top_k is not None:
stmt = stmt.limit(self._similarity_top_k)

with self._vector_store._session() as session, session.begin():
res = session.execute(stmt)
Expand Down
34 changes: 0 additions & 34 deletions tests/integration/test_retrieve_similar_nodes.py

This file was deleted.

94 changes: 94 additions & 0 deletions tests/unit/test_retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from unittest import TestCase
from unittest.mock import MagicMock, patch

from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.vector_stores.postgres import PGVectorStore


class TestRetrieveSimilarNodes(TestCase):
def setUp(self):
self.table_name = "sample_table"
self.dbname = "community_some_id"

self.vector_store = PGVectorStore.from_params(
database="sample_db",
host="sample_host",
password="pass",
port=5432,
user="user",
table_name=self.table_name,
embed_dim=1536,
)
self.embed_model = MagicMock()
self.retriever = RetrieveSimilarNodes(
vector_store=self.vector_store,
similarity_top_k=5,
embed_model=self.embed_model,
)

def test_init(self):
self.assertEqual(self.retriever._similarity_top_k, 5)
self.assertEqual(self.vector_store, self.retriever._vector_store)

def test_get_nodes_with_score(self):
# Test the _get_nodes_with_score private method
query_result = MagicMock()
query_result.nodes = [TextNode(), TextNode(), TextNode()]
query_result.similarities = [0.8, 0.9, 0.7]

result = self.retriever._get_nodes_with_score(query_result)

self.assertEqual(len(result), 3)
self.assertAlmostEqual(result[0].score, 0.8, delta=0.001)

@patch.object(PGVectorStore, "_initialize")
@patch.object(PGVectorStore, "_session")
def test_query_db_with_filters_and_date(self, mock_session, mock_initialize):
# Mock vector store initialization
mock_initialize.return_value = None
mock_session.begin = MagicMock()
mock_session.execute = MagicMock()
mock_session.execute.return_value = [1]

query = "test query"
filters = [{"date": "2024-04-09"}]
date_interval = 2 # Look for nodes within 2 days of the filter date

# Call the query_db method with filters and date
results = self.retriever.query_db(query, filters, date_interval)

mock_initialize.assert_called_once()
mock_session.assert_called_once()

# Assert that the returned results are of type NodeWithScore
self.assertTrue(isinstance(result, NodeWithScore) for result in results)

@patch.object(PGVectorStore, "_initialize")
@patch.object(PGVectorStore, "_session")
def test_query_db_with_filters_and_date_aggregate_records(
self, mock_session, mock_initialize
):
mock_initialize.return_value = None
mock_session.begin = MagicMock()
mock_session.execute = MagicMock()
mock_session.execute.return_value = [1]

query = "test query"
filters = [{"date": "2024-04-09"}]
date_interval = 2 # Look for nodes within 2 days of the filter date

# Call the query_db method with filters and date
results = self.retriever.query_db(
query,
filters,
date_interval,
aggregate_records=True,
group_by_metadata=["thread"],
)

mock_initialize.assert_called_once()
mock_session.assert_called_once()

# Assert that the returned results are of type NodeWithScore
self.assertTrue(isinstance(result, NodeWithScore) for result in results)
Loading
Loading