Skip to content

Commit

Permalink
v3.1 - Release .explain() (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbrandt authored Apr 12, 2023
1 parent 3e5c3de commit 5880f79
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 42 deletions.
50 changes: 50 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,55 @@
# Changelog

## 3.1.0

### Features

### New `.explain()` method 🎉

Better understand the source of a completion, specifically on how much each section of a prompt impacts the completion.

To get started, you can simply pass in a prompt you used with a model and the completion the model gave and generate an explanation:

```python
from aleph_alpha_client import Client, CompletionRequest, ExplanationRequest, Prompt

client = Client(token=os.environ["AA_TOKEN"])
prompt = Prompt.from_text("An apple a day, ")
model_name = "luminous-extended"

# create a completion request
request = CompletionRequest(prompt=prompt, maximum_tokens=32)
response = client.complete(request, model=model_name)

# generate an explanation
request = ExplanationRequest(prompt=prompt, target=response.completions[0].completion)
response = client.explain(request, model=model_name)
```

To visually see the results, you can also use this in our [Playground](https://app.aleph-alpha.com/playground/explanation).

We also have more [documentation and examples](https://docs.aleph-alpha.com/docs/tasks/explain/) available for you to read.

### AtMan (Attention Manipulation)

Under the hood, we are leveraging the method from our [AtMan paper](https://arxiv.org/abs/2301.08110) to help generate these explanations. And we've also exposed these controls anywhere you can submit us a prompt!

So if you have other use cases for attention manipulation, you can pass these AtMan controls as part of your prompt items.

```python
from aleph_alpha_client import Prompt, Text, TextControl

Prompt([
Text("Hello, World!", controls=[TextControl(start=0, length=5, factor=0.5)]),
Image.from_url(
"https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png",
controls=[ImageControl(top=0.25, left=0.25, height=0.5, width=0.5, factor=2.0)]
)
])
```

For more information, check out our [documentation and examples](https://docs.aleph-alpha.com/docs/explainability/attention-manipulation/).

## 3.0.0

### Breaking Changes
Expand Down
57 changes: 44 additions & 13 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def evaluate(
request (EvaluationRequest, required):
Parameters for the requested evaluation.
model (string, optional, default None):
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Expand Down Expand Up @@ -451,11 +451,29 @@ def summarize(
)
return SummarizationResponse.from_json(response)

def _explain(
def explain(
self,
request: ExplanationRequest,
model: str,
) -> ExplanationResponse:
"""Better understand the source of a completion, specifically on how much each section of a
prompt impacts each token of the completion.
Parameters:
request (ExplanationRequest, required):
Parameters for the requested explanation.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> request = ExplanationRequest(
prompt=Prompt.from_text("Andreas likes"),
target=" pizza."
)
>>> response = client.explain(request, model="luminous-extended")
"""
response = self._post_request(
"explain",
request,
Expand Down Expand Up @@ -700,7 +718,7 @@ async def tokenize(
request (TokenizationRequest, required):
Parameters for the requested tokenization.
model (string, optional, default None):
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Expand All @@ -726,7 +744,7 @@ async def detokenize(
request (DetokenizationRequest, required):
Parameters for the requested detokenization.
model (string, optional, default None):
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Expand All @@ -752,7 +770,7 @@ async def embed(
request (EmbeddingRequest, required):
Parameters for the requested embedding.
model (string, optional, default None):
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Expand All @@ -779,7 +797,7 @@ async def semantic_embed(
request (SemanticEmbeddingRequest, required):
Parameters for the requested semnatic embedding.
model (string, optional, default None):
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Expand Down Expand Up @@ -828,7 +846,7 @@ async def evaluate(
request (EvaluationRequest, required):
Parameters for the requested evaluation.
model (string, optional, default None):
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Expand Down Expand Up @@ -871,11 +889,6 @@ async def summarize(
Parameters:
request (SummarizationRequest, required):
Parameters for the requested summarization.
model (string, optional, default None):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> request = SummarizationRequest(
document=Document.from_text("Andreas likes pizza."),
Expand All @@ -888,11 +901,29 @@ async def summarize(
)
return SummarizationResponse.from_json(response)

async def _explain(
async def explain(
self,
request: ExplanationRequest,
model: str,
) -> ExplanationResponse:
"""Better understand the source of a completion, specifically on how much each section of a
prompt impacts each token of the completion.
Parameters:
request (ExplanationRequest, required):
Parameters for the requested explanation.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> request = ExplanationRequest(
prompt=Prompt.from_text("Andreas likes"),
target=" pizza."
)
>>> response = await client.explain(request, model="luminous-extended")
"""
response = await self._post_request(
"explain",
request,
Expand Down
97 changes: 77 additions & 20 deletions aleph_alpha_client/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def from_json(score: Any) -> "TextScore":
score=score["score"],
)


class TextScoreWithRaw(NamedTuple):
start: int
length: int
Expand All @@ -218,9 +219,10 @@ def from_text_score(score: TextScore, prompt: Text) -> "TextScoreWithRaw":
start=score.start,
length=score.length,
score=score.score,
text=prompt.text[score.start:score.start + score.length],
text=prompt.text[score.start : score.start + score.length],
)


class ImageScore(NamedTuple):
left: float
top: float
Expand Down Expand Up @@ -252,6 +254,7 @@ def from_json(score: Any) -> "TargetScore":
score=score["score"],
)


class TargetScoreWithRaw(NamedTuple):
start: int
length: int
Expand All @@ -264,9 +267,10 @@ def from_target_score(score: TargetScore, target: str) -> "TargetScoreWithRaw":
start=score.start,
length=score.length,
score=score.score,
text=target[score.start:score.start + score.length],
text=target[score.start : score.start + score.length],
)


class TokenScore(NamedTuple):
score: float

Expand All @@ -278,6 +282,13 @@ def from_json(score: Any) -> "TokenScore":


class ImagePromptItemExplanation(NamedTuple):
"""
Explains the importance of an image prompt item.
The amount of items in the "scores" array depends on the granularity setting.
Each score object contains the top-left corner of a rectangular area in the image prompt.
The coordinates are all between 0 and 1 in terms of the total image size
"""

scores: List[ImageScore]

@staticmethod
Expand Down Expand Up @@ -305,40 +316,65 @@ def in_pixels(self, prompt_item: PromptItem) -> "ImagePromptItemExplanation":


class TextPromptItemExplanation(NamedTuple):
"""
Explains the importance of a text prompt item.
The amount of items in the "scores" array depends on the granularity setting.
Each score object contains an inclusive start character and a length of the substring plus
a floating point score value.
"""

scores: List[Union[TextScore, TextScoreWithRaw]]

@staticmethod
def from_json(item: Dict[str, Any]) -> "TextPromptItemExplanation":
return TextPromptItemExplanation(
scores=[TextScore.from_json(score) for score in item["scores"]]
)

def with_text(self, prompt: Text) -> "TextPromptItemExplanation":
return TextPromptItemExplanation(
scores=[TextScoreWithRaw.from_text_score(score, prompt) if isinstance(score, TextScore) else score for score in self.scores]
scores=[
TextScoreWithRaw.from_text_score(score, prompt)
if isinstance(score, TextScore)
else score
for score in self.scores
]
)



class TargetPromptItemExplanation(NamedTuple):
"""
Explains the importance of text in the target string that came before the currently
to-be-explained target token. The amount of items in the "scores" array depends on the
granularity setting.
Each score object contains an inclusive start character and a length of the substring plus
a floating point score value.
"""

scores: List[Union[TargetScore, TargetScoreWithRaw]]

@staticmethod
def from_json(item: Dict[str, Any]) -> "TargetPromptItemExplanation":
return TargetPromptItemExplanation(
scores=[TargetScore.from_json(score) for score in item["scores"]]
)

def with_text(self, prompt: str) -> "TargetPromptItemExplanation":
return TargetPromptItemExplanation(
scores=[TargetScoreWithRaw.from_target_score(score, prompt) if isinstance(score, TargetScore) else score for score in self.scores]
scores=[
TargetScoreWithRaw.from_target_score(score, prompt)
if isinstance(score, TargetScore)
else score
for score in self.scores
]
)





class TokenPromptItemExplanation(NamedTuple):
"""Explains the importance of a request prompt item of type "token_ids".
Will contain one floating point importance value for each token in the same order as in the original prompt.
"""

scores: List[TokenScore]

@staticmethod
Expand All @@ -349,6 +385,16 @@ def from_json(item: Dict[str, Any]) -> "TokenPromptItemExplanation":


class Explanation(NamedTuple):
"""
Explanations for a given portion of the target.
Parameters:
target (str, required)
If target_granularity was set to "complete", then this will be the entire target. If it was set to "token", this will be a single target token.
items (List[Union[TextPromptItemExplanation, TargetPromptItemExplanation, TokenPromptItemExplanation, ImagePromptItemExplanation], required)
Contains one item for each prompt item (in order), and the last item refers to the target.
"""

target: str
items: List[
Union[
Expand Down Expand Up @@ -397,17 +443,19 @@ def with_image_prompt_items_in_pixels(self, prompt: Prompt) -> "Explanation":
)

def with_text_from_prompt(self, prompt: Prompt, target: str) -> "Explanation":
items: List[Union[
TextPromptItemExplanation,
ImagePromptItemExplanation,
TargetPromptItemExplanation,
TokenPromptItemExplanation,
]] = []
for item_index, item in enumerate(self.items):
items: List[
Union[
TextPromptItemExplanation,
ImagePromptItemExplanation,
TargetPromptItemExplanation,
TokenPromptItemExplanation,
]
] = []
for item_index, item in enumerate(self.items):
if isinstance(item, TextPromptItemExplanation):
# separate variable to fix linting error
prompt_item = prompt.items[item_index]
if isinstance(prompt_item, Text):
if isinstance(prompt_item, Text):
items.append(item.with_text(prompt_item))
else:
items.append(item)
Expand All @@ -421,8 +469,17 @@ def with_text_from_prompt(self, prompt: Prompt, target: str) -> "Explanation":
)



class ExplanationResponse(NamedTuple):
"""
The top-level response data structure that will be returned from an explanation request.
Parameters:
model_version (str, required)
Version of the model used to generate the explanation.
explanations (List[Explanation], required)
This array will contain one explanation object for each portion of the target.
"""

model_version: str
explanations: List[Explanation]

Expand All @@ -444,7 +501,7 @@ def with_image_prompt_items_in_pixels(
for explanation in self.explanations
]
return ExplanationResponse(self.model_version, mapped_explanations)

def with_text_from_prompt(
self, request: ExplanationRequest
) -> "ExplanationResponse":
Expand Down
2 changes: 1 addition & 1 deletion aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.0"
__version__ = "3.1.0"
Loading

0 comments on commit 5880f79

Please sign in to comment.