Skip to content

Commit

Permalink
Merge pull request #32 from TogetherCrew/feat/prompt-update
Browse files Browse the repository at this point in the history
Added summaries to prompts and several updates on node retrieval
  • Loading branch information
amindadgar authored Feb 12, 2024
2 parents 7e6345b + 946f074 commit 22d8914
Show file tree
Hide file tree
Showing 10 changed files with 876 additions and 64 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,5 @@ cython_debug/
hivemind-bot-env/*
main.ipynb
.DS_Store

temp_test_run_data.json
38 changes: 34 additions & 4 deletions bot/retrievers/forum_summary_retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from bot.retrievers.summary_retriever_base import BaseSummarySearch
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import NodeWithScore
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


Expand Down Expand Up @@ -53,15 +54,44 @@ def retreive_filtering(
"""
nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k)

filters = self.define_filters(
nodes=nodes,
metadata_group1_key=metadata_group1_key,
metadata_group2_key=metadata_group2_key,
metadata_date_key=metadata_date_key,
)

return filters

def define_filters(
self,
nodes: list[NodeWithScore],
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
) -> list[dict[str, str]]:
"""
define dictionary filters based on metadata of retrieved nodes
Parameters
----------
nodes : list[dict[llama_index.schema.NodeWithScore]]
a list of retrieved similar nodes to define filters based
Returns
---------
filters : list[dict[str, str]]
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
"""
filters: list[dict[str, str]] = []

for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
if node.metadata[metadata_group1_key]:
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
if node.metadata[metadata_group2_key]:
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]

Expand Down
4 changes: 3 additions & 1 deletion bot/retrievers/process_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def process_dates(dates: list[str], d: int) -> list[str]:
Returns
----------
dates_modified : list[str]
days added to it
days added to it sorted ascending meaning
the first index is the lowest date
and the last is the biggest date
"""
dates_modified: list[str] = []
if dates != []:
Expand Down
86 changes: 71 additions & 15 deletions bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from datetime import datetime, timedelta

from dateutil import parser
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import NodeWithScore
from llama_index.vector_stores import PGVectorStore, VectorStoreQueryResult
from llama_index.vector_stores.postgres import DBEmbeddingRow
from sqlalchemy import Date, and_, cast, or_, select, text
from sqlalchemy import Date, and_, cast, null, or_, select, text
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding


Expand All @@ -12,7 +15,7 @@ class RetrieveSimilarNodes:
def __init__(
self,
vector_store: PGVectorStore,
similarity_top_k: int,
similarity_top_k: int | None,
embed_model: BaseEmbedding = CohereEmbedding(),
) -> None:
"""Init params."""
Expand All @@ -21,7 +24,11 @@ def __init__(
self._similarity_top_k = similarity_top_k

def query_db(
self, query: str, filters: list[dict[str, str]] | None = None
self,
query: str,
filters: list[dict[str, str | dict | None]] | None = None,
date_interval: int = 0,
**kwargs
) -> list[NodeWithScore]:
"""
query database with given filters (similarity search is also done)
Expand All @@ -30,48 +37,97 @@ def query_db(
-------------
query : str
the user question
filters : list[dict[str, str]] | None
filters : list[dict[str, str | dict | None]] | None
a list of filters to apply with `or` condition
the dictionary would be applying `and`
operation between keys and values of json metadata_
if `None` then no filtering would be applied
the value can be a dictionary with one key of "ne" and a value
which means to do a not equal operator `!=`
if `None` then no filtering would be applied.
date_interval : int
the number of back and forth days of date
default is set to 0 meaning no days back or forward.
**kwargs
ignore_sort : bool
to ignore sort by vector similarity.
Note: This would completely disable the similarity search and
it would just return the results with no ordering.
default is `False`. If `True` the query will be ignored and no embedding of it would be fetched
"""
ignore_sort = kwargs.get("ignore_sort", False)
self._vector_store._initialize()
embedding = self._embed_model.get_text_embedding(text=query)

if not ignore_sort:
embedding = self._embed_model.get_text_embedding(text=query)
else:
embedding = None

stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
self._vector_store._table_class.embedding.cosine_distance(embedding).label(
"distance"
),
).order_by(text("distance asc"))
(
self._vector_store._table_class.embedding.cosine_distance(embedding)
if not ignore_sort
else null()
).label("distance"),
)

if not ignore_sort:
stmt = stmt.order_by(text("distance asc"))

if filters is not None and filters != []:
conditions = []
for condition in filters:
filters_and = []
for key, value in condition.items():
if key == "date":
date: datetime
if isinstance(value, str):
date = parser.parse(value)
else:
raise ValueError(
"the values for filtering dates must be string!"
)
date_back = (date - timedelta(days=date_interval)).strftime(
"%Y-%m-%d"
)
date_forward = (date + timedelta(days=date_interval)).strftime(
"%Y-%m-%d"
)

# Apply ::date cast when the key is 'date'
filter_condition = cast(
filter_condition_back = cast(
self._vector_store._table_class.metadata_.op("->>")(key),
Date,
) >= cast(date_back, Date)

filter_condition_forward = cast(
self._vector_store._table_class.metadata_.op("->>")(key),
Date,
) == cast(value, Date)
) <= cast(date_forward, Date)

filters_and.append(filter_condition_back)
filters_and.append(filter_condition_forward)
else:
filter_condition = (
self._vector_store._table_class.metadata_.op("->>")(key)
== value
if not isinstance(value, dict)
else self._vector_store._table_class.metadata_.op("->>")(
key
)
!= value["ne"]
)

filters_and.append(filter_condition)
filters_and.append(filter_condition)

conditions.append(and_(*filters_and))

stmt = stmt.where(or_(*conditions))

stmt = stmt.limit(self._similarity_top_k)
if self._similarity_top_k is not None:
stmt = stmt.limit(self._similarity_top_k)

with self._vector_store._session() as session, session.begin():
res = session.execute(stmt)
Expand Down
64 changes: 56 additions & 8 deletions tests/unit/test_level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from unittest.mock import patch

from bot.retrievers.forum_summary_retriever import ForumBasedSummaryRetriever
from bot.retrievers.retrieve_similar_nodes import RetrieveSimilarNodes
from llama_index.schema import NodeWithScore, TextNode
from sqlalchemy.exc import OperationalError
from utils.query_engine.level_based_platform_query_engine import (
LevelBasedPlatformQueryEngine,
)
Expand All @@ -26,9 +29,9 @@ def test_prepare_platform_engine(self):
"""
# the output should always have a `date` key for each dictionary
filters = [
{"channel": "general", "date": "2023-01-02"},
{"thread": "discussion", "date": "2024-01-03"},
{"date": "2022-01-01"},
{"channel": "general", "thread": "some_thread", "date": "2023-01-02"},
{"channel": "general", "thread": "discussion", "date": "2024-01-03"},
{"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"},
]

engine = LevelBasedPlatformQueryEngine.prepare_platform_engine(
Expand All @@ -39,26 +42,71 @@ def test_prepare_platform_engine(self):
)
self.assertIsNotNone(engine)

def test_prepare_engine_auto_filter(self):
def test_prepare_engine_auto_filter_raise_error(self):
"""
Test prepare_engine_auto_filter method with sample data
when an error was raised
"""
with patch.object(
ForumBasedSummaryRetriever, "retreive_filtering"
ForumBasedSummaryRetriever, "define_filters"
) as mock_retriever:
# the output should always have a `date` key for each dictionary
mock_retriever.return_value = [
{"channel": "general", "date": "2023-01-02"},
{"thread": "discussion", "date": "2024-01-03"},
{"date": "2022-01-01"},
{"channel": "general", "thread": "some_thread", "date": "2023-01-02"},
{"channel": "general", "thread": "discussion", "date": "2024-01-03"},
{"channel": "general#2", "thread": "Agenda", "date": "2022-01-01"},
]

with self.assertRaises(OperationalError):
# no database with name of `test_community` is available
_ = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
platform_table_name=self.platform_table_name,
level1_key=self.level1_key,
level2_key=self.level2_key,
date_key=self.date_key,
)

def test_prepare_engine_auto_filter(self):
"""
Test prepare_engine_auto_filter method with sample data in normal condition
"""
with patch.object(RetrieveSimilarNodes, "query_db") as mock_query:
# the output should always have a `date` key for each dictionary
mock_query.return_value = [
NodeWithScore(
node=TextNode(
text="some summaries #1",
metadata={
"thread": "thread#1",
"channel": "channel#1",
"date": "2022-01-01",
},
),
score=0,
),
NodeWithScore(
node=TextNode(
text="some summaries #2",
metadata={
"thread": "thread#3",
"channel": "channel#2",
"date": "2022-01-02",
},
),
score=0,
),
]

# no database with name of `test_community` is available
engine = LevelBasedPlatformQueryEngine.prepare_engine_auto_filter(
community_id=self.community_id,
query="test query",
platform_table_name=self.platform_table_name,
level1_key=self.level1_key,
level2_key=self.level2_key,
date_key=self.date_key,
include_summary_context=True,
)
self.assertIsNotNone(engine)
Loading

0 comments on commit 22d8914

Please sign in to comment.