diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py
index 657d2898c..09f3714c7 100644
--- a/daras_ai_v2/search_ref.py
+++ b/daras_ai_v2/search_ref.py
@@ -19,17 +19,23 @@ class CitationStyles(Enum):
number = "Numbers ( [1] [2] [3] ..)"
title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)"
url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)"
+ symbol = "Symbols ( [*] [†] [‡] ..)"
markdown = "Markdown ( [Source 1](https://source1.com) [Source 2](https://source2.com) [Source 3](https://source3.com) ..)"
html = "HTML ( Source 1 Source 2 Source 3 ..)"
slack_mrkdwn = "Slack mrkdwn ( ..)"
plaintext = "Plain Text / WhatsApp ( [Source 1 https://source1.com] [Source 2 https://source2.com] [Source 3 https://source3.com] ..)"
- number_markdown = " Markdown Numbers + Footnotes"
+ number_markdown = "Markdown Numbers + Footnotes"
number_html = "HTML Numbers + Footnotes"
number_slack_mrkdwn = "Slack mrkdown Numbers + Footnotes"
number_plaintext = "Plain Text / WhatsApp Numbers + Footnotes"
+ symbol_markdown = "Markdown Symbols + Footnotes"
+ symbol_html = "HTML Symbols + Footnotes"
+ symbol_slack_mrkdwn = "Slack mrkdown Symbols + Footnotes"
+ symbol_plaintext = "Plain Text / WhatsApp Symbols + Footnotes"
+
def remove_quotes(snippet: str) -> str:
return re.sub(r"[\"\']+", r'"', snippet).strip()
@@ -63,36 +69,65 @@ def apply_response_template(
match citation_style:
case CitationStyles.number | CitationStyles.number_plaintext:
cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
- case CitationStyles.number_html:
+ case CitationStyles.title:
+ cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
+ case CitationStyles.url:
+ cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
+ case CitationStyles.symbol | CitationStyles.symbol_plaintext:
cites = " ".join(
- html_link(f"[{ref_num}]", ref["url"])
+ f"[{generate_footnote_symbol(ref_num - 1)}]"
+ for ref_num in ref_map.keys()
+ )
+
+ case CitationStyles.markdown:
+ cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
+ case CitationStyles.html:
+ cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
+ case CitationStyles.slack_mrkdwn:
+ cites = " ".join(
+ ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
+ )
+ case CitationStyles.plaintext:
+ cites = " ".join(
+ f'[{ref["title"]} {ref["url"]}]'
for ref_num, ref in ref_map.items()
)
+
case CitationStyles.number_markdown:
cites = " ".join(
markdown_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
+ case CitationStyles.number_html:
+ cites = " ".join(
+ html_link(f"[{ref_num}]", ref["url"])
+ for ref_num, ref in ref_map.items()
+ )
case CitationStyles.number_slack_mrkdwn:
cites = " ".join(
slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
- case CitationStyles.title:
- cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
- case CitationStyles.url:
- cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
- case CitationStyles.markdown:
- cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
- case CitationStyles.html:
- cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
- case CitationStyles.slack_mrkdwn:
+
+ case CitationStyles.symbol_markdown:
cites = " ".join(
- ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
+ markdown_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
)
- case CitationStyles.plaintext:
+ case CitationStyles.symbol_html:
cites = " ".join(
- f'[{ref["title"]} {ref["url"]}]'
+ html_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
+ for ref_num, ref in ref_map.items()
+ )
+ case CitationStyles.symbol_slack_mrkdwn:
+ cites = " ".join(
+ slack_mrkdwn_link(
+ f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
+ )
for ref_num, ref in ref_map.items()
)
case None:
@@ -128,6 +163,31 @@ def apply_response_template(
for ref_num, ref in sorted(all_refs.items())
)
+ case CitationStyles.symbol_markdown:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_html:
+ formatted += "
"
+ formatted += "
".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_slack_mrkdwn:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
+ for ref_num, ref in sorted(all_refs.items())
+ )
+ case CitationStyles.symbol_plaintext:
+ formatted += "\n\n"
+ formatted += "\n".join(
+ f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
+ for ref_num, ref in sorted(all_refs.items())
+ )
+
for ref_num, ref in all_refs.items():
try:
template = ref["response_template"]
@@ -205,3 +265,9 @@ def render_output_with_refs(state, height):
for text in output_text:
html = render_text_with_refs(text, state.get("references", []))
scrollable_html(html, height=height)
+
+
+FOOTNOTE_SYMBOLS = ["*", "†", "‡", "§", "¶", "#", "♠", "♥", "♦", "♣", "✠", "☮", "☯", "✡"] # fmt: skip
+def generate_footnote_symbol(idx: int) -> str:
+ quotient, remainder = divmod(idx, len(FOOTNOTE_SYMBOLS))
+ return FOOTNOTE_SYMBOLS[remainder] * (quotient + 1)
diff --git a/tests/test_search_refs.py b/tests/test_search_refs.py
index 5374db260..00698bf88 100644
--- a/tests/test_search_refs.py
+++ b/tests/test_search_refs.py
@@ -1,4 +1,6 @@
-from daras_ai_v2.search_ref import parse_refs
+import pytest
+
+from daras_ai_v2.search_ref import parse_refs, generate_footnote_symbol
def test_ref_parser():
@@ -126,3 +128,21 @@ def test_ref_parser():
},
),
]
+
+
+def test_generate_footnote_symbol():
+ assert generate_footnote_symbol(0) == "*"
+ assert generate_footnote_symbol(1) == "†"
+ assert generate_footnote_symbol(13) == "✡"
+ assert generate_footnote_symbol(14) == "**"
+ assert generate_footnote_symbol(15) == "††"
+ assert generate_footnote_symbol(27) == "✡✡"
+ assert generate_footnote_symbol(28) == "***"
+ assert generate_footnote_symbol(29) == "†††"
+ assert generate_footnote_symbol(41) == "✡✡✡"
+ assert generate_footnote_symbol(70) == "******"
+ assert generate_footnote_symbol(71) == "††††††"
+
+ # testing with non-integer index
+ with pytest.raises(TypeError):
+ generate_footnote_symbol(1.5)