Skip to content

Commit

Permalink
streaming support for videobots
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Jan 31, 2024
1 parent 2cd686a commit 975dae3
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 158 deletions.
272 changes: 150 additions & 122 deletions daras_ai_v2/search_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,150 +57,178 @@ def render_text_with_refs(text: str, references: list[SearchReference]):
return html


def apply_response_template(
def apply_response_formattings_prefix(
output_text: list[str],
references: list[SearchReference],
citation_style: CitationStyles | None = CitationStyles.number,
):
) -> list[dict[int, SearchReference]]:
all_refs_list = [{}] * len(output_text)
for i, text in enumerate(output_text):
formatted = ""
all_refs = {}

for snippet, ref_map in parse_refs(text, references):
match citation_style:
case CitationStyles.number | CitationStyles.number_plaintext:
cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
case CitationStyles.title:
cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
case CitationStyles.url:
cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
case CitationStyles.symbol | CitationStyles.symbol_plaintext:
cites = " ".join(
f"[{generate_footnote_symbol(ref_num - 1)}]"
for ref_num in ref_map.keys()
)
all_refs_list[i], output_text[i] = format_citations(
text, references, citation_style
)
return all_refs_list

case CitationStyles.markdown:
cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
case CitationStyles.html:
cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
case CitationStyles.slack_mrkdwn:
cites = " ".join(
ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
)
case CitationStyles.plaintext:
cites = " ".join(
f'[{ref["title"]} {ref["url"]}]'
for ref_num, ref in ref_map.items()
)

case CitationStyles.number_markdown:
cites = " ".join(
markdown_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.number_html:
cites = " ".join(
html_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.number_slack_mrkdwn:
cites = " ".join(
slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
def apply_response_formattings_suffix(
all_refs_list: list[dict[int, SearchReference]],
output_text: list[str],
citation_style: CitationStyles | None = CitationStyles.number,
):
for i, text in enumerate(output_text):
output_text[i] = format_jinja_response_template(
all_refs_list[i],
format_footnotes(all_refs_list[i], text, citation_style),
)

case CitationStyles.symbol_markdown:
cites = " ".join(
markdown_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_html:
cites = " ".join(
html_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_slack_mrkdwn:
cites = " ".join(
slack_mrkdwn_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case None:
cites = ""
case _:
raise ValueError(f"Unknown citation style: {citation_style}")
formatted += snippet + " " + cites + " "
all_refs.update(ref_map)

def format_citations(
text: str,
references: list[SearchReference],
citation_style: CitationStyles | None = CitationStyles.number,
) -> tuple[dict[int, SearchReference], str]:
all_refs = {}
formatted = ""
for snippet, ref_map in parse_refs(text, references):
match citation_style:
case CitationStyles.number | CitationStyles.number_plaintext:
cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
case CitationStyles.title:
cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
case CitationStyles.url:
cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
case CitationStyles.symbol | CitationStyles.symbol_plaintext:
cites = " ".join(
f"[{generate_footnote_symbol(ref_num - 1)}]"
for ref_num in ref_map.keys()
)

case CitationStyles.markdown:
cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
case CitationStyles.html:
cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
case CitationStyles.slack_mrkdwn:
cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values())
case CitationStyles.plaintext:
cites = " ".join(
f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items()
)

case CitationStyles.number_markdown:
formatted += "\n\n"
formatted += "\n".join(
f"[{ref_num}] {ref_to_markdown(ref)}"
for ref_num, ref in sorted(all_refs.items())
cites = " ".join(
markdown_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.number_html:
formatted += "<br><br>"
formatted += "<br>".join(
f"[{ref_num}] {ref_to_html(ref)}"
for ref_num, ref in sorted(all_refs.items())
cites = " ".join(
html_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.number_slack_mrkdwn:
formatted += "\n\n"
formatted += "\n".join(
f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.number_plaintext:
formatted += "\n\n"
formatted += "\n".join(
f'{ref_num}. {ref["title"]} {ref["url"]}'
for ref_num, ref in sorted(all_refs.items())
cites = " ".join(
slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)

case CitationStyles.symbol_markdown:
formatted += "\n\n"
formatted += "\n".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
for ref_num, ref in sorted(all_refs.items())
cites = " ".join(
markdown_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_html:
formatted += "<br><br>"
formatted += "<br>".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
for ref_num, ref in sorted(all_refs.items())
cites = " ".join(
html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_slack_mrkdwn:
formatted += "\n\n"
formatted += "\n".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_plaintext:
formatted += "\n\n"
formatted += "\n".join(
f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
for ref_num, ref in sorted(all_refs.items())
)

for ref_num, ref in all_refs.items():
try:
template = ref["response_template"]
except KeyError:
pass
else:
formatted = jinja2.Template(template).render(
**ref,
output_text=formatted,
ref_num=ref_num,
cites = " ".join(
slack_mrkdwn_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
output_text[i] = formatted
case None:
cites = ""
case _:
raise ValueError(f"Unknown citation style: {citation_style}")
formatted += " ".join(filter(None, [snippet, cites]))
all_refs.update(ref_map)
return all_refs, formatted


def format_footnotes(
all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles
) -> str:
match citation_style:
case CitationStyles.number_markdown:
formatted += "\n\n"
formatted += "\n".join(
f"[{ref_num}] {ref_to_markdown(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.number_html:
formatted += "<br><br>"
formatted += "<br>".join(
f"[{ref_num}] {ref_to_html(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.number_slack_mrkdwn:
formatted += "\n\n"
formatted += "\n".join(
f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.number_plaintext:
formatted += "\n\n"
formatted += "\n".join(
f'{ref_num}. {ref["title"]} {ref["url"]}'
for ref_num, ref in sorted(all_refs.items())
)

case CitationStyles.symbol_markdown:
formatted += "\n\n"
formatted += "\n".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_html:
formatted += "<br><br>"
formatted += "<br>".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_slack_mrkdwn:
formatted += "\n\n"
formatted += "\n".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_plaintext:
formatted += "\n\n"
formatted += "\n".join(
f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
for ref_num, ref in sorted(all_refs.items())
)
return formatted


def format_jinja_response_template(
all_refs: dict[int, SearchReference], formatted: str
) -> str:
for ref_num, ref in all_refs.items():
try:
template = ref["response_template"]
except KeyError:
pass
else:
formatted = jinja2.Template(template).render(
**ref,
output_text=formatted,
ref_num=ref_num,
)
return formatted


search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]")
Expand Down
8 changes: 6 additions & 2 deletions recipes/DocSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from daras_ai_v2.search_ref import (
SearchReference,
render_output_with_refs,
apply_response_template,
CitationStyles,
apply_response_formattings_prefix,
apply_response_formattings_suffix,
)
from daras_ai_v2.vector_search import (
DocSearchRequest,
Expand Down Expand Up @@ -194,9 +195,12 @@ def run_v2(
citation_style = (
request.citation_style and CitationStyles[request.citation_style]
) or None
apply_response_template(
all_refs_list = apply_response_formattings_prefix(
response.output_text, response.references, citation_style
)
apply_response_formattings_suffix(
all_refs_list, response.output_text, citation_style
)

def get_raw_price(self, state: dict) -> float:
name = state.get("selected_model")
Expand Down
Loading

0 comments on commit 975dae3

Please sign in to comment.