Skip to content

Commit

Permalink
Handle if a user instantiates a prompt with just text (#107)
Browse files Browse the repository at this point in the history
It isn't an uncommon occurrence to have users instantiate a Prompt with
just a string. Since this is also how our API works, I don't blame them.
But when they do, every character becomes a text item, which causes
weird behavior with our tokenization and prompt optmizations.

This handles the case where a user does this, and serializes it
correctly. While it would be nice to have stricter types, I think unless
every user used mypy, the type hints still wouldn't alleviate the
confusion.
  • Loading branch information
benbrandt authored Mar 29, 2023
1 parent 8dddf80 commit 42bc537
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
7 changes: 5 additions & 2 deletions aleph_alpha_client/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ class Prompt(NamedTuple):
])
"""

items: Sequence[PromptItem]
items: Union[str, Sequence[PromptItem]]

@staticmethod
def from_text(
Expand All @@ -459,7 +459,10 @@ def from_tokens(
return Prompt([Tokens(tokens, controls or [])])

def to_json(self) -> Sequence[Mapping[str, Any]]:
return [_to_json(item) for item in self.items]
if isinstance(self.items, str):
return [_to_json(self.items)]
else:
return [_to_json(item) for item in self.items]


def _to_json(item: PromptItem) -> Mapping[str, Any]:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from tests.common import sync_client, model_name


def test_serialize_prompt_init_with_str():
text = "text prompt"
prompt = Prompt(text)
serialized_prompt = prompt.to_json()

assert serialized_prompt == [{"type": "text", "data": text}]


def test_serialize_token_ids():
tokens = [1, 2, 3, 4]
prompt = Prompt.from_tokens(tokens)
Expand Down

0 comments on commit 42bc537

Please sign in to comment.