Skip to content

Commit

Permalink
citation symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Sep 27, 2023
1 parent 366f0c8 commit 0a8a22b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 16 deletions.
96 changes: 81 additions & 15 deletions daras_ai_v2/search_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@ class CitationStyles(Enum):
number = "Numbers ( [1] [2] [3] ..)"
title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)"
url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)"
symbol = "Symbols ( [*] [†] [‡] ..)"

markdown = "Markdown ( [Source 1](https://source1.com) [Source 2](https://source2.com) [Source 3](https://source3.com) ..)"
html = "HTML ( <a href='https://source1.com'>Source 1</a> <a href='https://source2.com'>Source 2</a> <a href='https://source3.com'>Source 3</a> ..)"
slack_mrkdwn = "Slack mrkdwn ( <https://source1.com|Source 1> <https://source2.com|Source 2> <https://source3.com|Source 3> ..)"
plaintext = "Plain Text / WhatsApp ( [Source 1 https://source1.com] [Source 2 https://source2.com] [Source 3 https://source3.com] ..)"

number_markdown = " Markdown Numbers + Footnotes"
number_markdown = "Markdown Numbers + Footnotes"
number_html = "HTML Numbers + Footnotes"
number_slack_mrkdwn = "Slack mrkdown Numbers + Footnotes"
number_plaintext = "Plain Text / WhatsApp Numbers + Footnotes"

symbol_markdown = "Markdown Symbols + Footnotes"
symbol_html = "HTML Symbols + Footnotes"
symbol_slack_mrkdwn = "Slack mrkdown Symbols + Footnotes"
symbol_plaintext = "Plain Text / WhatsApp Symbols + Footnotes"


def remove_quotes(snippet: str) -> str:
return re.sub(r"[\"\']+", r'"', snippet).strip()
Expand Down Expand Up @@ -63,36 +69,65 @@ def apply_response_template(
match citation_style:
case CitationStyles.number | CitationStyles.number_plaintext:
cites = " ".join(f"[{ref_num}]" for ref_num in ref_map.keys())
case CitationStyles.number_html:
case CitationStyles.title:
cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
case CitationStyles.url:
cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
case CitationStyles.symbol | CitationStyles.symbol_plaintext:
cites = " ".join(
html_link(f"[{ref_num}]", ref["url"])
f"[{generate_footnote_symbol(ref_num - 1)}]"
for ref_num in ref_map.keys()
)

case CitationStyles.markdown:
cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
case CitationStyles.html:
cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
case CitationStyles.slack_mrkdwn:
cites = " ".join(
ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
)
case CitationStyles.plaintext:
cites = " ".join(
f'[{ref["title"]} {ref["url"]}]'
for ref_num, ref in ref_map.items()
)

case CitationStyles.number_markdown:
cites = " ".join(
markdown_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.number_html:
cites = " ".join(
html_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.number_slack_mrkdwn:
cites = " ".join(
slack_mrkdwn_link(f"[{ref_num}]", ref["url"])
for ref_num, ref in ref_map.items()
)
case CitationStyles.title:
cites = " ".join(f"[{ref['title']}]" for ref in ref_map.values())
case CitationStyles.url:
cites = " ".join(f"[{ref['url']}]" for ref in ref_map.values())
case CitationStyles.markdown:
cites = " ".join(ref_to_markdown(ref) for ref in ref_map.values())
case CitationStyles.html:
cites = " ".join(ref_to_html(ref) for ref in ref_map.values())
case CitationStyles.slack_mrkdwn:

case CitationStyles.symbol_markdown:
cites = " ".join(
ref_to_slack_mrkdwn(ref) for ref in ref_map.values()
markdown_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case CitationStyles.plaintext:
case CitationStyles.symbol_html:
cites = " ".join(
f'[{ref["title"]} {ref["url"]}]'
html_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case CitationStyles.symbol_slack_mrkdwn:
cites = " ".join(
slack_mrkdwn_link(
f"[{generate_footnote_symbol(ref_num - 1)}]", ref["url"]
)
for ref_num, ref in ref_map.items()
)
case None:
Expand Down Expand Up @@ -128,6 +163,31 @@ def apply_response_template(
for ref_num, ref in sorted(all_refs.items())
)

case CitationStyles.symbol_markdown:
formatted += "\n\n"
formatted += "\n".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_markdown(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_html:
formatted += "<br><br>"
formatted += "<br>".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_html(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_slack_mrkdwn:
formatted += "\n\n"
formatted += "\n".join(
f"{generate_footnote_symbol(ref_num - 1)} {ref_to_slack_mrkdwn(ref)}"
for ref_num, ref in sorted(all_refs.items())
)
case CitationStyles.symbol_plaintext:
formatted += "\n\n"
formatted += "\n".join(
f'{generate_footnote_symbol(ref_num - 1)}. {ref["title"]} {ref["url"]}'
for ref_num, ref in sorted(all_refs.items())
)

for ref_num, ref in all_refs.items():
try:
template = ref["response_template"]
Expand Down Expand Up @@ -205,3 +265,9 @@ def render_output_with_refs(state, height):
for text in output_text:
html = render_text_with_refs(text, state.get("references", []))
scrollable_html(html, height=height)


FOOTNOTE_SYMBOLS = ["*", "†", "‡", "§", "¶", "#", "♠", "♥", "♦", "♣", "✠", "☮", "☯", "✡"] # fmt: skip
def generate_footnote_symbol(idx: int) -> str:
quotient, remainder = divmod(idx, len(FOOTNOTE_SYMBOLS))
return FOOTNOTE_SYMBOLS[remainder] * (quotient + 1)
22 changes: 21 additions & 1 deletion tests/test_search_refs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from daras_ai_v2.search_ref import parse_refs
import pytest

from daras_ai_v2.search_ref import parse_refs, generate_footnote_symbol


def test_ref_parser():
Expand Down Expand Up @@ -126,3 +128,21 @@ def test_ref_parser():
},
),
]


def test_generate_footnote_symbol():
assert generate_footnote_symbol(0) == "*"
assert generate_footnote_symbol(1) == "†"
assert generate_footnote_symbol(13) == "✡"
assert generate_footnote_symbol(14) == "**"
assert generate_footnote_symbol(15) == "††"
assert generate_footnote_symbol(27) == "✡✡"
assert generate_footnote_symbol(28) == "***"
assert generate_footnote_symbol(29) == "†††"
assert generate_footnote_symbol(41) == "✡✡✡"
assert generate_footnote_symbol(70) == "******"
assert generate_footnote_symbol(71) == "††††††"

# testing with non-integer index
with pytest.raises(TypeError):
generate_footnote_symbol(1.5)

0 comments on commit 0a8a22b

Please sign in to comment.