diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 106d00a..12e25d9 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -16,7 +16,7 @@ from aleph_alpha_client.document import Document from aleph_alpha_client.explanation import ExplanationRequest, ExplanationResponse from aleph_alpha_client.image import Image -from aleph_alpha_client.prompt import _to_prompt_item, _to_serializable_prompt +from aleph_alpha_client.prompt import _to_json, _to_serializable_prompt from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse from aleph_alpha_client.qa import QaRequest, QaResponse from aleph_alpha_client.completion import CompletionRequest, CompletionResponse @@ -779,7 +779,7 @@ def _explain( ): body = { "model": model, - "prompt": [_to_prompt_item(item) for item in request.prompt.items], + "prompt": [_to_json(item) for item in request.prompt.items], "target": request.target, "suppression_factor": request.suppression_factor, "conceptual_suppression_threshold": request.conceptual_suppression_threshold, diff --git a/aleph_alpha_client/document.py b/aleph_alpha_client/document.py index 05805bf..fdc090c 100644 --- a/aleph_alpha_client/document.py +++ b/aleph_alpha_client/document.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union from aleph_alpha_client.image import Image -from aleph_alpha_client.prompt import _to_prompt_item +from aleph_alpha_client.prompt import PromptItem, Text, _to_json class Document: @@ -65,6 +65,12 @@ def _to_serializable_document(self) -> Dict[str, Any]: """ A dict if serialized to JSON is suitable as a document element """ + + def to_prompt_item(item: Union[str, Image]) -> PromptItem: + # document still uses a plain piece of text for text-prompts + # -> convert to Text-instance + return Text.from_text(item) if isinstance(item, str) else item + if self.docx is not None: # Serialize docx to Document JSON format return { @@ -72,7 +78,9 @@ def _to_serializable_document(self) -> Dict[str, Any]: } elif self.prompt is not None: # Serialize prompt to Document JSON format - prompt_data = [_to_prompt_item(prompt_item) for prompt_item in self.prompt] + prompt_data = [ + _to_json(to_prompt_item(prompt_item)) for prompt_item in self.prompt + ] return {"prompt": prompt_data} elif self.text is not None: return { diff --git a/aleph_alpha_client/prompt.py b/aleph_alpha_client/prompt.py index d7bcc60..4fd5eab 100644 --- a/aleph_alpha_client/prompt.py +++ b/aleph_alpha_client/prompt.py @@ -50,16 +50,78 @@ class Tokens(NamedTuple): """ tokens: Sequence[int] - controls: Optional[Sequence[TokenControl]] = None + controls: Sequence[TokenControl] def to_json(self) -> Mapping[str, Any]: """ Serialize the prompt item to JSON for sending to the API. """ - payload = {"type": "token_ids", "data": self.tokens} - if self.controls: - payload["controls"] = [c.to_json() for c in self.controls] - return payload + return { + "type": "token_ids", + "data": self.tokens, + "controls": [c.to_json() for c in self.controls], + } + + @staticmethod + def from_token_ids(token_ids: Sequence[int]) -> "Tokens": + return Tokens(token_ids, []) + + +class TextControl(NamedTuple): + """ + Attention manipulation for a Text PromptItem. + + Parameters: + start (int, required): + Starting character index to apply the factor to. + length (int, required): + The amount of characters to apply the factor to. + factor (float, required): + The amount to adjust model attention by. + Values between 0 and 1 will supress attention. + A value of 1 will have no effect. + Values above 1 will increase attention. + """ + + start: int + length: int + factor: float + + def to_json(self) -> Mapping[str, Any]: + return self._asdict() + + +class Text(NamedTuple): + """ + A Text-prompt including optional controls for attention manipulation. + + Parameters: + text (str, required): + The text prompt + controls (list of TextControl, required): + A list of TextControls to manilpulate attention when processing the prompt. + Can be empty if no manipulation is required. + + Examples: + >>> Text("Hello, World!", controls=[TextControl(start=0, length=5, factor=0.5)]) + """ + + text: str + controls: Sequence[TextControl] + + def to_json(self) -> Mapping[str, Any]: + return { + "type": "text", + "data": self.text, + "controls": [control.to_json() for control in self.controls], + } + + @staticmethod + def from_text(text: str) -> "Text": + return Text(text, []) + + +PromptItem = Union[Text, Tokens, Image] class Prompt(NamedTuple): @@ -72,39 +134,41 @@ class Prompt(NamedTuple): ]) """ - items: Sequence[Union[str, Image, Tokens, Sequence[int]]] + items: Sequence[PromptItem] @staticmethod - def from_text(text: str) -> "Prompt": - return Prompt([text]) + def from_text( + text: str, controls: Optional[Sequence[TextControl]] = None + ) -> "Prompt": + return Prompt([Text(text, controls or [])]) @staticmethod def from_image(image: Image) -> "Prompt": return Prompt([image]) @staticmethod - def from_tokens(tokens: Union[Sequence[int], Tokens]) -> "Prompt": + def from_tokens( + tokens: Sequence[int], controls: Optional[Sequence[TokenControl]] = None + ) -> "Prompt": """ Examples: >>> prompt = Prompt.from_tokens(Tokens([1, 2, 3])) """ - if isinstance(tokens, List): - tokens = Tokens(tokens) - return Prompt([tokens]) + return Prompt([Tokens(tokens, controls or [])]) def to_json(self) -> Sequence[Mapping[str, Any]]: - return [_to_prompt_item(item) for item in self.items] + return [_to_json(item) for item in self.items] -def _to_prompt_item( - item: Union[str, Image, Tokens, Sequence[int]] -) -> Mapping[str, Any]: - if isinstance(item, str): +def _to_json(item: PromptItem) -> Mapping[str, Any]: + if hasattr(item, "to_json"): + return item.to_json() + # Required for backwards compatibility + # item could be a plain piece of text or a plain list of token-ids + elif isinstance(item, str): return {"type": "text", "data": item} elif isinstance(item, List): return {"type": "token_ids", "data": item} - elif hasattr(item, "to_json"): - return item.to_json() else: raise ValueError( "The item in the prompt is not valid. Try either a string or an Image." @@ -125,7 +189,7 @@ def _to_serializable_prompt( return prompt elif isinstance(prompt, list): - return [_to_prompt_item(item) for item in prompt] + return [_to_json(item) for item in prompt] raise ValueError( "Invalid prompt. Prompt must either be a string, or a list of valid multimodal propmt items." diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 913c20c..c615fee 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,24 +1,25 @@ from aleph_alpha_client import Prompt, Tokens, TokenControl +from aleph_alpha_client.prompt import TextControl def test_serialize_token_ids(): tokens = [1, 2, 3, 4] - prompt = Prompt.from_tokens(Tokens(tokens)) + prompt = Prompt.from_tokens(tokens) serialized_prompt = prompt.to_json() - assert serialized_prompt == [{"type": "token_ids", "data": [1, 2, 3, 4]}] + assert serialized_prompt == [ + {"type": "token_ids", "data": [1, 2, 3, 4], "controls": []} + ] def test_serialize_token_ids_with_controls(): tokens = [1, 2, 3, 4] prompt = Prompt.from_tokens( - Tokens( - tokens, - controls=[ - TokenControl(pos=0, factor=0.25), - TokenControl(pos=1, factor=0.5), - ], - ) + tokens, + controls=[ + TokenControl(pos=0, factor=0.25), + TokenControl(pos=1, factor=0.5), + ], ) serialized_prompt = prompt.to_json() @@ -29,3 +30,18 @@ def test_serialize_token_ids_with_controls(): "controls": [{"index": 0, "factor": 0.25}, {"index": 1, "factor": 0.5}], } ] + + +def test_serialize_text_with_controls(): + prompt_text = "An apple a day" + prompt = Prompt.from_text(prompt_text, [TextControl(start=3, length=5, factor=1.5)]) + + serialized_prompt = prompt.to_json() + + assert serialized_prompt == [ + { + "type": "text", + "data": prompt_text, + "controls": [{"start": 3, "length": 5, "factor": 1.5}], + } + ]