Skip to content

Commit

Permalink
Merge pull request #37 from Aleph-Alpha/introduce-prompt-type
Browse files Browse the repository at this point in the history
Introduce prompt type
  • Loading branch information
ahartel authored Jun 28, 2022
2 parents f756975 + 44ce8ad commit d21e8ac
Show file tree
Hide file tree
Showing 17 changed files with 121 additions and 147 deletions.
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install aleph-alpha-client


```python
from aleph_alpha_client import ImagePrompt, AlephAlphaModel, AlephAlphaClient, CompletionRequest
from aleph_alpha_client import ImagePrompt, AlephAlphaModel, AlephAlphaClient, CompletionRequest, Prompt
import os

model = AlephAlphaModel(
Expand All @@ -33,30 +33,30 @@ model = AlephAlphaModel(
url = "https://cdn-images-1.medium.com/max/1200/1*HunNdlTmoPj8EKpl-jqvBA.png"

image = ImagePrompt.from_url(url)
prompt = [
prompt = Prompt([
image,
"Q: What does the picture show? A:",
]
])
request = CompletionRequest(prompt=prompt, maximum_tokens=20)
result = model.complete(request)

print(result.completions[0]["completion"])
print(result.completions[0].completion)
```


### Evaluation text prompt


```python
from aleph_alpha_client import AlephAlphaClient, AlephAlphaModel, EvaluationRequest
from aleph_alpha_client import AlephAlphaClient, AlephAlphaModel, EvaluationRequest, Prompt
import os

model = AlephAlphaModel(
AlephAlphaClient(host="https://api.aleph-alpha.com", token=os.getenv("AA_TOKEN")),
model_name = "luminous-extended"
)

request = EvaluationRequest(prompt="The api works", completion_expected=" well")
request = EvaluationRequest(prompt=Prompt.from_text("The api works"), completion_expected=" well")
result = model.evaluate(request)

print(result)
Expand All @@ -69,7 +69,7 @@ print(result)


```python
from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EvaluationRequest
from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EvaluationRequest, Prompt
import os

model = AlephAlphaModel(
Expand All @@ -80,10 +80,10 @@ model = AlephAlphaModel(

url = "https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/2008-09-24_Blockbuster_in_Durham.jpg/330px-2008-09-24_Blockbuster_in_Durham.jpg"
image = ImagePrompt.from_url(url)
prompt = [
prompt = Prompt([
image,
"Q: What is the name of the store?\nA:",
]
])
request = EvaluationRequest(prompt=prompt, completion_expected=" Blockbuster Video")
result = model.evaluate(request)

Expand All @@ -96,15 +96,15 @@ print(result)


```python
from aleph_alpha_client import AlephAlphaModel, AlephAlphaClient, EmbeddingRequest
from aleph_alpha_client import AlephAlphaModel, AlephAlphaClient, EmbeddingRequest, Prompt
import os

model = AlephAlphaModel(
AlephAlphaClient(host="https://api.aleph-alpha.com", token=os.getenv("AA_TOKEN")),
model_name = "luminous-extended"
)

request = EmbeddingRequest(prompt=["This is an example."], layers=[-1], pooling=["mean"])
request = EmbeddingRequest(prompt=Prompt.from_text("This is an example."), layers=[-1], pooling=["mean"])
result = model.embed(request)

print(result)
Expand All @@ -116,7 +116,7 @@ print(result)


```python
from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EmbeddingRequest
from aleph_alpha_client import ImagePrompt, AlephAlphaClient, AlephAlphaModel, EmbeddingRequest, Prompt
import os

model = AlephAlphaModel(
Expand All @@ -127,10 +127,10 @@ model = AlephAlphaModel(

url = "https://upload.wikimedia.org/wikipedia/commons/thumb/7/74/2008-09-24_Blockbuster_in_Durham.jpg/330px-2008-09-24_Blockbuster_in_Durham.jpg"
image = ImagePrompt.from_url(url)
prompt = [
prompt = Prompt([
image,
"Q: What is the name of the store?\nA:",
]
])
request = EmbeddingRequest(prompt=prompt, layers=[-1], pooling=["mean"])
result = model.embed(request)

Expand Down
1 change: 1 addition & 0 deletions aleph_alpha_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .aleph_alpha_client import AlephAlphaClient, QuotaError, POOLING_OPTIONS
from .aleph_alpha_model import AlephAlphaModel
from .image import ImagePrompt
from .prompt import Prompt
from .explanation import ExplanationRequest
from .embedding import EmbeddingRequest
from .completion import CompletionRequest
Expand Down
12 changes: 9 additions & 3 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,16 @@ def qa(
return response_json

def _explain(self, model: str, request: ExplanationRequest, hosting: Optional[str] = None):
body = request.render_as_body(model, hosting)
body = {
"model": model,
"prompt": [_to_prompt_item(item) for item in request.prompt.items],
"target": request.target,
"suppression_factor": request.suppression_factor,
"directional": request.directional,
"conceptual_suppression_threshold": request.conceptual_suppression_threshold
}
response = requests.post(f"{self.host}explain", headers=self.request_headers, json=body)
response_dict = self._translate_errors(response)
return response_dict
return self._translate_errors(response)


@staticmethod
Expand Down
45 changes: 34 additions & 11 deletions aleph_alpha_client/aleph_alpha_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Mapping
from collections import ChainMap
from typing import Any, Mapping, Union
from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.detokenization import DetokenizationRequest, DetokenizationResponse
from aleph_alpha_client.detokenization import (
DetokenizationRequest,
DetokenizationResponse,
)
from aleph_alpha_client.embedding import EmbeddingRequest, EmbeddingResponse
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.explanation import ExplanationRequest
Expand All @@ -10,35 +14,54 @@


class AlephAlphaModel:

def __init__(self, client: AlephAlphaClient, model_name: str, hosting: str = "cloud") -> None:
def __init__(
self, client: AlephAlphaClient, model_name: str, hosting: str = "cloud"
) -> None:
self.client = client
self.model_name = model_name
self.hosting = hosting

def complete(self, request: CompletionRequest) -> CompletionResponse:
response_json = self.client.complete(model = self.model_name, hosting=self.hosting, **request._asdict())
response_json = self.client.complete(
model=self.model_name, hosting=self.hosting, **self.as_request_dict(request)
)
return CompletionResponse.from_json(response_json)

def tokenize(self, request: TokenizationRequest) -> TokenizationResponse:
response_json = self.client.tokenize(model = self.model_name, **request._asdict())
response_json = self.client.tokenize(model=self.model_name, **request._asdict())
return TokenizationResponse.from_json(response_json)

def detokenize(self, request: DetokenizationRequest) -> DetokenizationResponse:
response_json = self.client.detokenize(model = self.model_name, **request._asdict())
response_json = self.client.detokenize(
model=self.model_name, **request._asdict()
)
return DetokenizationResponse.from_json(response_json)

def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
response_json = self.client.embed(model = self.model_name, hosting=self.hosting, **request._asdict())
response_json = self.client.embed(
model=self.model_name, hosting=self.hosting, **self.as_request_dict(request)
)
return EmbeddingResponse.from_json(response_json)

def evaluate(self, request: EvaluationRequest) -> EvaluationResponse:
response_json = self.client.evaluate(model = self.model_name, hosting=self.hosting, **request._asdict())
response_json = self.client.evaluate(
model=self.model_name, hosting=self.hosting, **self.as_request_dict(request)
)
return EvaluationResponse.from_json(response_json)

def qa(self, request: QaRequest) -> QaResponse:
response_json = self.client.qa(model = self.model_name, hosting=self.hosting, **request._asdict())
response_json = self.client.qa(
model=self.model_name, hosting=self.hosting, **request._asdict()
)
return QaResponse.from_json(response_json)

def _explain(self, request: ExplanationRequest) -> Mapping[str, Any]:
return self.client._explain(model = self.model_name, hosting=self.hosting, request=request)
return self.client._explain(
model=self.model_name, hosting=self.hosting, request=request
)

@staticmethod
def as_request_dict(
request: Union[CompletionRequest, EmbeddingRequest, EvaluationRequest]
) -> Mapping[str, Any]:
return ChainMap({"prompt": request.prompt.items}, request._asdict())
26 changes: 2 additions & 24 deletions aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union

from aleph_alpha_client.image import ImagePrompt
from aleph_alpha_client.prompt import _to_serializable_prompt
from aleph_alpha_client.prompt import Prompt, _to_serializable_prompt


class CompletionRequest(NamedTuple):
Expand Down Expand Up @@ -75,7 +75,7 @@ class CompletionRequest(NamedTuple):
Our goal is to improve your results while using our API. But you can always pass disable_optimizations: true and we will leave your prompt and completion untouched.
"""

prompt: Sequence[Union[str, ImagePrompt]]
prompt: Prompt
maximum_tokens: int = 64
temperature: float = 0.0
top_k: int = 0
Expand All @@ -92,28 +92,6 @@ class CompletionRequest(NamedTuple):
tokens: bool = False
disable_optimizations: bool = False

def render_as_body(self, model: str, hosting: str) -> Dict[str, Any]:
return {
"model": model,
"hosting": hosting,
"prompt": _to_serializable_prompt(self.prompt),
"maximum_tokens": self.maximum_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"best_of": self.best_of,
"n": self.n,
"logit_bias": self.logit_bias,
"log_probs": self.log_probs,
"repetition_penalties_include_prompt": self.repetition_penalties_include_prompt,
"use_multiplicative_presence_penalty": self.use_multiplicative_presence_penalty,
"stop_sequences": self.stop_sequences,
"tokens": self.tokens,
"disable_optimizations": self.disable_optimizations,
}


class CompletionResult(NamedTuple):
log_probs: Optional[Sequence[Mapping[str, Optional[float]]]] = None
Expand Down
12 changes: 6 additions & 6 deletions aleph_alpha_client/detokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@


class DetokenizationRequest(NamedTuple):
"""Describes a detokenization request.
Parameters
token_ids (Sequence[int])
Ids of the tokens for which the text should be returned.
"""
token_ids: Sequence[int]

def render_as_body(self, model: str) -> Dict[str, Any]:
return {
"model": model,
"token_ids": self.token_ids,
}


class DetokenizationResponse(NamedTuple):
result: Sequence[str]
Expand Down
15 changes: 2 additions & 13 deletions aleph_alpha_client/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from aleph_alpha_client.image import ImagePrompt
from aleph_alpha_client.prompt import _to_prompt_item
from aleph_alpha_client.prompt import Prompt, _to_prompt_item


class EmbeddingRequest(NamedTuple):
Expand Down Expand Up @@ -33,23 +33,12 @@ class EmbeddingRequest(NamedTuple):
"""

prompt: Sequence[Union[str, ImagePrompt]]
prompt: Prompt
layers: List[int]
pooling: List[str]
type: Optional[str] = None
tokens: bool = False

def render_as_body(self, model: str, hosting=Optional[str]) -> dict:
return {
"model": model,
"hosting": hosting,
"prompt": [_to_prompt_item(item) for item in self.prompt],
"layers": self.layers,
"pooling": self.pooling,
"type": self.type,
"tokens": self.tokens,
}


class EmbeddingResponse(NamedTuple):
model_version: str
Expand Down
13 changes: 2 additions & 11 deletions aleph_alpha_client/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from aleph_alpha_client.image import ImagePrompt
from aleph_alpha_client.prompt import _to_serializable_prompt
from aleph_alpha_client.prompt import Prompt


class EvaluationRequest(NamedTuple):
Expand All @@ -15,17 +14,9 @@ class EvaluationRequest(NamedTuple):
The ground truth completion expected to be produced given the prompt.
"""

prompt: Sequence[Union[str, ImagePrompt]]
prompt: Prompt
completion_expected: str

def render_as_body(self, model: str, hosting=Optional[str]) -> dict:
return {
"model": model,
"hosting": hosting,
"prompt": _to_serializable_prompt(self.prompt),
"completion_expected": self.completion_expected,
}


class EvaluationResponse(NamedTuple):
model_version: str
Expand Down
16 changes: 2 additions & 14 deletions aleph_alpha_client/explanation.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
from typing import List, NamedTuple, Optional, Union
from aleph_alpha_client.image import ImagePrompt
from aleph_alpha_client.prompt import _to_prompt_item
from aleph_alpha_client.prompt import Prompt


class ExplanationRequest(NamedTuple):
prompt: List[Union[str, ImagePrompt]]
prompt: Prompt
target: str
directional: bool
suppression_factor: float
conceptual_suppression_threshold: Optional[float] = None


def render_as_body(self, model: str, hosting=Optional[str]) -> dict:
return {
"model": model,
"prompt": [_to_prompt_item(item) for item in self.prompt],
"target": self.target,
"suppression_factor": self.suppression_factor,
"directional": self.directional,
"conceptual_suppression_threshold": self.conceptual_suppression_threshold
}
5 changes: 5 additions & 0 deletions aleph_alpha_client/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ class Prompt(NamedTuple):
def from_text(text: str) -> "Prompt":
return Prompt([text])

@staticmethod
def from_image(image: ImagePrompt) -> "Prompt":
return Prompt([image])


def _to_prompt_item(item: Union[str, ImagePrompt]) -> Dict[str, str]:
if isinstance(item, str):
return {"type": "text", "data": item}
Expand Down
Loading

0 comments on commit d21e8ac

Please sign in to comment.