From 42bc5371c89e80473583c39ab66227d5822f09bc Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 29 Mar 2023 14:21:30 +0200 Subject: [PATCH] Handle if a user instantiates a prompt with just text (#107) It isn't an uncommon occurrence to have users instantiate a Prompt with just a string. Since this is also how our API works, I don't blame them. But when they do, every character becomes a text item, which causes weird behavior with our tokenization and prompt optmizations. This handles the case where a user does this, and serializes it correctly. While it would be nice to have stricter types, I think unless every user used mypy, the type hints still wouldn't alleviate the confusion. --- aleph_alpha_client/prompt.py | 7 +++++-- tests/test_prompt.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/aleph_alpha_client/prompt.py b/aleph_alpha_client/prompt.py index 6120ad7..c6ee38f 100644 --- a/aleph_alpha_client/prompt.py +++ b/aleph_alpha_client/prompt.py @@ -436,7 +436,7 @@ class Prompt(NamedTuple): ]) """ - items: Sequence[PromptItem] + items: Union[str, Sequence[PromptItem]] @staticmethod def from_text( @@ -459,7 +459,10 @@ def from_tokens( return Prompt([Tokens(tokens, controls or [])]) def to_json(self) -> Sequence[Mapping[str, Any]]: - return [_to_json(item) for item in self.items] + if isinstance(self.items, str): + return [_to_json(self.items)] + else: + return [_to_json(item) for item in self.items] def _to_json(item: PromptItem) -> Mapping[str, Any]: diff --git a/tests/test_prompt.py b/tests/test_prompt.py index d7f8ff5..1f446d1 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -13,6 +13,14 @@ from tests.common import sync_client, model_name +def test_serialize_prompt_init_with_str(): + text = "text prompt" + prompt = Prompt(text) + serialized_prompt = prompt.to_json() + + assert serialized_prompt == [{"type": "text", "data": text}] + + def test_serialize_token_ids(): tokens = [1, 2, 3, 4] prompt = Prompt.from_tokens(tokens)