Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/discord summarizer prompt update #47

Merged
merged 11 commits into from
Apr 16, 2024
18 changes: 16 additions & 2 deletions bot/retrievers/forum_summary_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def define_filters(
metadata_group1_key: str,
metadata_group2_key: str,
metadata_date_key: str,
**kwargs,
) -> list[dict[str, str]]:
"""
define dictionary filters based on metadata of retrieved nodes
Expand All @@ -77,6 +78,15 @@ def define_filters(
----------
nodes : list[dict[llama_index.schema.NodeWithScore]]
a list of retrieved similar nodes to define filters based
metadata_group1_key : str
the metadata name 1 to use
metadata_group2_key : str
the metadata name 2 to use
metadata_date_key : str
the date key in metadata
**kwargs :
and_filters : dict[str, str]
more `AND` filters to be applied to each

Returns
---------
Expand All @@ -85,16 +95,20 @@ def define_filters(
the dictionary would be applying `and`
operation between keys and values of json metadata_
"""
and_filters: dict[str, str] | None = kwargs.get("and_filters", None)
filters: list[dict[str, str]] = []

for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]

# if more and filters were given
if and_filters:
for key, value in and_filters.items():
filter[key] = value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration of and_filters from kwargs is handled correctly. However, consider adding error handling for cases where and_filters might not be a dictionary as expected.

+ if not isinstance(and_filters, dict):
+     raise ValueError("Expected 'and_filters' to be a dictionary.")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
and_filters: dict[str, str] | None = kwargs.get("and_filters", None)
filters: list[dict[str, str]] = []
for node in nodes:
# the filter made by given node
filter: dict[str, str] = {}
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
# date filter
filter[metadata_date_key] = node.metadata[metadata_date_key]
# if more and filters were given
if and_filters:
for key, value in and_filters.items():
filter[key] = value
and_filters: dict[str, str] | None = kwargs.get("and_filters", None)
if not isinstance(and_filters, dict):
raise ValueError("Expected 'and_filters' to be a dictionary.")
filters: list[dict[str, str]] = []
for node in nodes:
filter: dict[str, str] = {}
filter[metadata_group1_key] = node.metadata[metadata_group1_key]
filter[metadata_group2_key] = node.metadata[metadata_group2_key]
filter[metadata_date_key] = node.metadata[metadata_date_key]
# if more and filters were given
if and_filters:
for key, value in and_filters.items():
filter[key] = value

filters.append(filter)

return filters
77 changes: 58 additions & 19 deletions bot/retrievers/retrieve_similar_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from llama_index.core.vector_stores.types import VectorStoreQueryResult
from llama_index.vector_stores.postgres import PGVectorStore
from llama_index.vector_stores.postgres.base import DBEmbeddingRow
from sqlalchemy import Date, and_, cast, null, or_, select, text
from sqlalchemy import Date, and_, cast, null, or_, select, text, literal, func
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
from uuid import uuid1


class RetrieveSimilarNodes:
Expand Down Expand Up @@ -55,26 +56,53 @@ def query_db(
Note: This would completely disable the similarity search and
it would just return the results with no ordering.
default is `False`. If `True` the query will be ignored and no embedding of it would be fetched
aggregate_records : bool
aggregate records and group by a given term in `group_by_metadata`
group_by_metadata : list[str]
do grouping by some property of `metadata_`
"""
ignore_sort = kwargs.get("ignore_sort", False)
aggregate_records = kwargs.get("aggregate_records", False)
group_by_metadata = kwargs.get("group_by_metadata")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handling of aggregate_records and group_by_metadata is implemented correctly. Ensure that group_by_metadata is always a list to avoid type errors during iteration.

+ if not isinstance(group_by_metadata, list):
+     raise ValueError("Expected 'group_by_metadata' to be a list.")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
aggregate_records : bool
aggregate records and group by a given term in `group_by_metadata`
group_by_metadata : list[str]
do grouping by some property of `metadata_`
"""
ignore_sort = kwargs.get("ignore_sort", False)
aggregate_records = kwargs.get("aggregate_records", False)
group_by_metadata = kwargs.get("group_by_metadata")
aggregate_records : bool
aggregate records and group by a given term in `group_by_metadata`
group_by_metadata : list[str]
do grouping by some property of `metadata_`
"""
ignore_sort = kwargs.get("ignore_sort", False)
aggregate_records = kwargs.get("aggregate_records", False)
group_by_metadata = kwargs.get("group_by_metadata")
if not isinstance(group_by_metadata, list):
raise ValueError("Expected 'group_by_metadata' to be a list.")

self._vector_store._initialize()

if not ignore_sort:
embedding = self._embed_model.get_text_embedding(text=query)
if not aggregate_records:
stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
(
self._vector_store._table_class.embedding.cosine_distance(
self._embed_model.get_text_embedding(text=query)
)
if not ignore_sort
else null()
).label("distance"),
)
else:
embedding = None

stmt = select( # type: ignore
self._vector_store._table_class.id,
self._vector_store._table_class.node_id,
self._vector_store._table_class.text,
self._vector_store._table_class.metadata_,
(
self._vector_store._table_class.embedding.cosine_distance(embedding)
if not ignore_sort
else null()
).label("distance"),
)
# to manually create metadata
metadata_grouping = []
for item in group_by_metadata:
metadata_grouping.append(item)
metadata_grouping.append(
self._vector_store._table_class.metadata_.op("->>")(item)
)

stmt = select(
null().label("id"),
literal(str(uuid1())).label("node_id"),
func.aggregate_strings(
# default content key for llama-index nodes and documents
# is `text`
self._vector_store._table_class.text,
"\n",
).label("text"),
func.json_agg(func.json_build_object(*metadata_grouping)).label(
"metadata_"
),
null().label("distance"),
)

if not ignore_sort:
stmt = stmt.order_by(text("distance asc"))
Expand Down Expand Up @@ -128,8 +156,15 @@ def query_db(

stmt = stmt.where(or_(*conditions))

if self._similarity_top_k is not None:
stmt = stmt.limit(self._similarity_top_k)
if aggregate_records:
group_by_terms = [
self._vector_store._table_class.metadata_.op("->>")(item)
for item in group_by_metadata
]
stmt = stmt.group_by(*group_by_terms)

if self._similarity_top_k is not None:
stmt = stmt.limit(self._similarity_top_k)

with self._vector_store._session() as session, session.begin():
res = session.execute(stmt)
Expand All @@ -138,7 +173,11 @@ def query_db(
DBEmbeddingRow(
node_id=item.node_id,
text=item.text,
metadata=item.metadata_,
# in case of aggregation having null values
# the metadata might will have duplicate date
# so using the first index always will make it right
# in this case, always the metadata should be the same as group_by data
metadata=item.metadata_ if not aggregate_records else item.metadata_[0],
similarity=(1 - item.distance) if item.distance is not None else 0,
)
for item in res.all()
Expand Down
55 changes: 50 additions & 5 deletions utils/query_engine/level_based_platform_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def custom_query(self, query_str: str):
query=query_str, filters=self._filters, date_interval=self._d
)

context_str = self._prepare_context_str(similar_nodes, self.summary_nodes)
context_str = self._prepare_context_str(similar_nodes, summary_nodes=None)
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str)
response = self.llm.complete(fmt_qa_prompt)
logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}")
Expand Down Expand Up @@ -98,6 +98,12 @@ def prepare_platform_engine(
index_summary : VectorStoreIndex
the vector store index for summary data
If not passed, it would just create one itself
summary_nodes_filters : list[dict[str, str]]
a list of filters to fetch the summary nodes
for default, not passing this would mean to use previous nodes
but if passed we would re-fetch nodes.
This could be benefitial in case we want to do some manual
processing with nodes
Comment on lines +101 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of summary_nodes_filters allows for flexible filtering of summary nodes. Ensure that this parameter is properly validated to be a list of dictionaries before use.

+ if summary_nodes_filters is not None and not all(isinstance(f, dict) for f in summary_nodes_filters):
+     raise ValueError("Each filter in 'summary_nodes_filters' must be a dictionary.")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
summary_nodes_filters : list[dict[str, str]]
a list of filters to fetch the summary nodes
for default, not passing this would mean to use previous nodes
but if passed we would re-fetch nodes.
This could be benefitial in case we want to do some manual
processing with nodes
summary_nodes_filters : list[dict[str, str]]
a list of filters to fetch the summary nodes
for default, not passing this would mean to use previous nodes
but if passed we would re-fetch nodes.
This could be benefitial in case we want to do some manual
processing with nodes
if summary_nodes_filters is not None and not all(isinstance(f, dict) for f in summary_nodes_filters):
raise ValueError("Each filter in 'summary_nodes_filters' must be a dictionary.")


Returns
---------
Expand All @@ -115,6 +121,8 @@ def prepare_platform_engine(
"index_raw",
cls._setup_vector_store_index(platform_table_name, dbname, testing),
)
summary_nodes_filters = kwargs.get("summary_nodes_filters", None)

retriever = index.as_retriever()
cls._summary_vector_store = kwargs.get(
"index_summary",
Expand All @@ -130,6 +138,7 @@ def prepare_platform_engine(

cls._similarity_top_k = similarity_top_k
cls._filters = filters
cls._summary_nodes_filters = summary_nodes_filters

return cls(
retriever=retriever,
Expand Down Expand Up @@ -202,12 +211,20 @@ def prepare_engine_auto_filter(
table_name=platform_table_name + "_summary", dbname=dbname
)

filters = platform_retriever.define_filters(
raw_nodes_filters = platform_retriever.define_filters(
nodes,
metadata_group1_key=level1_key,
metadata_group2_key=level2_key,
metadata_date_key=date_key,
)
summary_nodes_filters = platform_retriever.define_filters(
nodes,
metadata_group1_key=level1_key,
metadata_group2_key=level2_key,
metadata_date_key=date_key,
# we will always use thread summaries
and_filters={"type": "thread"},
)

# saving to add summaries to the context of prompt
if include_summary_context:
Expand All @@ -222,18 +239,21 @@ def prepare_engine_auto_filter(
cls._d = d
cls._platform_table_name = platform_table_name

logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}")
logging.debug(
f"COMMUNITY_ID: {community_id} | raw filters: {raw_nodes_filters}"
)

engine = LevelBasedPlatformQueryEngine.prepare_platform_engine(
community_id=community_id,
platform_table_name=platform_table_name,
filters=filters,
filters=raw_nodes_filters,
index_summary=index_summary,
summary_nodes_filters=summary_nodes_filters,
)
return engine

def _prepare_context_str(
self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore]
self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | None
) -> str:
"""
prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information
Expand All @@ -248,6 +268,31 @@ def _prepare_context_str(
context_str += self._utils_class.prepare_prompt_with_metadata_info(
nodes=raw_nodes
)
elif summary_nodes is None:
retriever = RetrieveSimilarNodes(
self._summary_vector_store,
similarity_top_k=None,
)
# Note: `self._summary_nodes_filters` must be set before
fetched_summary_nodes = retriever.query_db(
query="",
filters=self._summary_nodes_filters,
aggregate_records=True,
ignore_sort=True,
group_by_metadata=["thread", "date", "channel"],
date_interval=self._d,
)
grouped_summary_nodes = self._utils_class.group_nodes_per_metadata(
fetched_summary_nodes
)
grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes)
context_data, (
summary_nodes_to_fetch_filters,
_,
) = self._utils_class.prepare_context_str_based_on_summaries(
grouped_raw_nodes, grouped_summary_nodes
)
context_str += context_data
else:
# grouping the data we have so we could
# get them per each metadata without looping over them
Expand Down
1 change: 1 addition & 0 deletions utils/query_engine/level_based_platforms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def group_nodes_per_metadata(
str | None, dict[str | None, dict[str, list[NodeWithScore]]]
] = {}
for node in nodes:
# logging.info(f"node.metadata {node.metadata}")
level1_title = node.metadata[self.level1_key]
level2_title = node.metadata[self.level2_key]
date_str = node.metadata[self.date_key]
Expand Down
Loading