diff --git a/aleph_alpha_client/__init__.py b/aleph_alpha_client/__init__.py index dcedc92..2d65b28 100644 --- a/aleph_alpha_client/__init__.py +++ b/aleph_alpha_client/__init__.py @@ -1,55 +1,65 @@ +from .prompt import ( + ControlTokenOverlap, + Image, + ImageControl, + ImagePrompt, + Prompt, + Text, + TextControl, + TokenControl, + Tokens, +) from .aleph_alpha_client import ( - AlephAlphaClient, - QuotaError, POOLING_OPTIONS, + AlephAlphaClient, AsyncClient, Client, + QuotaError, ) from .aleph_alpha_model import AlephAlphaModel -from .image import Image, ImagePrompt, ImageControl -from .prompt import Prompt, Tokens, TokenControl, Text, TextControl +from .completion import CompletionRequest, CompletionResponse +from .detokenization import DetokenizationRequest, DetokenizationResponse +from .document import Document +from .embedding import ( + EmbeddingRequest, + EmbeddingResponse, + SemanticEmbeddingRequest, + SemanticEmbeddingResponse, + SemanticRepresentation, +) +from .evaluation import EvaluationRequest, EvaluationResponse from .explanation import ( - ExplanationRequest, + CustomGranularity, + Explanation, ExplanationPostprocessing, + ExplanationRequest, ExplanationResponse, - TargetGranularity, - CustomGranularity, - TextScore, + ImagePromptItemExplanation, ImageScore, + TargetGranularity, + TargetPromptItemExplanation, TargetScore, - TokenScore, - ImagePromptItemExplanation, TextPromptItemExplanation, - TargetPromptItemExplanation, + TextScore, TokenPromptItemExplanation, - Explanation, -) -from .embedding import ( - EmbeddingRequest, - EmbeddingResponse, - SemanticEmbeddingRequest, - SemanticEmbeddingResponse, - SemanticRepresentation, + TokenScore, ) -from .completion import CompletionRequest, CompletionResponse from .qa import QaRequest, QaResponse -from .evaluation import EvaluationRequest, EvaluationResponse -from .tokenization import TokenizationRequest, TokenizationResponse -from .detokenization import DetokenizationRequest, DetokenizationResponse -from .summarization import SummarizationRequest, SummarizationResponse from .search import SearchRequest, SearchResponse, SearchResult -from .utils import load_base64_from_url, load_base64_from_file -from .document import Document +from .summarization import SummarizationRequest, SummarizationResponse +from .tokenization import TokenizationRequest, TokenizationResponse +from .utils import load_base64_from_file, load_base64_from_url from .version import __version__ __all__ = [ - "POOLING_OPTIONS", "AlephAlphaClient", "AlephAlphaModel", "AsyncClient", "Client", "CompletionRequest", "CompletionResponse", + "ControlTokenOverlap", + "CustomGranularity", "DetokenizationRequest", "DetokenizationResponse", "Document", @@ -57,15 +67,6 @@ "EmbeddingResponse", "EvaluationRequest", "EvaluationResponse", - "CustomGranularity", - "TextScore", - "ImageScore", - "TargetScore", - "TokenScore", - "ImagePromptItemExplanation", - "TextPromptItemExplanation", - "TargetPromptItemExplanation", - "TokenPromptItemExplanation", "Explanation", "ExplanationPostprocessing", "ExplanationRequest", @@ -73,6 +74,9 @@ "Image", "ImageControl", "ImagePrompt", + "ImagePromptItemExplanation", + "ImageScore", + "POOLING_OPTIONS", "Prompt", "QaRequest", "QaResponse", @@ -86,10 +90,16 @@ "SummarizationRequest", "SummarizationResponse", "TargetGranularity", + "TargetPromptItemExplanation", + "TargetScore", "Text", "TextControl", + "TextPromptItemExplanation", + "TextScore", + "TokenControl", "TokenizationRequest", "TokenizationResponse", - "TokenControl", + "TokenPromptItemExplanation", "Tokens", + "TokenScore", ] diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 51a3b15..133af87 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -20,7 +20,7 @@ ExplanationRequest, ExplanationResponse, ) -from aleph_alpha_client.image import Image +from aleph_alpha_client import Image from aleph_alpha_client.prompt import _to_json, _to_serializable_prompt from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse from aleph_alpha_client.qa import QaRequest, QaResponse diff --git a/aleph_alpha_client/document.py b/aleph_alpha_client/document.py index fdc090c..7544834 100644 --- a/aleph_alpha_client/document.py +++ b/aleph_alpha_client/document.py @@ -1,8 +1,7 @@ import base64 -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union -from aleph_alpha_client.image import Image -from aleph_alpha_client.prompt import PromptItem, Text, _to_json +from aleph_alpha_client.prompt import Image, PromptItem, Text, _to_json class Document: diff --git a/aleph_alpha_client/explanation.py b/aleph_alpha_client/explanation.py index 7e3537d..7532001 100644 --- a/aleph_alpha_client/explanation.py +++ b/aleph_alpha_client/explanation.py @@ -11,9 +11,8 @@ # Import Literal with Python 3.7 fallback from typing_extensions import Literal -from aleph_alpha_client.image import Image -from aleph_alpha_client.prompt import Prompt, PromptItem +from aleph_alpha_client.prompt import ControlTokenOverlap, Image, Prompt, PromptItem class ExplanationPostprocessing(Enum): @@ -34,6 +33,17 @@ def to_json(self) -> str: class CustomGranularity(NamedTuple): + """ + Allows for passing a custom delimiter to determine the granularity to + to explain the prompt by. The text of the prompt will be split by the + delimiter you provide. + + Parameters: + delimiter (str, required): + String to split the text in the prompt by for generating + explanations for your prompt. + """ + delimiter: str def to_json(self) -> Mapping[str, Any]: @@ -76,10 +86,78 @@ def to_json(self) -> str: class ExplanationRequest(NamedTuple): + """ + Describes an Explanation request you want to make agains the API. + + Parameters: + prompt (Prompt, required) + Prompt you want to generate explanations for a target completion. + target (str, required) + The completion string to be explained based on model probabilities. + contextual_control_threshold (float, default None) + If set to None, attention control parameters only apply to those tokens that have + explicitly been set in the request. + If set to a non-None value, we apply the control parameters to similar tokens as well. + Controls that have been applied to one token will then be applied to all other tokens + that have at least the similarity score defined by this parameter. + The similarity score is the cosine similarity of token embeddings. + control_factor (float, default None): + The amount to adjust model attention by. + For Explanation, you want to supress attention, and the API will default to 0.1. + Values between 0 and 1 will supress attention. + A value of 1 will have no effect. + Values above 1 will increase attention. + control_token_overlap (ControlTokenOverlap, default None) + What to do if a control partially overlaps with a text or image token. + If set to "partial", the factor will be adjusted proportionally with the amount + of the token it overlaps. So a factor of 2.0 of a control that only covers 2 of + 4 token characters, would be adjusted to 1.5. + If set to "complete", the full factor will be applied as long as the control + overlaps with the token at all. + control_log_additive (bool, default None) + True: apply control by adding the log(control_factor) to attention scores. + False: apply control by (attention_scores - - attention_scores.min(-1)) * control_factor + If None, the API will default to True + prompt_granularity (PromptGranularity, default None) + At which granularity should the target be explained in terms of the prompt. + If you choose, for example, "sentence" then we report the importance score of each + sentence in the prompt towards generating the target output. + + If you do not choose a granularity then we will try to find the granularity that + brings you closest to around 30 explanations. For large documents, this would likely + be sentences. For short prompts this might be individual words or even tokens. + + If you choose a custom granularity then you must provide a custom delimiter. We then + split your prompt by that delimiter. This might be helpful if you are using few-shot + prompts that contain stop sequences. + + For image prompt items, the granularities determine into how many tiles we divide + the image for the explanation. + "token" -> 12x12 + "word" -> 6x6 + "sentence" -> 3x3 + "paragraph" -> 1 + target_granularity (TargetGranularity, default None) + How many explanations should be returned in the output. + + "complete" -> Return one explanation for the entire target. Helpful in many cases to determine which parts of the prompt contribute overall to the given completion. + "token" -> Return one explanation for each token in the target. + + If None, API will default to "complete" + postprocessing (ExplanationPostprocessing, default None) + Optionally apply postprocessing to the difference in cross entropy scores for each token. + "none": Apply no postprocessing. + "absolute": Return the absolute value of each value. + "square": Square each value + normalize (bool, default None) + Return normalized scores. Minimum score becomes 0 and maximum score becomes 1. Applied after any postprocessing + """ + prompt: Prompt target: str contextual_control_threshold: Optional[float] = None control_factor: Optional[float] = None + control_token_overlap: Optional[ControlTokenOverlap] = None control_log_additive: Optional[bool] = None prompt_granularity: Optional[PromptGranularity] = None target_granularity: Optional[TargetGranularity] = None @@ -93,8 +171,10 @@ def to_json(self) -> Dict[str, Any]: } if self.contextual_control_threshold is not None: payload["contextual_control_threshold"] = self.contextual_control_threshold - if self.control_factor is not None: - payload["control_factor"] = self.control_factor + if self.control_token_overlap is not None: + payload["control_token_overlap"] = self.control_token_overlap.to_json() + if self.postprocessing is not None: + payload["postprocessing"] = self.postprocessing.to_json() if self.control_log_additive is not None: payload["control_log_additive"] = self.control_log_additive if self.prompt_granularity is not None: diff --git a/aleph_alpha_client/image.py b/aleph_alpha_client/image.py deleted file mode 100644 index 5222784..0000000 --- a/aleph_alpha_client/image.py +++ /dev/null @@ -1,239 +0,0 @@ -import base64 -import io -from PIL.Image import Image as PILImage -import PIL -from pathlib import Path -from typing import Any, Dict, Mapping, NamedTuple, Optional, Sequence, Tuple, Union -from urllib.parse import urlparse - -import requests - - -class Cropping: - """ - Describes a quadratic crop of the file. - """ - - def __init__(self, upper_left_x: int, upper_left_y: int, size: int): - self.upper_left_x = upper_left_x - self.upper_left_y = upper_left_y - self.size = size - - -class ImageControl(NamedTuple): - """ - Attention manipulation for an Image PromptItem. - - All coordinates of the bounding box are logical coordinates (between 0 and 1) and relative to - the entire image. - - Keep in mind, non-square images are center-cropped by default before going to the model. (You - can specify a custom cropping if you want.). Since control coordinates are relative to the - entire image, all or a portion of your control may be outside the "model visible area". - - Parameters: - left (float, required): - x-coordinate of top left corner of the control bounding box. - Must be a value between 0 and 1, where 0 is the left corner and 1 is the right corner. - top (float, required): - y-coordinate of top left corner of the control bounding box - Must be a value between 0 and 1, where 0 is the top pixel row and 1 is the bottom row. - width (float, required): - width of the control bounding box - Must be a value between 0 and 1, where 1 means the full width of the image. - height (float, required): - height of the control bounding box - Must be a value between 0 and 1, where 1 means the full height of the image. - factor (float, required): - The amount to adjust model attention by. - Values between 0 and 1 will supress attention. - A value of 1 will have no effect. - Values above 1 will increase attention. - """ - - left: float - top: float - width: float - height: float - factor: float - - def to_json(self) -> Mapping[str, Any]: - return { - "rect": { - "left": self.left, - "top": self.top, - "width": self.width, - "height": self.height, - }, - "factor": self.factor, - } - - -class Image: - """ - An image send as part of a prompt to a model. The image is represented as - base64. - - Note: The models operate on square images. All non-square images are center-cropped - before going to the model, so portions of the image may not be visible. - - You can supply specific cropping parameters if you like, to choose a different area - of the image than a center-crop. Or, you can always transform the image yourself to - a square before sending it. - - Examples: - >>> # You need to choose a model with multimodal capabilities for this example. - >>> url = "https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png" - >>> image = Image.from_url(url) - """ - - def __init__( - self, - base_64: str, - cropping: Optional[Cropping], - controls: Sequence[ImageControl], - ): - # We use a base_64 reperesentation, because we want to embed the image - # into a prompt send in JSON. - self.base_64 = base_64 - self.cropping = cropping - self.controls: Sequence[ImageControl] = controls - - @classmethod - def from_image_source( - cls, - image_source: Union[str, Path, bytes], - controls: Optional[Sequence[ImageControl]] = None, - ): - """ - Abstraction on top of the existing methods of image initialization. - If you are not sure what the exact type of your image, but you know it is either a Path object, URL, a file path, - or a bytes array, just use the method and we will figure out which of the methods of image initialization to use - """ - if isinstance(image_source, Path): - return cls.from_file(path=str(image_source), controls=controls) - - elif isinstance(image_source, str): - try: - p = urlparse(image_source) - if p.scheme: - return cls.from_url(url=image_source, controls=controls) - except Exception as e: - # we assume that If the string runs into a Exception it isn't not a valid ulr - pass - - return cls.from_file(path=image_source, controls=controls) - - elif isinstance(image_source, bytes): - return cls.from_bytes(bytes=image_source, controls=controls) - - else: - raise TypeError( - f"The image source: {image_source} should be either Path, str or bytes" - ) - - @classmethod - def from_bytes( - cls, - bytes: bytes, - cropping: Optional[Cropping] = None, - controls: Optional[Sequence[ImageControl]] = None, - ): - image = base64.b64encode(bytes).decode() - return cls(image, cropping, controls or []) - - @classmethod - def from_url(cls, url: str, controls: Optional[Sequence[ImageControl]] = None): - """ - Downloads a file and prepare it to be used in a prompt. - The image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop) - """ - return cls.from_bytes( - cls._get_url(url), cropping=None, controls=controls or None - ) - - @classmethod - def from_url_with_cropping( - cls, - url: str, - upper_left_x: int, - upper_left_y: int, - crop_size: int, - controls: Optional[Sequence[ImageControl]] = None, - ): - """ - Downloads a file and prepare it to be used in a prompt. - upper_left_x, upper_left_y and crop_size are used to crop the image. - """ - cropping = Cropping( - upper_left_x=upper_left_x, upper_left_y=upper_left_y, size=crop_size - ) - bytes = cls._get_url(url) - return cls.from_bytes(bytes, cropping=cropping, controls=controls or []) - - @classmethod - def from_file(cls, path: str, controls: Optional[Sequence[ImageControl]] = None): - """ - Load an image from disk and prepare it to be used in a prompt - If they are not provided then the image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop) - """ - with open(path, "rb") as f: - image = f.read() - return cls.from_bytes(image, None, controls or []) - - @classmethod - def from_file_with_cropping( - cls, - path: str, - upper_left_x: int, - upper_left_y: int, - crop_size: int, - controls: Optional[Sequence[ImageControl]] = None, - ): - """ - Load an image from disk and prepare it to be used in a prompt - upper_left_x, upper_left_y and crop_size are used to crop the image. - """ - cropping = Cropping( - upper_left_x=upper_left_x, upper_left_y=upper_left_y, size=crop_size - ) - with open(path, "rb") as f: - bytes = f.read() - return cls.from_bytes(bytes, cropping=cropping, controls=controls or None) - - @classmethod - def _get_url(cls, url: str) -> bytes: - response = requests.get(url) - response.raise_for_status() - return response.content - - def to_json(self) -> Dict[str, Any]: - """ - A dict if serialized to JSON is suitable as a prompt element - """ - if self.cropping is None: - return { - "type": "image", - "data": self.base_64, - "controls": [control.to_json() for control in self.controls], - } - else: - return { - "type": "image", - "data": self.base_64, - "x": self.cropping.upper_left_x, - "y": self.cropping.upper_left_y, - "size": self.cropping.size, - "controls": [control.to_json() for control in self.controls], - } - - def to_image(self) -> PILImage: - return PIL.Image.open(io.BytesIO(base64.b64decode(self.base_64))) - - def dimensions(self) -> Tuple[int, int]: - image = self.to_image() - return (image.width, image.height) - - -# For backwards compatibility we still expose this with the old name -ImagePrompt = Image diff --git a/aleph_alpha_client/prompt.py b/aleph_alpha_client/prompt.py index 57e0459..6120ad7 100644 --- a/aleph_alpha_client/prompt.py +++ b/aleph_alpha_client/prompt.py @@ -1,6 +1,44 @@ -from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union +import base64 +import io +from enum import Enum +from pathlib import Path +from typing import ( + Any, + Dict, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) +from urllib.parse import urlparse + +import PIL +import requests +from PIL.Image import Image as PILImage + + +class ControlTokenOverlap(Enum): + """ + What to do if a control partially overlaps with a text or image token. + + Partial: + The factor will be adjusted proportionally with the amount of the token + it overlaps. So a factor of 2.0 of a control that only covers 2 of 4 + token characters, would be adjusted to 1.5. + + Complete: + The full factor will be applied as long as the control overlaps with + the token at all. How many explanations should be returned in the output. + """ -from aleph_alpha_client.image import Image + Partial = "partial" + Complete = "complete" + + def to_json(self) -> str: + return self.value class TokenControl(NamedTuple): @@ -79,14 +117,34 @@ class TextControl(NamedTuple): Values between 0 and 1 will supress attention. A value of 1 will have no effect. Values above 1 will increase attention. + token_overlap (ControlTokenOverlap, optional): + What to do if a control partially overlaps with a text token. + + If set to "partial", the factor will be adjusted proportionally + with the amount of the token it overlaps. So a factor of 2.0 of a + control that only covers 2 of 4 token characters, would be adjusted + to 1.5. + + If set to "complete", the full factor will be applied as long as + the control overlaps with the token at all. + + If not set, the API will default to "partial". """ start: int length: int factor: float + token_overlap: Optional[ControlTokenOverlap] = None - def to_json(self) -> Mapping[str, Any]: - return self._asdict() + def to_json(self) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "start": self.start, + "length": self.length, + "factor": self.factor, + } + if self.token_overlap is not None: + payload["token_overlap"] = self.token_overlap.to_json() + return payload class Text(NamedTuple): @@ -119,6 +177,252 @@ def from_text(text: str) -> "Text": return Text(text, []) +class Cropping: + """ + Describes a quadratic crop of the file. + """ + + def __init__(self, upper_left_x: int, upper_left_y: int, size: int): + self.upper_left_x = upper_left_x + self.upper_left_y = upper_left_y + self.size = size + + +class ImageControl(NamedTuple): + """ + Attention manipulation for an Image PromptItem. + + All coordinates of the bounding box are logical coordinates (between 0 and 1) and relative to + the entire image. + + Keep in mind, non-square images are center-cropped by default before going to the model. (You + can specify a custom cropping if you want.). Since control coordinates are relative to the + entire image, all or a portion of your control may be outside the "model visible area". + + Parameters: + left (float, required): + x-coordinate of top left corner of the control bounding box. + Must be a value between 0 and 1, where 0 is the left corner and 1 is the right corner. + top (float, required): + y-coordinate of top left corner of the control bounding box + Must be a value between 0 and 1, where 0 is the top pixel row and 1 is the bottom row. + width (float, required): + width of the control bounding box + Must be a value between 0 and 1, where 1 means the full width of the image. + height (float, required): + height of the control bounding box + Must be a value between 0 and 1, where 1 means the full height of the image. + factor (float, required): + The amount to adjust model attention by. + Values between 0 and 1 will supress attention. + A value of 1 will have no effect. + Values above 1 will increase attention. + token_overlap (ControlTokenOverlap, optional): + What to do if a control partially overlaps with an image token. + + If set to "partial", the factor will be adjusted proportionally + with the amount of the token it overlaps. So a factor of 2.0 of a + control that only half of the image "tile", would be adjusted to + 1.5. + + If set to "complete", the full factor will be applied as long as + the control overlaps with the token at all. + + If not set, the API will default to "partial". + """ + + left: float + top: float + width: float + height: float + factor: float + token_overlap: Optional[ControlTokenOverlap] = None + + def to_json(self) -> Mapping[str, Any]: + payload = { + "rect": { + "left": self.left, + "top": self.top, + "width": self.width, + "height": self.height, + }, + "factor": self.factor, + } + if self.token_overlap is not None: + payload["token_overlap"] = self.token_overlap.to_json() + return payload + + +class Image: + """ + An image send as part of a prompt to a model. The image is represented as + base64. + + Note: The models operate on square images. All non-square images are center-cropped + before going to the model, so portions of the image may not be visible. + + You can supply specific cropping parameters if you like, to choose a different area + of the image than a center-crop. Or, you can always transform the image yourself to + a square before sending it. + + Examples: + >>> # You need to choose a model with multimodal capabilities for this example. + >>> url = "https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png" + >>> image = Image.from_url(url) + """ + + def __init__( + self, + base_64: str, + cropping: Optional[Cropping], + controls: Sequence[ImageControl], + ): + # We use a base_64 reperesentation, because we want to embed the image + # into a prompt send in JSON. + self.base_64 = base_64 + self.cropping = cropping + self.controls: Sequence[ImageControl] = controls + + @classmethod + def from_image_source( + cls, + image_source: Union[str, Path, bytes], + controls: Optional[Sequence[ImageControl]] = None, + ): + """ + Abstraction on top of the existing methods of image initialization. + If you are not sure what the exact type of your image, but you know it is either a Path object, URL, a file path, + or a bytes array, just use the method and we will figure out which of the methods of image initialization to use + """ + if isinstance(image_source, Path): + return cls.from_file(path=str(image_source), controls=controls) + + elif isinstance(image_source, str): + try: + p = urlparse(image_source) + if p.scheme: + return cls.from_url(url=image_source, controls=controls) + except Exception as e: + # we assume that If the string runs into a Exception it isn't not a valid ulr + pass + + return cls.from_file(path=image_source, controls=controls) + + elif isinstance(image_source, bytes): + return cls.from_bytes(bytes=image_source, controls=controls) + + else: + raise TypeError( + f"The image source: {image_source} should be either Path, str or bytes" + ) + + @classmethod + def from_bytes( + cls, + bytes: bytes, + cropping: Optional[Cropping] = None, + controls: Optional[Sequence[ImageControl]] = None, + ): + image = base64.b64encode(bytes).decode() + return cls(image, cropping, controls or []) + + @classmethod + def from_url(cls, url: str, controls: Optional[Sequence[ImageControl]] = None): + """ + Downloads a file and prepare it to be used in a prompt. + The image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop) + """ + return cls.from_bytes( + cls._get_url(url), cropping=None, controls=controls or None + ) + + @classmethod + def from_url_with_cropping( + cls, + url: str, + upper_left_x: int, + upper_left_y: int, + crop_size: int, + controls: Optional[Sequence[ImageControl]] = None, + ): + """ + Downloads a file and prepare it to be used in a prompt. + upper_left_x, upper_left_y and crop_size are used to crop the image. + """ + cropping = Cropping( + upper_left_x=upper_left_x, upper_left_y=upper_left_y, size=crop_size + ) + bytes = cls._get_url(url) + return cls.from_bytes(bytes, cropping=cropping, controls=controls or []) + + @classmethod + def from_file(cls, path: str, controls: Optional[Sequence[ImageControl]] = None): + """ + Load an image from disk and prepare it to be used in a prompt + If they are not provided then the image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop) + """ + with open(path, "rb") as f: + image = f.read() + return cls.from_bytes(image, None, controls or []) + + @classmethod + def from_file_with_cropping( + cls, + path: str, + upper_left_x: int, + upper_left_y: int, + crop_size: int, + controls: Optional[Sequence[ImageControl]] = None, + ): + """ + Load an image from disk and prepare it to be used in a prompt + upper_left_x, upper_left_y and crop_size are used to crop the image. + """ + cropping = Cropping( + upper_left_x=upper_left_x, upper_left_y=upper_left_y, size=crop_size + ) + with open(path, "rb") as f: + bytes = f.read() + return cls.from_bytes(bytes, cropping=cropping, controls=controls or None) + + @classmethod + def _get_url(cls, url: str) -> bytes: + response = requests.get(url) + response.raise_for_status() + return response.content + + def to_json(self) -> Dict[str, Any]: + """ + A dict if serialized to JSON is suitable as a prompt element + """ + if self.cropping is None: + return { + "type": "image", + "data": self.base_64, + "controls": [control.to_json() for control in self.controls], + } + else: + return { + "type": "image", + "data": self.base_64, + "x": self.cropping.upper_left_x, + "y": self.cropping.upper_left_y, + "size": self.cropping.size, + "controls": [control.to_json() for control in self.controls], + } + + def to_image(self) -> PILImage: + return PIL.Image.open(io.BytesIO(base64.b64decode(self.base_64))) + + def dimensions(self) -> Tuple[int, int]: + image = self.to_image() + return (image.width, image.height) + + +# For backwards compatibility we still expose this with the old name +ImagePrompt = Image + + PromptItem = Union[Text, Tokens, Image, str, Sequence[int]] diff --git a/tests/test_complete.py b/tests/test_complete.py index cafc89d..de358b0 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -2,7 +2,7 @@ from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient, AsyncClient, Client from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel from aleph_alpha_client.completion import CompletionRequest -from aleph_alpha_client.prompt import Prompt, Text, TextControl +from aleph_alpha_client.prompt import ControlTokenOverlap, Prompt, Text, TextControl from tests.common import ( client, @@ -40,7 +40,15 @@ def test_complete(sync_client: Client, model_name: str): [ Text( "Hello, World!", - controls=[TextControl(start=1, length=5, factor=0.5)], + controls=[ + TextControl(start=1, length=5, factor=0.5), + TextControl( + start=1, + length=5, + factor=0.5, + token_overlap=ControlTokenOverlap.Complete, + ), + ], ) ] ), diff --git a/tests/test_explanation.py b/tests/test_explanation.py index 2a50d49..64733e7 100644 --- a/tests/test_explanation.py +++ b/tests/test_explanation.py @@ -1,6 +1,7 @@ from pathlib import Path import pytest from aleph_alpha_client import ( + ControlTokenOverlap, ExplanationRequest, AsyncClient, Client, @@ -72,6 +73,7 @@ def test_explanation(sync_client: Client, model_name: str): postprocessing=ExplanationPostprocessing.Absolute, normalize=True, target_granularity=TargetGranularity.Token, + control_token_overlap=ControlTokenOverlap.Complete, ) explanation = sync_client._explain(request, model=model_name) diff --git a/tests/test_image.py b/tests/test_image.py index 990a999..0b35136 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -5,7 +5,7 @@ from pytest_httpserver import HTTPServer from requests import RequestException -from aleph_alpha_client.image import Image +from aleph_alpha_client import Image def test_from_url_with_non_OK_response(httpserver: HTTPServer): diff --git a/tests/test_prompt.py b/tests/test_prompt.py index ed03f90..d7f8ff5 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -1,7 +1,14 @@ -from aleph_alpha_client import Prompt, Tokens, TokenControl, Image, ImageControl +from aleph_alpha_client import ( + ControlTokenOverlap, + Prompt, + Tokens, + TokenControl, + Image, + ImageControl, +) from aleph_alpha_client.aleph_alpha_client import Client from aleph_alpha_client.completion import CompletionRequest -from aleph_alpha_client.image import ImagePrompt +from aleph_alpha_client import ImagePrompt from aleph_alpha_client.prompt import Text, TextControl from tests.common import sync_client, model_name @@ -38,7 +45,18 @@ def test_serialize_token_ids_with_controls(): def test_serialize_text_with_controls(): prompt_text = "An apple a day" - prompt = Prompt.from_text(prompt_text, [TextControl(start=3, length=5, factor=1.5)]) + prompt = Prompt.from_text( + prompt_text, + [ + TextControl(start=3, length=5, factor=1.5), + TextControl( + start=3, + length=5, + factor=1.5, + token_overlap=ControlTokenOverlap.Complete, + ), + ], + ) serialized_prompt = prompt.to_json() @@ -46,14 +64,28 @@ def test_serialize_text_with_controls(): { "type": "text", "data": prompt_text, - "controls": [{"start": 3, "length": 5, "factor": 1.5}], + "controls": [ + {"start": 3, "length": 5, "factor": 1.5}, + {"start": 3, "length": 5, "factor": 1.5, "token_overlap": "complete"}, + ], } ] def test_serialize_image_with_controls(): image = Image.from_file( - "tests/dog-and-cat-cover.jpg", [ImageControl(0.0, 0.0, 0.5, 0.5, 0.5)] + "tests/dog-and-cat-cover.jpg", + [ + ImageControl(0.0, 0.0, 0.5, 0.5, 0.5), + ImageControl( + left=0.0, + top=0.0, + width=0.5, + height=0.5, + factor=0.5, + token_overlap=ControlTokenOverlap.Partial, + ), + ], ) prompt = Prompt.from_image(image) serialized_prompt = prompt.to_json() @@ -71,7 +103,17 @@ def test_serialize_image_with_controls(): "height": 0.5, }, "factor": 0.5, - } + }, + { + "rect": { + "left": 0.0, + "top": 0.0, + "width": 0.5, + "height": 0.5, + }, + "factor": 0.5, + "token_overlap": "partial", + }, ], } ]