-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/discord summarizer prompt update #47
Changes from 5 commits
0a83f60
a444755
95a6898
0fdd128
8509948
87bf092
0135d0c
3ece4f3
b142f62
ce688e5
24fd48f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -47,7 +47,7 @@ def custom_query(self, query_str: str): | |||||||||||||||||||||||||||||
query=query_str, filters=self._filters, date_interval=self._d | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
context_str = self._prepare_context_str(similar_nodes, self.summary_nodes) | ||||||||||||||||||||||||||||||
context_str = self._prepare_context_str(similar_nodes, summary_nodes=None) | ||||||||||||||||||||||||||||||
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str=query_str) | ||||||||||||||||||||||||||||||
response = self.llm.complete(fmt_qa_prompt) | ||||||||||||||||||||||||||||||
logging.debug(f"fmt_qa_prompt:\n{fmt_qa_prompt}") | ||||||||||||||||||||||||||||||
|
@@ -98,6 +98,12 @@ def prepare_platform_engine( | |||||||||||||||||||||||||||||
index_summary : VectorStoreIndex | ||||||||||||||||||||||||||||||
the vector store index for summary data | ||||||||||||||||||||||||||||||
If not passed, it would just create one itself | ||||||||||||||||||||||||||||||
summary_nodes_filters : list[dict[str, str]] | ||||||||||||||||||||||||||||||
a list of filters to fetch the summary nodes | ||||||||||||||||||||||||||||||
for default, not passing this would mean to use previous nodes | ||||||||||||||||||||||||||||||
but if passed we would re-fetch nodes. | ||||||||||||||||||||||||||||||
This could be benefitial in case we want to do some manual | ||||||||||||||||||||||||||||||
processing with nodes | ||||||||||||||||||||||||||||||
Comment on lines
+101
to
+106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The addition of + if summary_nodes_filters is not None and not all(isinstance(f, dict) for f in summary_nodes_filters):
+ raise ValueError("Each filter in 'summary_nodes_filters' must be a dictionary.") Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||
--------- | ||||||||||||||||||||||||||||||
|
@@ -115,6 +121,8 @@ def prepare_platform_engine( | |||||||||||||||||||||||||||||
"index_raw", | ||||||||||||||||||||||||||||||
cls._setup_vector_store_index(platform_table_name, dbname, testing), | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
summary_nodes_filters = kwargs.get("summary_nodes_filters", None) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
retriever = index.as_retriever() | ||||||||||||||||||||||||||||||
cls._summary_vector_store = kwargs.get( | ||||||||||||||||||||||||||||||
"index_summary", | ||||||||||||||||||||||||||||||
|
@@ -130,6 +138,7 @@ def prepare_platform_engine( | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
cls._similarity_top_k = similarity_top_k | ||||||||||||||||||||||||||||||
cls._filters = filters | ||||||||||||||||||||||||||||||
cls._summary_nodes_filters = summary_nodes_filters | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
return cls( | ||||||||||||||||||||||||||||||
retriever=retriever, | ||||||||||||||||||||||||||||||
|
@@ -202,12 +211,20 @@ def prepare_engine_auto_filter( | |||||||||||||||||||||||||||||
table_name=platform_table_name + "_summary", dbname=dbname | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
filters = platform_retriever.define_filters( | ||||||||||||||||||||||||||||||
raw_nodes_filters = platform_retriever.define_filters( | ||||||||||||||||||||||||||||||
nodes, | ||||||||||||||||||||||||||||||
metadata_group1_key=level1_key, | ||||||||||||||||||||||||||||||
metadata_group2_key=level2_key, | ||||||||||||||||||||||||||||||
metadata_date_key=date_key, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
summary_nodes_filters = platform_retriever.define_filters( | ||||||||||||||||||||||||||||||
nodes, | ||||||||||||||||||||||||||||||
metadata_group1_key=level1_key, | ||||||||||||||||||||||||||||||
metadata_group2_key=level2_key, | ||||||||||||||||||||||||||||||
metadata_date_key=date_key, | ||||||||||||||||||||||||||||||
# we will always use thread summaries | ||||||||||||||||||||||||||||||
and_filters={"type": "thread"}, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# saving to add summaries to the context of prompt | ||||||||||||||||||||||||||||||
if include_summary_context: | ||||||||||||||||||||||||||||||
|
@@ -222,18 +239,21 @@ def prepare_engine_auto_filter( | |||||||||||||||||||||||||||||
cls._d = d | ||||||||||||||||||||||||||||||
cls._platform_table_name = platform_table_name | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
logging.debug(f"COMMUNITY_ID: {community_id} | summary filters: {filters}") | ||||||||||||||||||||||||||||||
logging.debug( | ||||||||||||||||||||||||||||||
f"COMMUNITY_ID: {community_id} | raw filters: {raw_nodes_filters}" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
engine = LevelBasedPlatformQueryEngine.prepare_platform_engine( | ||||||||||||||||||||||||||||||
community_id=community_id, | ||||||||||||||||||||||||||||||
platform_table_name=platform_table_name, | ||||||||||||||||||||||||||||||
filters=filters, | ||||||||||||||||||||||||||||||
filters=raw_nodes_filters, | ||||||||||||||||||||||||||||||
index_summary=index_summary, | ||||||||||||||||||||||||||||||
summary_nodes_filters=summary_nodes_filters, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
return engine | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _prepare_context_str( | ||||||||||||||||||||||||||||||
self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | ||||||||||||||||||||||||||||||
self, raw_nodes: list[NodeWithScore], summary_nodes: list[NodeWithScore] | None | ||||||||||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
prepare the prompt context using the raw_nodes for answers and summary_nodes for additional information | ||||||||||||||||||||||||||||||
|
@@ -248,6 +268,31 @@ def _prepare_context_str( | |||||||||||||||||||||||||||||
context_str += self._utils_class.prepare_prompt_with_metadata_info( | ||||||||||||||||||||||||||||||
nodes=raw_nodes | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
elif summary_nodes is None: | ||||||||||||||||||||||||||||||
retriever = RetrieveSimilarNodes( | ||||||||||||||||||||||||||||||
self._summary_vector_store, | ||||||||||||||||||||||||||||||
similarity_top_k=None, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
# Note: `self._summary_nodes_filters` must be set before | ||||||||||||||||||||||||||||||
fetched_summary_nodes = retriever.query_db( | ||||||||||||||||||||||||||||||
query="", | ||||||||||||||||||||||||||||||
filters=self._summary_nodes_filters, | ||||||||||||||||||||||||||||||
aggregate_records=True, | ||||||||||||||||||||||||||||||
ignore_sort=True, | ||||||||||||||||||||||||||||||
group_by_metadata=["thread", "date", "channel"], | ||||||||||||||||||||||||||||||
date_interval=self._d, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
grouped_summary_nodes = self._utils_class.group_nodes_per_metadata( | ||||||||||||||||||||||||||||||
fetched_summary_nodes | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
grouped_raw_nodes = self._utils_class.group_nodes_per_metadata(raw_nodes) | ||||||||||||||||||||||||||||||
context_data, ( | ||||||||||||||||||||||||||||||
summary_nodes_to_fetch_filters, | ||||||||||||||||||||||||||||||
_, | ||||||||||||||||||||||||||||||
) = self._utils_class.prepare_context_str_based_on_summaries( | ||||||||||||||||||||||||||||||
grouped_raw_nodes, grouped_summary_nodes | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
context_str += context_data | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
# grouping the data we have so we could | ||||||||||||||||||||||||||||||
# get them per each metadata without looping over them | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration of
and_filters
fromkwargs
is handled correctly. However, consider adding error handling for cases whereand_filters
might not be a dictionary as expected.Committable suggestion