diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 1baf64b47..7cc810b1a 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -57,150 +57,178 @@ def render_text_with_refs(text: str, references: list[SearchReference]): return html -def apply_response_template( +def apply_response_formattings_prefix( output_text: list[str], references: list[SearchReference], citation_style: CitationStyles | None = CitationStyles.number, -): +) -> list[dict[int, SearchReference]]: + all_refs_list = [{}] * len(output_text) for i, text in enumerate(output_text): - formatted = "" - all_refs = {} - - for snippet, ref_map in parse_refs(text, references): - match citation_style: - case CitationStyles.number | CitationStyles.number_plaintext: - cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) - case CitationStyles.title: - cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) - case CitationStyles.url: - cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) - case CitationStyles.symbol | CitationStyles.symbol_plaintext: - cites = " ".join( - f"[{generate_footnote_symbol(ref_num - 1)}]" - for ref_num in ref_map.keys() - ) + all_refs_list[i], output_text[i] = format_citations( + text, references, citation_style + ) + return all_refs_list - case CitationStyles.markdown: - cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) - case CitationStyles.html: - cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) - case CitationStyles.slack_mrkdwn: - cites = " ".join( - ref_to_slack_mrkdwn(ref) for ref in ref_map.values() - ) - case CitationStyles.plaintext: - cites = " ".join( - f'[{ref["title"]} {ref["url"]}]' - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_markdown: - cites = " ".join( - markdown_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_html: - cites = " ".join( - html_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.number_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) - for ref_num, ref in ref_map.items() - ) +def apply_response_formattings_suffix( + all_refs_list: list[dict[int, SearchReference]], + output_text: list[str], + citation_style: CitationStyles | None = CitationStyles.number, +): + for i, text in enumerate(output_text): + output_text[i] = format_jinja_response_template( + all_refs_list[i], + format_footnotes(all_refs_list[i], text, citation_style), + ) - case CitationStyles.symbol_markdown: - cites = " ".join( - markdown_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_html: - cites = " ".join( - html_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case CitationStyles.symbol_slack_mrkdwn: - cites = " ".join( - slack_mrkdwn_link( - f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] - ) - for ref_num, ref in ref_map.items() - ) - case None: - cites = "" - case _: - raise ValueError(f"Unknown citation style: {citation_style}") - formatted += snippet + " " + cites + " " - all_refs.update(ref_map) +def format_citations( + text: str, + references: list[SearchReference], + citation_style: CitationStyles | None = CitationStyles.number, +) -> tuple[dict[int, SearchReference], str]: + all_refs = {} + formatted = "" + for snippet, ref_map in parse_refs(text, references): match citation_style: + case CitationStyles.number | CitationStyles.number_plaintext: + cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) + case CitationStyles.title: + cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) + case CitationStyles.url: + cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) + case CitationStyles.symbol | CitationStyles.symbol_plaintext: + cites = " ".join( + f"[{generate_footnote_symbol(ref_num - 1)}]" + for ref_num in ref_map.keys() + ) + + case CitationStyles.markdown: + cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) + case CitationStyles.html: + cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) + case CitationStyles.slack_mrkdwn: + cites = " ".join(ref_to_slack_mrkdwn(ref) for ref in ref_map.values()) + case CitationStyles.plaintext: + cites = " ".join( + f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items() + ) + case CitationStyles.number_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_html: - formatted += "

" - formatted += "
".join( - f"[{ref_num}] {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.number_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.number_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{ref_num}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_markdown: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + markdown_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_html: - formatted += "

" - formatted += "
".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" - for ref_num, ref in sorted(all_refs.items()) + cites = " ".join( + html_link(f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]) + for ref_num, ref in ref_map.items() ) case CitationStyles.symbol_slack_mrkdwn: - formatted += "\n\n" - formatted += "\n".join( - f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" - for ref_num, ref in sorted(all_refs.items()) - ) - case CitationStyles.symbol_plaintext: - formatted += "\n\n" - formatted += "\n".join( - f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' - for ref_num, ref in sorted(all_refs.items()) - ) - - for ref_num, ref in all_refs.items(): - try: - template = ref["response_template"] - except KeyError: - pass - else: - formatted = jinja2.Template(template).render( - **ref, - output_text=formatted, - ref_num=ref_num, + cites = " ".join( + slack_mrkdwn_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) - output_text[i] = formatted + case None: + cites = "" + case _: + raise ValueError(f"Unknown citation style: {citation_style}") + formatted += " ".join(filter(None, [snippet, cites])) + all_refs.update(ref_map) + return all_refs, formatted + + +def format_footnotes( + all_refs: dict[int, SearchReference], formatted: str, citation_style: CitationStyles +) -> str: + match citation_style: + case CitationStyles.number_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_html: + formatted += "

" + formatted += "
".join( + f"[{ref_num}] {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"[{ref_num}] {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.number_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{ref_num}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + + case CitationStyles.symbol_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_html: + formatted += "

" + formatted += "
".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + return formatted + + +def format_jinja_response_template( + all_refs: dict[int, SearchReference], formatted: str +) -> str: + for ref_num, ref in all_refs.items(): + try: + template = ref["response_template"] + except KeyError: + pass + else: + formatted = jinja2.Template(template).render( + **ref, + output_text=formatted, + ref_num=ref_num, + ) + return formatted search_ref_pat = re.compile(r"\[" r"[\d\s\.\,\[\]\$\{\}]+" r"\]") diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 2a509d8fd..15cb6b063 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -23,8 +23,9 @@ from daras_ai_v2.search_ref import ( SearchReference, render_output_with_refs, - apply_response_template, CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, ) from daras_ai_v2.vector_search import ( DocSearchRequest, @@ -194,9 +195,12 @@ def run_v2( citation_style = ( request.citation_style and CitationStyles[request.citation_style] ) or None - apply_response_template( + all_refs_list = apply_response_formattings_prefix( response.output_text, response.references, citation_style ) + apply_response_formattings_suffix( + all_refs_list, response.output_text, citation_style + ) def get_raw_price(self, state: dict) -> float: name = state.get("selected_model") diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 149931ffb..89ea87973 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -50,6 +50,7 @@ get_entry_images, get_entry_text, format_chat_entry, + SUPERSCRIPT, ) from daras_ai_v2.language_model_settings_widgets import language_model_settings from daras_ai_v2.lipsync_settings_widgets import lipsync_settings @@ -58,7 +59,12 @@ from daras_ai_v2.query_generator import generate_final_search_query from daras_ai_v2.query_params import gooey_get_query_params from daras_ai_v2.query_params_util import extract_query_params -from daras_ai_v2.search_ref import apply_response_template, parse_refs, CitationStyles +from daras_ai_v2.search_ref import ( + parse_refs, + CitationStyles, + apply_response_formattings_prefix, + apply_response_formattings_suffix, +) from daras_ai_v2.text_output_widget import text_output from daras_ai_v2.text_to_speech_settings_widgets import ( TextToSpeechProviders, @@ -805,7 +811,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield f"Running {model.value}..." if is_chat_model: - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, messages=[ {"role": s["role"], "content": s["content"]} @@ -816,12 +822,13 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, tools=request.tools, + stream=True, ) else: prompt = "\n".join( format_chatml_message(entry) for entry in prompt_messages ) - output_text = run_language_model( + chunks = run_language_model( model=request.selected_model, prompt=prompt, max_tokens=max_allowed_tokens, @@ -830,43 +837,51 @@ def run(self, state: dict) -> typing.Iterator[str | None]: temperature=request.sampling_temperature, avoid_repetition=request.avoid_repetition, stop=[CHATML_END_TOKEN, CHATML_START_TOKEN], + stream=True, ) - if request.tools: - output_text, tool_call_choices = output_text - state["output_documents"] = output_documents = [] - for tool_calls in tool_call_choices: - for call in tool_calls: - result = yield from exec_tool_call(call) - output_documents.append(result) - - # save model response - state["raw_output_text"] = [ - "".join(snippet for snippet, _ in parse_refs(text, references)) - for text in output_text - ] - - # translate response text - if request.user_language and request.user_language != "en": - yield f"Translating response to {request.user_language}..." - output_text = run_google_translate( - texts=output_text, - source_language="en", - target_language=request.user_language, - glossary_url=request.output_glossary_document, - ) - state["raw_tts_text"] = [ + citation_style = ( + request.citation_style and CitationStyles[request.citation_style] + ) or None + all_refs_list = [] + for i, output_text in enumerate(chunks): + if request.tools: + output_text, tool_call_choices = output_text + state["output_documents"] = output_documents = [] + for tool_calls in tool_call_choices: + for call in tool_calls: + result = yield from exec_tool_call(call) + output_documents.append(result) + + # save model response + state["raw_output_text"] = [ "".join(snippet for snippet, _ in parse_refs(text, references)) for text in output_text ] - if references: - citation_style = ( - request.citation_style and CitationStyles[request.citation_style] - ) or None - apply_response_template(output_text, references, citation_style) - - state["output_text"] = output_text + # translate response text + if request.user_language and request.user_language != "en": + yield f"Translating response to {request.user_language}..." + output_text = run_google_translate( + texts=output_text, + source_language="en", + target_language=request.user_language, + glossary_url=request.output_glossary_document, + ) + state["raw_tts_text"] = [ + "".join(snippet for snippet, _ in parse_refs(text, references)) + for text in output_text + ] + + if references: + all_refs_list = apply_response_formattings_prefix( + output_text, references, citation_style + ) + state["output_text"] = output_text + yield f"Streaming{str(i + 1).translate(SUPERSCRIPT)} {model.value}..." + apply_response_formattings_suffix( + all_refs_list, state["output_text"], citation_style + ) state["output_audio"] = [] state["output_video"] = []