Skip to content

Commit

Permalink
Add support for text attention manipulation
Browse files Browse the repository at this point in the history
  • Loading branch information
volkerstampa committed Feb 13, 2023
1 parent 77857b7 commit b17ea26
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 33 deletions.
4 changes: 2 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aleph_alpha_client.document import Document
from aleph_alpha_client.explanation import ExplanationRequest, ExplanationResponse
from aleph_alpha_client.image import Image
from aleph_alpha_client.prompt import _to_prompt_item, _to_serializable_prompt
from aleph_alpha_client.prompt import _to_json, _to_serializable_prompt
from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse
from aleph_alpha_client.qa import QaRequest, QaResponse
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
Expand Down Expand Up @@ -779,7 +779,7 @@ def _explain(
):
body = {
"model": model,
"prompt": [_to_prompt_item(item) for item in request.prompt.items],
"prompt": [_to_json(item) for item in request.prompt.items],
"target": request.target,
"suppression_factor": request.suppression_factor,
"conceptual_suppression_threshold": request.conceptual_suppression_threshold,
Expand Down
12 changes: 10 additions & 2 deletions aleph_alpha_client/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional, Sequence, Union

from aleph_alpha_client.image import Image
from aleph_alpha_client.prompt import _to_prompt_item
from aleph_alpha_client.prompt import PromptItem, Text, _to_json


class Document:
Expand Down Expand Up @@ -65,14 +65,22 @@ def _to_serializable_document(self) -> Dict[str, Any]:
"""
A dict if serialized to JSON is suitable as a document element
"""

def to_prompt_item(item: Union[str, Image]) -> PromptItem:
# document still uses a plain piece of text for text-prompts
# -> convert to Text-instance
return Text.from_text(item) if isinstance(item, str) else item

if self.docx is not None:
# Serialize docx to Document JSON format
return {
"docx": self.docx,
}
elif self.prompt is not None:
# Serialize prompt to Document JSON format
prompt_data = [_to_prompt_item(prompt_item) for prompt_item in self.prompt]
prompt_data = [
_to_json(to_prompt_item(prompt_item)) for prompt_item in self.prompt
]
return {"prompt": prompt_data}
elif self.text is not None:
return {
Expand Down
104 changes: 84 additions & 20 deletions aleph_alpha_client/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,78 @@ class Tokens(NamedTuple):
"""

tokens: Sequence[int]
controls: Optional[Sequence[TokenControl]] = None
controls: Sequence[TokenControl]

def to_json(self) -> Mapping[str, Any]:
"""
Serialize the prompt item to JSON for sending to the API.
"""
payload = {"type": "token_ids", "data": self.tokens}
if self.controls:
payload["controls"] = [c.to_json() for c in self.controls]
return payload
return {
"type": "token_ids",
"data": self.tokens,
"controls": [c.to_json() for c in self.controls],
}

@staticmethod
def from_token_ids(token_ids: Sequence[int]) -> "Tokens":
return Tokens(token_ids, [])


class TextControl(NamedTuple):
"""
Attention manipulation for a Text PromptItem.
Parameters:
start (int, required):
Starting character index to apply the factor to.
length (int, required):
The amount of characters to apply the factor to.
factor (float, required):
The amount to adjust model attention by.
Values between 0 and 1 will supress attention.
A value of 1 will have no effect.
Values above 1 will increase attention.
"""

start: int
length: int
factor: float

def to_json(self) -> Mapping[str, Any]:
return self._asdict()


class Text(NamedTuple):
"""
A Text-prompt including optional controls for attention manipulation.
Parameters:
text (str, required):
The text prompt
controls (list of TextControl, required):
A list of TextControls to manilpulate attention when processing the prompt.
Can be empty if no manipulation is required.
Examples:
>>> Text("Hello, World!", controls=[TextControl(start=0, length=5, factor=0.5)])
"""

text: str
controls: Sequence[TextControl]

def to_json(self) -> Mapping[str, Any]:
return {
"type": "text",
"data": self.text,
"controls": [control.to_json() for control in self.controls],
}

@staticmethod
def from_text(text: str) -> "Text":
return Text(text, [])


PromptItem = Union[Text, Tokens, Image]


class Prompt(NamedTuple):
Expand All @@ -72,39 +134,41 @@ class Prompt(NamedTuple):
])
"""

items: Sequence[Union[str, Image, Tokens, Sequence[int]]]
items: Sequence[PromptItem]

@staticmethod
def from_text(text: str) -> "Prompt":
return Prompt([text])
def from_text(
text: str, controls: Optional[Sequence[TextControl]] = None
) -> "Prompt":
return Prompt([Text(text, controls or [])])

@staticmethod
def from_image(image: Image) -> "Prompt":
return Prompt([image])

@staticmethod
def from_tokens(tokens: Union[Sequence[int], Tokens]) -> "Prompt":
def from_tokens(
tokens: Sequence[int], controls: Optional[Sequence[TokenControl]] = None
) -> "Prompt":
"""
Examples:
>>> prompt = Prompt.from_tokens(Tokens([1, 2, 3]))
"""
if isinstance(tokens, List):
tokens = Tokens(tokens)
return Prompt([tokens])
return Prompt([Tokens(tokens, controls or [])])

def to_json(self) -> Sequence[Mapping[str, Any]]:
return [_to_prompt_item(item) for item in self.items]
return [_to_json(item) for item in self.items]


def _to_prompt_item(
item: Union[str, Image, Tokens, Sequence[int]]
) -> Mapping[str, Any]:
if isinstance(item, str):
def _to_json(item: PromptItem) -> Mapping[str, Any]:
if hasattr(item, "to_json"):
return item.to_json()
# Required for backwards compatibility
# item could be a plain piece of text or a plain list of token-ids
elif isinstance(item, str):
return {"type": "text", "data": item}
elif isinstance(item, List):
return {"type": "token_ids", "data": item}
elif hasattr(item, "to_json"):
return item.to_json()
else:
raise ValueError(
"The item in the prompt is not valid. Try either a string or an Image."
Expand All @@ -125,7 +189,7 @@ def _to_serializable_prompt(
return prompt

elif isinstance(prompt, list):
return [_to_prompt_item(item) for item in prompt]
return [_to_json(item) for item in prompt]

raise ValueError(
"Invalid prompt. Prompt must either be a string, or a list of valid multimodal propmt items."
Expand Down
34 changes: 25 additions & 9 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from aleph_alpha_client import Prompt, Tokens, TokenControl
from aleph_alpha_client.prompt import TextControl


def test_serialize_token_ids():
tokens = [1, 2, 3, 4]
prompt = Prompt.from_tokens(Tokens(tokens))
prompt = Prompt.from_tokens(tokens)
serialized_prompt = prompt.to_json()

assert serialized_prompt == [{"type": "token_ids", "data": [1, 2, 3, 4]}]
assert serialized_prompt == [
{"type": "token_ids", "data": [1, 2, 3, 4], "controls": []}
]


def test_serialize_token_ids_with_controls():
tokens = [1, 2, 3, 4]
prompt = Prompt.from_tokens(
Tokens(
tokens,
controls=[
TokenControl(pos=0, factor=0.25),
TokenControl(pos=1, factor=0.5),
],
)
tokens,
controls=[
TokenControl(pos=0, factor=0.25),
TokenControl(pos=1, factor=0.5),
],
)
serialized_prompt = prompt.to_json()

Expand All @@ -29,3 +30,18 @@ def test_serialize_token_ids_with_controls():
"controls": [{"index": 0, "factor": 0.25}, {"index": 1, "factor": 0.5}],
}
]


def test_serialize_text_with_controls():
prompt_text = "An apple a day"
prompt = Prompt.from_text(prompt_text, [TextControl(start=3, length=5, factor=1.5)])

serialized_prompt = prompt.to_json()

assert serialized_prompt == [
{
"type": "text",
"data": prompt_text,
"controls": [{"start": 3, "length": 5, "factor": 1.5}],
}
]

0 comments on commit b17ea26

Please sign in to comment.