Skip to content

Commit

Permalink
Support instructable_embed endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
WoytenAA committed Nov 18, 2024
1 parent 709b434 commit ca960bf
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 57 deletions.
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## 7.6.0

- Add `instructable_embed` to `Client` and `AsyncClient`

## 7.5.1

- Add fallback mechanism for figuring out the version locally.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async with AsyncClient(token=os.environ["AA_TOKEN"]) as client:
maximum_tokens=64,
)
response = client.complete_with_streaming(request, model="luminous-base")

async for stream_item in response:
print(stream_item)
```
Expand Down
115 changes: 114 additions & 1 deletion aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@
CompletionResponseStreamItem,
stream_item_from_json,
)
from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk, Usage, stream_chat_item_from_json
from aleph_alpha_client.chat import (
ChatRequest,
ChatResponse,
ChatStreamChunk,
Usage,
stream_chat_item_from_json,
)
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand All @@ -50,6 +56,8 @@
EmbeddingRequest,
EmbeddingResponse,
EmbeddingVector,
InstructableEmbeddingRequest,
InstructableEmbeddingResponse,
SemanticEmbeddingRequest,
SemanticEmbeddingResponse,
)
Expand Down Expand Up @@ -104,6 +112,7 @@ def _check_api_version(version_str: str):
TokenizationRequest,
DetokenizationRequest,
SemanticEmbeddingRequest,
InstructableEmbeddingRequest,
BatchSemanticEmbeddingRequest,
QaRequest,
SummarizationRequest,
Expand Down Expand Up @@ -514,6 +523,58 @@ def batch_semantic_embed(
num_tokens_prompt_total=num_tokens_prompt_total,
)

def instructable_embed(
self,
request: InstructableEmbeddingRequest,
model: str,
) -> InstructableEmbeddingResponse:
"""Embeds a text and returns vectors that can be used for classification according to a given instruction.
Parameters:
request (InstructableEmbeddingRequest, required):
Parameters for the requested instructable embedding.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # function for salutation embedding
>>> def embed_salutation(text: str):
# Create an embeddingrequest with a given instruction
request = InstructableEmbeddingRequest(
input=Prompt.from_text(text),
instruction="Represent the text to query a database of salutations"
)
# create the embedding
result = client.instructable_embed(request, model=model_name)
return result.embedding
>>>
>>> # function to calculate similarity
>>> def cosine_similarity(v1: Sequence[float], v2: Sequence[float]) -> float:
"compute cosine similarity of v1 to v2: (v1 dot v2)/{||v1||*||v2||)"
sumxx, sumxy, sumyy = 0, 0, 0
for i in range(len(v1)):
x = v1[i]; y = v2[i]
sumxx += x*x
sumyy += y*y
sumxy += x*y
return sumxy/math.sqrt(sumxx*sumyy)
>>>
>>> # define the texts
>>> text_a = "Hello"
>>> text_b = "Good morning"
>>>
>>> # show the similarity
>>> print(cosine_similarity(embed_salutation(text_a), embed_salutation(text_b)))
"""
response = self._post_request(
"instructable_embed",
request,
model,
)
return InstructableEmbeddingResponse.from_json(response)

def evaluate(
self,
request: EvaluationRequest,
Expand Down Expand Up @@ -1206,6 +1267,58 @@ async def batch_semantic_embed(
num_tokens_prompt_total=num_tokens_prompt_total,
)

async def instructable_embed(
self,
request: InstructableEmbeddingRequest,
model: str,
) -> InstructableEmbeddingResponse:
"""Embeds a text and returns vectors that can be used for classification according to a given instruction.
Parameters:
request (InstructableEmbeddingRequest, required):
Parameters for the requested instructable embedding.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # function for salutation embedding
>>> async def embed_salutation(text: str):
# Create an embeddingrequest with a given instruction
request = InstructableEmbeddingRequest(
input=Prompt.from_text(text),
instruction="Represent the text to query a database of salutations"
)
# create the embedding
result = await client.instructable_embed(request, model=model_name)
return result.embedding
>>>
>>> # function to calculate similarity
>>> def cosine_similarity(v1: Sequence[float], v2: Sequence[float]) -> float:
"compute cosine similarity of v1 to v2: (v1 dot v2)/{||v1||*||v2||)"
sumxx, sumxy, sumyy = 0, 0, 0
for i in range(len(v1)):
x = v1[i]; y = v2[i]
sumxx += x*x
sumyy += y*y
sumxy += x*y
return sumxy/math.sqrt(sumxx*sumyy)
>>>
>>> # define the texts
>>> text_a = "Hello"
>>> text_b = "Good morning"
>>>
>>> # show the similarity
>>> print(cosine_similarity(await embed_salutation(text_a), await embed_salutation(text_b)))
"""
response = await self._post_request(
"instructable_embed",
request,
model,
)
return InstructableEmbeddingResponse.from_json(response)

async def evaluate(
self,
request: EvaluationRequest,
Expand Down
23 changes: 15 additions & 8 deletions aleph_alpha_client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class Role(str, Enum):
"""A role used for a message in a chat."""

User = "user"
Assistant = "assistant"
System = "system"
Expand All @@ -14,14 +15,15 @@ class Role(str, Enum):
class Message:
"""
Describes a message in a chat.
Parameters:
role (Role, required):
The role of the message.
content (str, required):
The content of the message.
"""

role: Role
content: str

Expand All @@ -41,6 +43,7 @@ class StreamOptions:
"""
Additional options to affect the streaming behavior.
"""

# If set, an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire
# request, and the choices field will always be an empty array.
Expand All @@ -51,10 +54,11 @@ class StreamOptions:
class ChatRequest:
"""
Describes a chat request.
Only supports a subset of the parameters of `CompletionRequest` for simplicity.
See `CompletionRequest` for documentation on the parameters.
"""

model: str
messages: List[Message]
maximum_tokens: Optional[int] = None
Expand All @@ -77,6 +81,7 @@ class ChatResponse:
As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values),
the `ChatResponse` assumes there to be only one choice.
"""

finish_reason: str
message: Message

Expand All @@ -89,7 +94,6 @@ def from_json(json: Dict[str, Any]) -> "ChatResponse":
)



@dataclass(frozen=True)
class Usage:
"""
Expand All @@ -98,6 +102,7 @@ class Usage:
When streaming is enabled, this field will be null by default.
To include an additional usage-only message in the response stream, set stream_options.include_usage to true.
"""

# Number of tokens in the generated completion.
completion_tokens: int

Expand All @@ -112,11 +117,10 @@ def from_json(json: Dict[str, Any]) -> "Usage":
return Usage(
completion_tokens=json["completion_tokens"],
prompt_tokens=json["prompt_tokens"],
total_tokens=json["total_tokens"]
total_tokens=json["total_tokens"],
)



@dataclass(frozen=True)
class ChatStreamChunk:
"""
Expand All @@ -128,7 +132,8 @@ class ChatStreamChunk:
role (Role, optional):
The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks.
"""
"""

content: str
role: Optional[Role]

Expand All @@ -146,8 +151,10 @@ def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]:
)


def stream_chat_item_from_json(json: Dict[str, Any]) -> Union[Usage, ChatStreamChunk, None]:
def stream_chat_item_from_json(
json: Dict[str, Any],
) -> Union[Usage, ChatStreamChunk, None]:
if (usage := json.get("usage")) is not None:
return Usage.from_json(usage)

return ChatStreamChunk.from_json(json)
return ChatStreamChunk.from_json(json)
Loading

0 comments on commit ca960bf

Please sign in to comment.