diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 657d2898c..09f3714c7 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -19,17 +19,23 @@ class CitationStyles(Enum): number = "Numbers ( [1] [2] [3] ..)" title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)" url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)" + symbol = "Symbols ( [*] [†] [‡] ..)" markdown = "Markdown ( [Source 1](https://source1.com) [Source 2](https://source2.com) [Source 3](https://source3.com) ..)" html = "HTML ( Source 1 Source 2 Source 3 ..)" slack_mrkdwn = "Slack mrkdwn ( ..)" plaintext = "Plain Text / WhatsApp ( [Source 1 https://source1.com] [Source 2 https://source2.com] [Source 3 https://source3.com] ..)" - number_markdown = " Markdown Numbers + Footnotes" + number_markdown = "Markdown Numbers + Footnotes" number_html = "HTML Numbers + Footnotes" number_slack_mrkdwn = "Slack mrkdown Numbers + Footnotes" number_plaintext = "Plain Text / WhatsApp Numbers + Footnotes" + symbol_markdown = "Markdown Symbols + Footnotes" + symbol_html = "HTML Symbols + Footnotes" + symbol_slack_mrkdwn = "Slack mrkdown Symbols + Footnotes" + symbol_plaintext = "Plain Text / WhatsApp Symbols + Footnotes" + def remove_quotes(snippet: str) -> str: return re.sub(r"[\"\']+", r'"', snippet).strip() @@ -63,36 +69,65 @@ def apply_response_template( match citation_style: case CitationStyles.number | CitationStyles.number_plaintext: cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys()) - case CitationStyles.number_html: + case CitationStyles.title: + cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) + case CitationStyles.url: + cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) + case CitationStyles.symbol | CitationStyles.symbol_plaintext: cites = " ".join( - html_link(f"[{ref_num}]", ref["url"]) + f"[{generate_footnote_symbol(ref_num - 1)}]" + for ref_num in ref_map.keys() + ) + + case CitationStyles.markdown: + cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) + case CitationStyles.html: + cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) + case CitationStyles.slack_mrkdwn: + cites = " ".join( + ref_to_slack_mrkdwn(ref) for ref in ref_map.values() + ) + case CitationStyles.plaintext: + cites = " ".join( + f'[{ref["title"]} {ref["url"]}]' for ref_num, ref in ref_map.items() ) + case CitationStyles.number_markdown: cites = " ".join( markdown_link(f"[{ref_num}]", ref["url"]) for ref_num, ref in ref_map.items() ) + case CitationStyles.number_html: + cites = " ".join( + html_link(f"[{ref_num}]", ref["url"]) + for ref_num, ref in ref_map.items() + ) case CitationStyles.number_slack_mrkdwn: cites = " ".join( slack_mrkdwn_link(f"[{ref_num}]", ref["url"]) for ref_num, ref in ref_map.items() ) - case CitationStyles.title: - cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values()) - case CitationStyles.url: - cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values()) - case CitationStyles.markdown: - cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values()) - case CitationStyles.html: - cites = " ".join(ref_to_html(ref) for ref in ref_map.values()) - case CitationStyles.slack_mrkdwn: + + case CitationStyles.symbol_markdown: cites = " ".join( - ref_to_slack_mrkdwn(ref) for ref in ref_map.values() + markdown_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() ) - case CitationStyles.plaintext: + case CitationStyles.symbol_html: cites = " ".join( - f'[{ref["title"]} {ref["url"]}]' + html_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) + for ref_num, ref in ref_map.items() + ) + case CitationStyles.symbol_slack_mrkdwn: + cites = " ".join( + slack_mrkdwn_link( + f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"] + ) for ref_num, ref in ref_map.items() ) case None: @@ -128,6 +163,31 @@ def apply_response_template( for ref_num, ref in sorted(all_refs.items()) ) + case CitationStyles.symbol_markdown: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_html: + formatted += "

" + formatted += "
".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_slack_mrkdwn: + formatted += "\n\n" + formatted += "\n".join( + f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}" + for ref_num, ref in sorted(all_refs.items()) + ) + case CitationStyles.symbol_plaintext: + formatted += "\n\n" + formatted += "\n".join( + f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}' + for ref_num, ref in sorted(all_refs.items()) + ) + for ref_num, ref in all_refs.items(): try: template = ref["response_template"] @@ -205,3 +265,9 @@ def render_output_with_refs(state, height): for text in output_text: html = render_text_with_refs(text, state.get("references", [])) scrollable_html(html, height=height) + + +FOOTNOTE_SYMBOLS = ["*", "†", "‡", "§", "¶", "#", "♠", "♥", "♦", "♣", "✠", "☮", "☯", "✡"] # fmt: skip +def generate_footnote_symbol(idx: int) -> str: + quotient, remainder = divmod(idx, len(FOOTNOTE_SYMBOLS)) + return FOOTNOTE_SYMBOLS[remainder] * (quotient + 1) diff --git a/tests/test_search_refs.py b/tests/test_search_refs.py index 5374db260..00698bf88 100644 --- a/tests/test_search_refs.py +++ b/tests/test_search_refs.py @@ -1,4 +1,6 @@ -from daras_ai_v2.search_ref import parse_refs +import pytest + +from daras_ai_v2.search_ref import parse_refs, generate_footnote_symbol def test_ref_parser(): @@ -126,3 +128,21 @@ def test_ref_parser(): }, ), ] + + +def test_generate_footnote_symbol(): + assert generate_footnote_symbol(0) == "*" + assert generate_footnote_symbol(1) == "†" + assert generate_footnote_symbol(13) == "✡" + assert generate_footnote_symbol(14) == "**" + assert generate_footnote_symbol(15) == "††" + assert generate_footnote_symbol(27) == "✡✡" + assert generate_footnote_symbol(28) == "***" + assert generate_footnote_symbol(29) == "†††" + assert generate_footnote_symbol(41) == "✡✡✡" + assert generate_footnote_symbol(70) == "******" + assert generate_footnote_symbol(71) == "††††††" + + # testing with non-integer index + with pytest.raises(TypeError): + generate_footnote_symbol(1.5)