diff --git a/Changelog.md b/Changelog.md index bdce2b7..45f66a2 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,55 @@ # Changelog +## 3.1.0 + +### Features + +### New `.explain()` method 🎉 + +Better understand the source of a completion, specifically on how much each section of a prompt impacts the completion. + +To get started, you can simply pass in a prompt you used with a model and the completion the model gave and generate an explanation: + +```python +from aleph_alpha_client import Client, CompletionRequest, ExplanationRequest, Prompt + +client = Client(token=os.environ["AA_TOKEN"]) +prompt = Prompt.from_text("An apple a day, ") +model_name = "luminous-extended" + +# create a completion request +request = CompletionRequest(prompt=prompt, maximum_tokens=32) +response = client.complete(request, model=model_name) + +# generate an explanation +request = ExplanationRequest(prompt=prompt, target=response.completions[0].completion) +response = client.explain(request, model=model_name) +``` + +To visually see the results, you can also use this in our [Playground](https://app.aleph-alpha.com/playground/explanation). + +We also have more [documentation and examples](https://docs.aleph-alpha.com/docs/tasks/explain/) available for you to read. + +### AtMan (Attention Manipulation) + +Under the hood, we are leveraging the method from our [AtMan paper](https://arxiv.org/abs/2301.08110) to help generate these explanations. And we've also exposed these controls anywhere you can submit us a prompt! + +So if you have other use cases for attention manipulation, you can pass these AtMan controls as part of your prompt items. + +```python +from aleph_alpha_client import Prompt, Text, TextControl + +Prompt([ + Text("Hello, World!", controls=[TextControl(start=0, length=5, factor=0.5)]), + Image.from_url( + "https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png", + controls=[ImageControl(top=0.25, left=0.25, height=0.5, width=0.5, factor=2.0)] + ) +]) +``` + +For more information, check out our [documentation and examples](https://docs.aleph-alpha.com/docs/explainability/attention-manipulation/). + ## 3.0.0 ### Breaking Changes diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 2a866ad..49d64b1 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -395,7 +395,7 @@ def evaluate( request (EvaluationRequest, required): Parameters for the requested evaluation. - model (string, optional, default None): + model (string, required): Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. @@ -451,11 +451,29 @@ def summarize( ) return SummarizationResponse.from_json(response) - def _explain( + def explain( self, request: ExplanationRequest, model: str, ) -> ExplanationResponse: + """Better understand the source of a completion, specifically on how much each section of a + prompt impacts each token of the completion. + + Parameters: + request (ExplanationRequest, required): + Parameters for the requested explanation. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> request = ExplanationRequest( + prompt=Prompt.from_text("Andreas likes"), + target=" pizza." + ) + >>> response = client.explain(request, model="luminous-extended") + """ response = self._post_request( "explain", request, @@ -700,7 +718,7 @@ async def tokenize( request (TokenizationRequest, required): Parameters for the requested tokenization. - model (string, optional, default None): + model (string, required): Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. @@ -726,7 +744,7 @@ async def detokenize( request (DetokenizationRequest, required): Parameters for the requested detokenization. - model (string, optional, default None): + model (string, required): Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. @@ -752,7 +770,7 @@ async def embed( request (EmbeddingRequest, required): Parameters for the requested embedding. - model (string, optional, default None): + model (string, required): Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. @@ -779,7 +797,7 @@ async def semantic_embed( request (SemanticEmbeddingRequest, required): Parameters for the requested semnatic embedding. - model (string, optional, default None): + model (string, required): Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. @@ -828,7 +846,7 @@ async def evaluate( request (EvaluationRequest, required): Parameters for the requested evaluation. - model (string, optional, default None): + model (string, required): Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. @@ -871,11 +889,6 @@ async def summarize( Parameters: request (SummarizationRequest, required): Parameters for the requested summarization. - - model (string, optional, default None): - Name of model to use. A model name refers to a model architecture (number of parameters among others). - Always the latest version of model is used. - Examples: >>> request = SummarizationRequest( document=Document.from_text("Andreas likes pizza."), @@ -888,11 +901,29 @@ async def summarize( ) return SummarizationResponse.from_json(response) - async def _explain( + async def explain( self, request: ExplanationRequest, model: str, ) -> ExplanationResponse: + """Better understand the source of a completion, specifically on how much each section of a + prompt impacts each token of the completion. + + Parameters: + request (ExplanationRequest, required): + Parameters for the requested explanation. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> request = ExplanationRequest( + prompt=Prompt.from_text("Andreas likes"), + target=" pizza." + ) + >>> response = await client.explain(request, model="luminous-extended") + """ response = await self._post_request( "explain", request, diff --git a/aleph_alpha_client/explanation.py b/aleph_alpha_client/explanation.py index db49034..1d353a8 100644 --- a/aleph_alpha_client/explanation.py +++ b/aleph_alpha_client/explanation.py @@ -206,6 +206,7 @@ def from_json(score: Any) -> "TextScore": score=score["score"], ) + class TextScoreWithRaw(NamedTuple): start: int length: int @@ -218,9 +219,10 @@ def from_text_score(score: TextScore, prompt: Text) -> "TextScoreWithRaw": start=score.start, length=score.length, score=score.score, - text=prompt.text[score.start:score.start + score.length], + text=prompt.text[score.start : score.start + score.length], ) + class ImageScore(NamedTuple): left: float top: float @@ -252,6 +254,7 @@ def from_json(score: Any) -> "TargetScore": score=score["score"], ) + class TargetScoreWithRaw(NamedTuple): start: int length: int @@ -264,9 +267,10 @@ def from_target_score(score: TargetScore, target: str) -> "TargetScoreWithRaw": start=score.start, length=score.length, score=score.score, - text=target[score.start:score.start + score.length], + text=target[score.start : score.start + score.length], ) + class TokenScore(NamedTuple): score: float @@ -278,6 +282,13 @@ def from_json(score: Any) -> "TokenScore": class ImagePromptItemExplanation(NamedTuple): + """ + Explains the importance of an image prompt item. + The amount of items in the "scores" array depends on the granularity setting. + Each score object contains the top-left corner of a rectangular area in the image prompt. + The coordinates are all between 0 and 1 in terms of the total image size + """ + scores: List[ImageScore] @staticmethod @@ -305,6 +316,13 @@ def in_pixels(self, prompt_item: PromptItem) -> "ImagePromptItemExplanation": class TextPromptItemExplanation(NamedTuple): + """ + Explains the importance of a text prompt item. + The amount of items in the "scores" array depends on the granularity setting. + Each score object contains an inclusive start character and a length of the substring plus + a floating point score value. + """ + scores: List[Union[TextScore, TextScoreWithRaw]] @staticmethod @@ -312,15 +330,27 @@ def from_json(item: Dict[str, Any]) -> "TextPromptItemExplanation": return TextPromptItemExplanation( scores=[TextScore.from_json(score) for score in item["scores"]] ) - + def with_text(self, prompt: Text) -> "TextPromptItemExplanation": return TextPromptItemExplanation( - scores=[TextScoreWithRaw.from_text_score(score, prompt) if isinstance(score, TextScore) else score for score in self.scores] + scores=[ + TextScoreWithRaw.from_text_score(score, prompt) + if isinstance(score, TextScore) + else score + for score in self.scores + ] ) - class TargetPromptItemExplanation(NamedTuple): + """ + Explains the importance of text in the target string that came before the currently + to-be-explained target token. The amount of items in the "scores" array depends on the + granularity setting. + Each score object contains an inclusive start character and a length of the substring plus + a floating point score value. + """ + scores: List[Union[TargetScore, TargetScoreWithRaw]] @staticmethod @@ -328,17 +358,23 @@ def from_json(item: Dict[str, Any]) -> "TargetPromptItemExplanation": return TargetPromptItemExplanation( scores=[TargetScore.from_json(score) for score in item["scores"]] ) - + def with_text(self, prompt: str) -> "TargetPromptItemExplanation": return TargetPromptItemExplanation( - scores=[TargetScoreWithRaw.from_target_score(score, prompt) if isinstance(score, TargetScore) else score for score in self.scores] + scores=[ + TargetScoreWithRaw.from_target_score(score, prompt) + if isinstance(score, TargetScore) + else score + for score in self.scores + ] ) - - - class TokenPromptItemExplanation(NamedTuple): + """Explains the importance of a request prompt item of type "token_ids". + Will contain one floating point importance value for each token in the same order as in the original prompt. + """ + scores: List[TokenScore] @staticmethod @@ -349,6 +385,16 @@ def from_json(item: Dict[str, Any]) -> "TokenPromptItemExplanation": class Explanation(NamedTuple): + """ + Explanations for a given portion of the target. + + Parameters: + target (str, required) + If target_granularity was set to "complete", then this will be the entire target. If it was set to "token", this will be a single target token. + items (List[Union[TextPromptItemExplanation, TargetPromptItemExplanation, TokenPromptItemExplanation, ImagePromptItemExplanation], required) + Contains one item for each prompt item (in order), and the last item refers to the target. + """ + target: str items: List[ Union[ @@ -397,17 +443,19 @@ def with_image_prompt_items_in_pixels(self, prompt: Prompt) -> "Explanation": ) def with_text_from_prompt(self, prompt: Prompt, target: str) -> "Explanation": - items: List[Union[ - TextPromptItemExplanation, - ImagePromptItemExplanation, - TargetPromptItemExplanation, - TokenPromptItemExplanation, - ]] = [] - for item_index, item in enumerate(self.items): + items: List[ + Union[ + TextPromptItemExplanation, + ImagePromptItemExplanation, + TargetPromptItemExplanation, + TokenPromptItemExplanation, + ] + ] = [] + for item_index, item in enumerate(self.items): if isinstance(item, TextPromptItemExplanation): # separate variable to fix linting error prompt_item = prompt.items[item_index] - if isinstance(prompt_item, Text): + if isinstance(prompt_item, Text): items.append(item.with_text(prompt_item)) else: items.append(item) @@ -421,8 +469,17 @@ def with_text_from_prompt(self, prompt: Prompt, target: str) -> "Explanation": ) - class ExplanationResponse(NamedTuple): + """ + The top-level response data structure that will be returned from an explanation request. + + Parameters: + model_version (str, required) + Version of the model used to generate the explanation. + explanations (List[Explanation], required) + This array will contain one explanation object for each portion of the target. + """ + model_version: str explanations: List[Explanation] @@ -444,7 +501,7 @@ def with_image_prompt_items_in_pixels( for explanation in self.explanations ] return ExplanationResponse(self.model_version, mapped_explanations) - + def with_text_from_prompt( self, request: ExplanationRequest ) -> "ExplanationResponse": diff --git a/aleph_alpha_client/version.py b/aleph_alpha_client/version.py index 528787c..f5f41e5 100644 --- a/aleph_alpha_client/version.py +++ b/aleph_alpha_client/version.py @@ -1 +1 @@ -__version__ = "3.0.0" +__version__ = "3.1.0" diff --git a/tests/test_explanation.py b/tests/test_explanation.py index 96ee717..901b1ee 100644 --- a/tests/test_explanation.py +++ b/tests/test_explanation.py @@ -45,7 +45,7 @@ async def test_can_explain_with_async_client( target_granularity=TargetGranularity.Token, ) - explanation = await async_client._explain(request, model=model_name) + explanation = await async_client.explain(request, model=model_name) assert len(explanation.explanations) == 3 assert all([len(exp.items) == 3 for exp in explanation.explanations]) @@ -76,7 +76,7 @@ def test_explanation(sync_client: Client, model_name: str): control_token_overlap=ControlTokenOverlap.Complete, ) - explanation = sync_client._explain(request, model=model_name) + explanation = sync_client.explain(request, model=model_name) assert len(explanation.explanations) == 3 assert all([len(exp.items) == 4 for exp in explanation.explanations]) @@ -107,7 +107,7 @@ def test_explanation_auto_granularity(sync_client: Client, model_name: str): prompt_granularity=None, ) - explanation = sync_client._explain(request, model=model_name) + explanation = sync_client.explain(request, model=model_name) assert len(explanation.explanations) == 1 assert all([len(exp.items) == 4 for exp in explanation.explanations]) @@ -130,7 +130,7 @@ def test_explanation_of_image_in_pixels(sync_client: Client, model_name: str): prompt_granularity=None, ) - explanation = sync_client._explain(request, model=model_name) + explanation = sync_client.explain(request, model=model_name) explanation = explanation.with_image_prompt_items_in_pixels(request.prompt) assert len(explanation.explanations) == 1 @@ -150,7 +150,6 @@ def test_explanation_of_text_in_prompt_relativ_indeces( prompt=Prompt( [ Text.from_text("I am a programmer and French. My favourite food is"), - # " My favorite food is" Tokens.from_token_ids([4014, 36316, 5681, 387]), ] @@ -160,20 +159,22 @@ def test_explanation_of_text_in_prompt_relativ_indeces( target_granularity=TargetGranularity.Token, ) - explanation = sync_client._explain(request, model=model_name) + explanation = sync_client.explain(request, model=model_name) explanation = explanation.with_text_from_prompt(request) assert len(explanation.explanations) == 3 assert all([len(exp.items) == 3 for exp in explanation.explanations]) assert all( [ - isinstance(raw_text_score, TextScoreWithRaw) and isinstance(raw_text_score.text, str) + isinstance(raw_text_score, TextScoreWithRaw) + and isinstance(raw_text_score.text, str) for raw_text_score in explanation.explanations[0].items[0].scores ] ) assert all( [ - isinstance(raw_text_score, TargetScoreWithRaw) and isinstance(raw_text_score.text, str) + isinstance(raw_text_score, TargetScoreWithRaw) + and isinstance(raw_text_score.text, str) for raw_text_score in explanation.explanations[1].items[2].scores ] )