Skip to content

Commit

Permalink
Differentiate between TextPosition and PromptItemPosition
Browse files Browse the repository at this point in the history
  • Loading branch information
volkerstampa committed Oct 4, 2023
1 parent 5c0abcb commit 6009ecd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
54 changes: 36 additions & 18 deletions aleph_alpha_client/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@


@dataclass(frozen=True)
class Position:
class TextPosition:
item: int
start: int
length: int


@dataclass(frozen=True)
class PromptItemPosition:
item: int


@dataclass(frozen=True)
class PromptData:
prompt: Prompt
positions: Mapping[str, Position]
positions: Mapping[str, TextPosition]


class PromptTemplate:
Expand Down Expand Up @@ -132,14 +137,17 @@ def to_prompt_data(self, **kwargs) -> PromptData:
placeholder_indices = self._compute_indices(
self.placeholders.keys(), liquid_prompt
)
positions_by_placeholder: Dict[Placeholder, List[Position]] = defaultdict(list)
positions_by_placeholder: Dict[
Placeholder, List[Union[TextPosition, PromptItemPosition]]
] = defaultdict(list)
modalities = self._modalities_from(
placeholder_indices, positions_by_placeholder, liquid_prompt
)
result = PromptData(
Prompt(list(modalities)),
{
template_variable_by_placeholder.get(placeholder): positions
# template_variable_by_placeholder.get(placeholder) cannot be None as None-s are filtered
template_variable_by_placeholder.get(placeholder): positions # type: ignore
for placeholder, positions in positions_by_placeholder.items()
if template_variable_by_placeholder.get(placeholder)
},
Expand All @@ -165,35 +173,45 @@ def _compute_indices(
def _modalities_from(
self,
placeholder_indices: Iterable[Tuple[int, int]],
positions_by_placeholder: Dict[Placeholder, List[Position]],
positions_by_placeholder: Dict[
Placeholder, List[Union[TextPosition, PromptItemPosition]]
],
template: str,
) -> Iterable[PromptItem]:
last_to = 0
accumulated_text = ""
item_cnt = 0

def new_prompt_item(item: PromptItem) -> PromptItem:
nonlocal item_cnt, accumulated_text
item_cnt += 1
accumulated_text = ""
return item

def current_text_position(value: str) -> Iterable[TextPosition]:
nonlocal item_cnt, accumulated_text
yield TextPosition(
item=item_cnt,
start=len(accumulated_text),
length=len(value),
)
accumulated_text += value

for placeholder_from, placeholder_to in placeholder_indices:
accumulated_text += template[last_to:placeholder_from]
placeholder = Placeholder(UUID(template[placeholder_from:placeholder_to]))
placeholder_value = self.placeholders[placeholder]
if isinstance(placeholder_value, (Tokens, Image)):
if accumulated_text:
yield Text.from_text(accumulated_text)
item_cnt += 1
accumulated_text = ""
yield new_prompt_item(Text.from_text(accumulated_text))
positions_by_placeholder[placeholder].append(
Position(item=item_cnt, start=0, length=0)
PromptItemPosition(item=item_cnt)
)
yield placeholder_value
item_cnt += 1
yield new_prompt_item(placeholder_value)
else:
positions_by_placeholder[placeholder].append(
Position(
item=item_cnt,
start=len(accumulated_text),
length=len(placeholder_value),
)
positions_by_placeholder[placeholder].extend(
current_text_position(placeholder_value)
)
accumulated_text += placeholder_value
last_to = placeholder_to
if last_to < len(template):
yield Text.from_text(accumulated_text + template[last_to:])
13 changes: 7 additions & 6 deletions tests/test_prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pathlib import Path
from typing import List
from pytest import raises
from aleph_alpha_client.prompt import Prompt, Image, Text, Tokens
from aleph_alpha_client.prompt_template import Position, PromptTemplate
from aleph_alpha_client.prompt import Prompt, Image, PromptItem, Text, Tokens
from aleph_alpha_client.prompt_template import TextPosition, PromptTemplate
from liquid.exceptions import LiquidTypeError
from .common import prompt_image

Expand Down Expand Up @@ -177,13 +178,13 @@ def test_to_prompt_resets_template(prompt_image: Image):

def test_to_prompt_returns_position_of_embedded_texts(prompt_image: Image):
embedded = "Embedded"
prefix_items = [
prefix_items: List[PromptItem] = [
Text.from_text("Prefix Text Item"),
prompt_image,
]
prefix_merged = Text.from_text("Merged Prefix Item")
postfix_merged = Text.from_text("Merged Postfix Item")
postfix_items = [prompt_image]
postfix_items: List[PromptItem] = [prompt_image]
template = PromptTemplate(
"{{prefix_items}}{{embedded_text}} more text {{postfix_items}}{{embedded_text}}"
)
Expand All @@ -197,12 +198,12 @@ def test_to_prompt_returns_position_of_embedded_texts(prompt_image: Image):
positions = prompt_data.positions.get("embedded_text")

assert positions == [
Position(
TextPosition(
item=len(prefix_items),
start=len(prefix_merged.text),
length=len(embedded),
),
Position(
TextPosition(
item=len(prefix_items) + len(postfix_items) + 1,
start=0,
length=len(embedded),
Expand Down

0 comments on commit 6009ecd

Please sign in to comment.