diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py
index 1baf64b47..7cc810b1a 100644
--- a/daras_ai_v2/search_ref.py
+++ b/daras_ai_v2/search_ref.py
@@ -57,150 +57,178 @@ def render_text_with_refs(text: str, references: list[SearchReference]):
return html
-def apply_response_template(
+def apply_response_formattings_prefix(
output_text: list[str],
references: list[SearchReference],
citation_style: CitationStyles | None = CitationStyles.number,
-):
+) -> list[dict[int, SearchReference]]:
+ all_refs_list = [{}] * len(output_text)
for i, text in enumerate(output_text):
- formatted = ""
- all_refs = {}
-
- for snippet, ref_map in parse_refs(text, references):
- match citation_style:
- case CitationStyles.number | CitationStyles.number_plaintext:
- cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
- case CitationStyles.title:
- cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
- case CitationStyles.url:
- cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
- case CitationStyles.symbol | CitationStyles.symbol_plaintext:
- cites = " ".join(
- f"[{generate_footnote_symbol(ref_num - 1)}]"
- for ref_num in ref_map.keys()
- )
+ all_refs_list[i], output_text[i] = format_citations(
+ text, references, citation_style
+ )
+ return all_refs_list
- case CitationStyles.markdown:
- cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
- case CitationStyles.html:
- cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
- case CitationStyles.slack_mrkdwn:
- cites = " ".join(
- ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
- )
- case CitationStyles.plaintext:
- cites = " ".join(
- f'[{ref["title"]} {ref["url"]}]'
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.number_markdown:
- cites = " ".join(
- markdown_link(f"[{ref_num}]", ref["url"])
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.number_html:
- cites = " ".join(
- html_link(f"[{ref_num}]", ref["url"])
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.number_slack_mrkdwn:
- cites = " ".join(
- slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
- for ref_num, ref in ref_map.items()
- )
+def apply_response_formattings_suffix(
+ all_refs_list: list[dict[int, SearchReference]],
+ output_text: list[str],
+ citation_style: CitationStyles | None = CitationStyles.number,
+):
+ for i, text in enumerate(output_text):
+ output_text[i] = format_jinja_response_template(
+ all_refs_list[i],
+ format_footnotes(all_refs_list[i], text, citation_style),
+ )
- case CitationStyles.symbol_markdown:
- cites = " ".join(
- markdown_link(
- f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
- )
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.symbol_html:
- cites = " ".join(
- html_link(
- f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
- )
- for ref_num, ref in ref_map.items()
- )
- case CitationStyles.symbol_slack_mrkdwn:
- cites = " ".join(
- slack_mrkdwn_link(
- f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
- )
- for ref_num, ref in ref_map.items()
- )
- case None:
- cites = ""
- case _:
- raise ValueError(f"Unknown citation style: {citation_style}")
- formatted += snippet + " " + cites + " "
- all_refs.update(ref_map)
+def format_citations(
+ text: str,
+ references: list[SearchReference],
+ citation_style: CitationStyles | None = CitationStyles.number,
+) -> tuple[dict[int, SearchReference], str]:
+ all_refs = {}
+ formatted = ""
+ for snippet, ref_map in parse_refs(text, references):
match citation_style:
+ case CitationStyles.number | CitationStyles.number_plaintext:
+ cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
+ case CitationStyles.title:
+ cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
+ case CitationStyles.url:
+ cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
+ case CitationStyles.symbol | CitationStyles.symbol_plaintext:
+ cites = " ".join(
+ f"[{generate_footnote_symbol(ref_num - 1)}]"
+ for ref_num in ref_map.keys()
+ )
+
+ case CitationStyles.markdown:
+ cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
+ case CitationStyles.html:
+ cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
+ case CitationStyles.slack_mrkdwn:
+ cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values())
+ case CitationStyles.plaintext:
+ cites = " ".join(
+ f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items()
+ )
+
case CitationStyles.number_markdown:
- formatted += "\n\n"
- formatted += "\n".join(
- f"[{ref_num}] {ref_to_markdown(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ markdown_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.number_html:
- formatted += "
"
- formatted += "
".join(
- f"[{ref_num}] {ref_to_html(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ html_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.number_slack_mrkdwn:
- formatted += "\n\n"
- formatted += "\n".join(
- f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}"
- for ref_num, ref in sorted(all_refs.items())
- )
- case CitationStyles.number_plaintext:
- formatted += "\n\n"
- formatted += "\n".join(
- f'{ref_num}. {ref["title"]} {ref["url"]}'
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_markdown:
- formatted += "\n\n"
- formatted += "\n".join(
- f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ markdown_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_html:
- formatted += "
"
- formatted += "
".join(
- f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
- for ref_num, ref in sorted(all_refs.items())
+ cites = " ".join(
+ html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"])
+ for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_slack_mrkdwn:
- formatted += "\n\n"
- formatted += "\n".join(
- f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
- for ref_num, ref in sorted(all_refs.items())
- )
- case CitationStyles.symbol_plaintext:
- formatted += "\n\n"
- formatted += "\n".join(
- f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
- for ref_num, ref in sorted(all_refs.items())
- )
-
- for ref_num, ref in all_refs.items():
- try:
- template = ref["response_template"]
- except KeyError:
- pass
- else:
- formatted = jinja2.Template(template).render(
- **ref,
- output_text=formatted,
- ref_num=ref_num,
+ cites = " ".join(
+ slack_mrkdwn_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
)
- output_text[i] = formatted
+ case None:
+ cites = ""
+ case _:
+ raise ValueError(f"Unknown citation style: {citation_style}")
+ formatted += " ".join(filter(None, [snippet, cites]))
+ all_refs.update(ref_map)
+ return all_refs, formatted
+
+
+def format_footnotes(
+ all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles
+) -> str:
+ match citation_style:
+ case CitationStyles.number_markdown:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"[{ref_num}] {ref_to_markdown(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.number_html:
+ formatted += "
"
+ formatted += "
".join(
+ f"[{ref_num}] {ref_to_html(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.number_slack_mrkdwn:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.number_plaintext:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f'{ref_num}. {ref["title"]} {ref["url"]}'
+ for ref_num, ref in sorted(all_refs.items())
+ )
+
+ case CitationStyles.symbol_markdown:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_html:
+ formatted += "
"
+ formatted += "
".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_slack_mrkdwn:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_plaintext:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ return formatted
+
+
+def format_jinja_response_template(
+ all_refs: dict[int, SearchReference], formatted: str
+) -> str:
+ for ref_num, ref in all_refs.items():
+ try:
+ template = ref["response_template"]
+ except KeyError:
+ pass
+ else:
+ formatted = jinja2.Template(template).render(
+ **ref,
+ output_text=formatted,
+ ref_num=ref_num,
+ )
+ return formatted
search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]")
diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py
index 2a509d8fd..15cb6b063 100644
--- a/recipes/DocSearch.py
+++ b/recipes/DocSearch.py
@@ -23,8 +23,9 @@
from daras_ai_v2.search_ref import (
SearchReference,
render_output_with_refs,
- apply_response_template,
CitationStyles,
+ apply_response_formattings_prefix,
+ apply_response_formattings_suffix,
)
from daras_ai_v2.vector_search import (
DocSearchRequest,
@@ -194,9 +195,12 @@ def run_v2(
citation_style = (
request.citation_style and CitationStyles[request.citation_style]
) or None
- apply_response_template(
+ all_refs_list = apply_response_formattings_prefix(
response.output_text, response.references, citation_style
)
+ apply_response_formattings_suffix(
+ all_refs_list, response.output_text, citation_style
+ )
def get_raw_price(self, state: dict) -> float:
name = state.get("selected_model")
diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py
index 149931ffb..89ea87973 100644
--- a/recipes/VideoBots.py
+++ b/recipes/VideoBots.py
@@ -50,6 +50,7 @@
get_entry_images,
get_entry_text,
format_chat_entry,
+ SUPERSCRIPT,
)
from daras_ai_v2.language_model_settings_widgets import language_model_settings
from daras_ai_v2.lipsync_settings_widgets import lipsync_settings
@@ -58,7 +59,12 @@
from daras_ai_v2.query_generator import generate_final_search_query
from daras_ai_v2.query_params import gooey_get_query_params
from daras_ai_v2.query_params_util import extract_query_params
-from daras_ai_v2.search_ref import apply_response_template, parse_refs, CitationStyles
+from daras_ai_v2.search_ref import (
+ parse_refs,
+ CitationStyles,
+ apply_response_formattings_prefix,
+ apply_response_formattings_suffix,
+)
from daras_ai_v2.text_output_widget import text_output
from daras_ai_v2.text_to_speech_settings_widgets import (
TextToSpeechProviders,
@@ -805,7 +811,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
yield f"Running {model.value}..."
if is_chat_model:
- output_text = run_language_model(
+ chunks = run_language_model(
model=request.selected_model,
messages=[
{"role": s["role"], "content": s["content"]}
@@ -816,12 +822,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
temperature=request.sampling_temperature,
avoid_repetition=request.avoid_repetition,
tools=request.tools,
+ stream=True,
)
else:
prompt = "\n".join(
format_chatml_message(entry) for entry in prompt_messages
)
- output_text = run_language_model(
+ chunks = run_language_model(
model=request.selected_model,
prompt=prompt,
max_tokens=max_allowed_tokens,
@@ -830,43 +837,51 @@ def run(self, state: dict) -> typing.Iterator[str | None]:
temperature=request.sampling_temperature,
avoid_repetition=request.avoid_repetition,
stop=[CHATML_END_TOKEN, CHATML_START_TOKEN],
+ stream=True,
)
- if request.tools:
- output_text, tool_call_choices = output_text
- state["output_documents"] = output_documents = []
- for tool_calls in tool_call_choices:
- for call in tool_calls:
- result = yield from exec_tool_call(call)
- output_documents.append(result)
-
- # save model response
- state["raw_output_text"] = [
- "".join(snippet for snippet, _ in parse_refs(text, references))
- for text in output_text
- ]
-
- # translate response text
- if request.user_language and request.user_language != "en":
- yield f"Translating response to {request.user_language}..."
- output_text = run_google_translate(
- texts=output_text,
- source_language="en",
- target_language=request.user_language,
- glossary_url=request.output_glossary_document,
- )
- state["raw_tts_text"] = [
+ citation_style = (
+ request.citation_style and CitationStyles[request.citation_style]
+ ) or None
+ all_refs_list = []
+ for i, output_text in enumerate(chunks):
+ if request.tools:
+ output_text, tool_call_choices = output_text
+ state["output_documents"] = output_documents = []
+ for tool_calls in tool_call_choices:
+ for call in tool_calls:
+ result = yield from exec_tool_call(call)
+ output_documents.append(result)
+
+ # save model response
+ state["raw_output_text"] = [
"".join(snippet for snippet, _ in parse_refs(text, references))
for text in output_text
]
- if references:
- citation_style = (
- request.citation_style and CitationStyles[request.citation_style]
- ) or None
- apply_response_template(output_text, references, citation_style)
-
- state["output_text"] = output_text
+ # translate response text
+ if request.user_language and request.user_language != "en":
+ yield f"Translating response to {request.user_language}..."
+ output_text = run_google_translate(
+ texts=output_text,
+ source_language="en",
+ target_language=request.user_language,
+ glossary_url=request.output_glossary_document,
+ )
+ state["raw_tts_text"] = [
+ "".join(snippet for snippet, _ in parse_refs(text, references))
+ for text in output_text
+ ]
+
+ if references:
+ all_refs_list = apply_response_formattings_prefix(
+ output_text, references, citation_style
+ )
+ state["output_text"] = output_text
+ yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..."
+ apply_response_formattings_suffix(
+ all_refs_list, state["output_text"], citation_style
+ )
state["output_audio"] = []
state["output_video"] = []