-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/discord summarizer prompt update #47
Changes from 2 commits
0a83f60
a444755
95a6898
0fdd128
8509948
87bf092
0135d0c
3ece4f3
b142f62
ce688e5
24fd48f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,8 +7,9 @@ | |||||||||||||||||||||||||||||||||||||
from llama_index.core.vector_stores.types import VectorStoreQueryResult | ||||||||||||||||||||||||||||||||||||||
from llama_index.vector_stores.postgres import PGVectorStore | ||||||||||||||||||||||||||||||||||||||
from llama_index.vector_stores.postgres.base import DBEmbeddingRow | ||||||||||||||||||||||||||||||||||||||
from sqlalchemy import Date, and_, cast, null, or_, select, text | ||||||||||||||||||||||||||||||||||||||
from sqlalchemy import Date, and_, cast, null, or_, select, text, literal, func | ||||||||||||||||||||||||||||||||||||||
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding | ||||||||||||||||||||||||||||||||||||||
from uuid import uuid1 | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
class RetrieveSimilarNodes: | ||||||||||||||||||||||||||||||||||||||
|
@@ -55,26 +56,53 @@ def query_db( | |||||||||||||||||||||||||||||||||||||
Note: This would completely disable the similarity search and | ||||||||||||||||||||||||||||||||||||||
it would just return the results with no ordering. | ||||||||||||||||||||||||||||||||||||||
default is `False`. If `True` the query will be ignored and no embedding of it would be fetched | ||||||||||||||||||||||||||||||||||||||
aggregate_records : bool | ||||||||||||||||||||||||||||||||||||||
aggregate records and group by a given term in `group_by_metadata` | ||||||||||||||||||||||||||||||||||||||
group_by_metadata : list[str] | ||||||||||||||||||||||||||||||||||||||
do grouping by some property of `metadata_` | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
ignore_sort = kwargs.get("ignore_sort", False) | ||||||||||||||||||||||||||||||||||||||
aggregate_records = kwargs.get("aggregate_records", False) | ||||||||||||||||||||||||||||||||||||||
group_by_metadata = kwargs.get("group_by_metadata") | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The handling of + if not isinstance(group_by_metadata, list):
+ raise ValueError("Expected 'group_by_metadata' to be a list.") Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
self._vector_store._initialize() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if not ignore_sort: | ||||||||||||||||||||||||||||||||||||||
embedding = self._embed_model.get_text_embedding(text=query) | ||||||||||||||||||||||||||||||||||||||
if not aggregate_records: | ||||||||||||||||||||||||||||||||||||||
stmt = select( # type: ignore | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.id, | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.node_id, | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.text, | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.metadata_, | ||||||||||||||||||||||||||||||||||||||
( | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.embedding.cosine_distance( | ||||||||||||||||||||||||||||||||||||||
self._embed_model.get_text_embedding(text=query) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
if not ignore_sort | ||||||||||||||||||||||||||||||||||||||
else null() | ||||||||||||||||||||||||||||||||||||||
).label("distance"), | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
embedding = None | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
stmt = select( # type: ignore | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.id, | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.node_id, | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.text, | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.metadata_, | ||||||||||||||||||||||||||||||||||||||
( | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.embedding.cosine_distance(embedding) | ||||||||||||||||||||||||||||||||||||||
if not ignore_sort | ||||||||||||||||||||||||||||||||||||||
else null() | ||||||||||||||||||||||||||||||||||||||
).label("distance"), | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
# to manually create metadata | ||||||||||||||||||||||||||||||||||||||
metadata_grouping = [] | ||||||||||||||||||||||||||||||||||||||
for item in group_by_metadata: | ||||||||||||||||||||||||||||||||||||||
metadata_grouping.append(item) | ||||||||||||||||||||||||||||||||||||||
metadata_grouping.append( | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.metadata_.op("->>")(item) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
stmt = select( | ||||||||||||||||||||||||||||||||||||||
null().label("id"), | ||||||||||||||||||||||||||||||||||||||
literal(str(uuid1())).label("node_id"), | ||||||||||||||||||||||||||||||||||||||
func.aggregate_strings( | ||||||||||||||||||||||||||||||||||||||
# default content key for llama-index nodes and documents | ||||||||||||||||||||||||||||||||||||||
# is `text` | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.text, | ||||||||||||||||||||||||||||||||||||||
"\n", | ||||||||||||||||||||||||||||||||||||||
).label("text"), | ||||||||||||||||||||||||||||||||||||||
func.json_agg(func.json_build_object(*metadata_grouping)).label( | ||||||||||||||||||||||||||||||||||||||
"metadata_" | ||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||
null().label("distance"), | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if not ignore_sort: | ||||||||||||||||||||||||||||||||||||||
stmt = stmt.order_by(text("distance asc")) | ||||||||||||||||||||||||||||||||||||||
|
@@ -128,8 +156,15 @@ def query_db( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
stmt = stmt.where(or_(*conditions)) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self._similarity_top_k is not None: | ||||||||||||||||||||||||||||||||||||||
stmt = stmt.limit(self._similarity_top_k) | ||||||||||||||||||||||||||||||||||||||
if aggregate_records: | ||||||||||||||||||||||||||||||||||||||
group_by_terms = [ | ||||||||||||||||||||||||||||||||||||||
self._vector_store._table_class.metadata_.op("->>")(item) | ||||||||||||||||||||||||||||||||||||||
for item in group_by_metadata | ||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||
stmt = stmt.group_by(*group_by_terms) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self._similarity_top_k is not None: | ||||||||||||||||||||||||||||||||||||||
stmt = stmt.limit(self._similarity_top_k) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
with self._vector_store._session() as session, session.begin(): | ||||||||||||||||||||||||||||||||||||||
res = session.execute(stmt) | ||||||||||||||||||||||||||||||||||||||
|
@@ -138,7 +173,11 @@ def query_db( | |||||||||||||||||||||||||||||||||||||
DBEmbeddingRow( | ||||||||||||||||||||||||||||||||||||||
node_id=item.node_id, | ||||||||||||||||||||||||||||||||||||||
text=item.text, | ||||||||||||||||||||||||||||||||||||||
metadata=item.metadata_, | ||||||||||||||||||||||||||||||||||||||
# in case of aggregation having null values | ||||||||||||||||||||||||||||||||||||||
# the metadata might will have duplicate date | ||||||||||||||||||||||||||||||||||||||
# so using the first index always will make it right | ||||||||||||||||||||||||||||||||||||||
# in this case, always the metadata should be the same as group_by data | ||||||||||||||||||||||||||||||||||||||
metadata=item.metadata_ if not aggregate_records else item.metadata_[0], | ||||||||||||||||||||||||||||||||||||||
similarity=(1 - item.distance) if item.distance is not None else 0, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
for item in res.all() | ||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -47,7 +47,7 @@ def custom_query(self, query_str: str): | |||||||||||||||||||||||||||||
query=query_str, filters=self._filters, date_interval=self._d | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) | ||||||||||||||||||||||||||||||
context_str = self._prepare_context_str(similar_nodes, summary_nodes=None) | ||||||||||||||||||||||||||||||
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) | ||||||||||||||||||||||||||||||
response = self.llm.complete(fmt_qa_prompt) | ||||||||||||||||||||||||||||||
logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") | ||||||||||||||||||||||||||||||
|
@@ -98,6 +98,12 @@ def prepare_platform_engine( | |||||||||||||||||||||||||||||
index_summary : VectorStoreIndex | ||||||||||||||||||||||||||||||
the vector store index for summary data | ||||||||||||||||||||||||||||||
If not passed, it would just create one itself | ||||||||||||||||||||||||||||||
summary_nodes_filters : list[dict[str, str]] | ||||||||||||||||||||||||||||||
a list of filters to fetch the summary nodes | ||||||||||||||||||||||||||||||
for default, not passing this would mean to use previous nodes | ||||||||||||||||||||||||||||||
but if passed we would re-fetch nodes. | ||||||||||||||||||||||||||||||
This could be benefitial in case we want to do some manual | ||||||||||||||||||||||||||||||
processing with nodes | ||||||||||||||||||||||||||||||
Comment on lines
+101
to
+106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The addition of + if summary_nodes_filters is not None and not all(isinstance(f, dict) for f in summary_nodes_filters):
+ raise ValueError("Each filter in 'summary_nodes_filters' must be a dictionary.") Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||
--------- | ||||||||||||||||||||||||||||||
|
@@ -115,6 +121,8 @@ def prepare_platform_engine( | |||||||||||||||||||||||||||||
"index_raw", | ||||||||||||||||||||||||||||||
cls._setup_vector_store_index(platform_table_name, dbname, testing), | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
summary_nodes_filters = kwargs.get("summary_nodes_filters", None) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
retriever = index.as_retriever() | ||||||||||||||||||||||||||||||
cls._summary_vector_store = kwargs.get( | ||||||||||||||||||||||||||||||
"index_summary", | ||||||||||||||||||||||||||||||
|
@@ -130,6 +138,7 @@ def prepare_platform_engine( | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
cls._similarity_top_k = similarity_top_k | ||||||||||||||||||||||||||||||
cls._filters = filters | ||||||||||||||||||||||||||||||
cls._summary_nodes_filters = summary_nodes_filters | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return cls( | ||||||||||||||||||||||||||||||
retriever=retriever, | ||||||||||||||||||||||||||||||
|
@@ -202,12 +211,20 @@ def prepare_engine_auto_filter( | |||||||||||||||||||||||||||||
table_name=platform_table_name + "_summary", dbname=dbname | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
filters = platform_retriever.define_filters( | ||||||||||||||||||||||||||||||
raw_nodes_filters = platform_retriever.define_filters( | ||||||||||||||||||||||||||||||
nodes, | ||||||||||||||||||||||||||||||
metadata_group1_key=level1_key, | ||||||||||||||||||||||||||||||
metadata_group2_key=level2_key, | ||||||||||||||||||||||||||||||
metadata_date_key=date_key, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
summary_nodes_filters = platform_retriever.define_filters( | ||||||||||||||||||||||||||||||
nodes, | ||||||||||||||||||||||||||||||
metadata_group1_key=level1_key, | ||||||||||||||||||||||||||||||
metadata_group2_key=level2_key, | ||||||||||||||||||||||||||||||
metadata_date_key=date_key, | ||||||||||||||||||||||||||||||
# we will always use thread summaries | ||||||||||||||||||||||||||||||
and_filters={"type": "thread"}, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# saving to add summaries to the context of prompt | ||||||||||||||||||||||||||||||
if include_summary_context: | ||||||||||||||||||||||||||||||
|
@@ -222,18 +239,21 @@ def prepare_engine_auto_filter( | |||||||||||||||||||||||||||||
cls._d = d | ||||||||||||||||||||||||||||||
cls._platform_table_name = platform_table_name | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") | ||||||||||||||||||||||||||||||
logging.debug( | ||||||||||||||||||||||||||||||
f"COMMUNITY_ID: {community_id} | raw filters: {raw_nodes_filters}" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( | ||||||||||||||||||||||||||||||
community_id=community_id, | ||||||||||||||||||||||||||||||
platform_table_name=platform_table_name, | ||||||||||||||||||||||||||||||
filters=filters, | ||||||||||||||||||||||||||||||
filters=raw_nodes_filters, | ||||||||||||||||||||||||||||||
index_summary=index_summary, | ||||||||||||||||||||||||||||||
summary_nodes_filters=summary_nodes_filters, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
return engine | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _prepare_context_str( | ||||||||||||||||||||||||||||||
self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | ||||||||||||||||||||||||||||||
self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | None | ||||||||||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information | ||||||||||||||||||||||||||||||
|
@@ -248,6 +268,31 @@ def _prepare_context_str( | |||||||||||||||||||||||||||||
context_str += self._utils_class.prepare_prompt_with_metadata_info( | ||||||||||||||||||||||||||||||
nodes=raw_nodes | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
elif summary_nodes is None: | ||||||||||||||||||||||||||||||
retriever = RetrieveSimilarNodes( | ||||||||||||||||||||||||||||||
self._summary_vector_store, | ||||||||||||||||||||||||||||||
similarity_top_k=None, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
# Note: `self._summary_nodes_filters` must be set before | ||||||||||||||||||||||||||||||
fetched_summary_nodes = retriever.query_db( | ||||||||||||||||||||||||||||||
query="", | ||||||||||||||||||||||||||||||
filters=self._summary_nodes_filters, | ||||||||||||||||||||||||||||||
aggregate_records=True, | ||||||||||||||||||||||||||||||
ignore_sort=True, | ||||||||||||||||||||||||||||||
group_by_metadata=["thread", "date", "channel"], | ||||||||||||||||||||||||||||||
date_interval=self._d, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( | ||||||||||||||||||||||||||||||
fetched_summary_nodes | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) | ||||||||||||||||||||||||||||||
context_data, ( | ||||||||||||||||||||||||||||||
summary_nodes_to_fetch_filters, | ||||||||||||||||||||||||||||||
_, | ||||||||||||||||||||||||||||||
) = self._utils_class.prepare_context_str_based_on_summaries( | ||||||||||||||||||||||||||||||
grouped_raw_nodes, grouped_summary_nodes | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
context_str += context_data | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
# grouping the data we have so we could | ||||||||||||||||||||||||||||||
# get them per each metadata without looping over them | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration of
and_filters
fromkwargs
is handled correctly. However, consider adding error handling for cases whereand_filters
might not be a dictionary as expected.Committable suggestion