Skip to content

Commit

Permalink
Merge pull request #85 from Aleph-Alpha/support-offline-tokenizer
Browse files Browse the repository at this point in the history
Support offline tokenizer
  • Loading branch information
ahartel authored Feb 13, 2023
2 parents 3665175 + 3ad7c65 commit 77857b7
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 10 deletions.
7 changes: 7 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 2.12.0

- Introduce offline tokenizer
- Add method `models` to Client and AsyncClient to list available models
- Fix docstrings for `complete` methods with respect to Prompt construction
- Minor docstring fix for `evaulate` methods

## 2.11.1

- fix complete in deprecated client: pass None-lists as empty list
Expand Down
70 changes: 62 additions & 8 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from tokenizers import Tokenizer # type: ignore
from types import TracebackType
from typing import Any, List, Mapping, Optional, Dict, Sequence, Tuple, Type, Union
import warnings

import aiohttp
from aiohttp import ClientResponse
from aiohttp_retry import RetryClient, ExponentialRetry
import requests
from requests import Response
from requests.adapters import HTTPAdapter
from requests.structures import CaseInsensitiveDict
from urllib3.util.retry import Retry
Expand Down Expand Up @@ -894,10 +896,13 @@ def __init__(

def get_version(self) -> str:
"""Gets version of the AlephAlpha HTTP API."""
response = self.session.get(self.host + "version")
return self._get_request("version").text

def _get_request(self, endpoint: str) -> Response:
response = self.session.get(self.host + endpoint)
if not response.ok:
_raise_for_status(response.status_code, response.text)
return response.text
return response

def _post_request(
self,
Expand Down Expand Up @@ -938,6 +943,15 @@ def _build_json_body(
json_body["hosting"] = self.hosting
return json_body

def models(self) -> List[Mapping[str, Any]]:
"""
Queries all models which are currently available.
For documentation of the response, see https://docs.aleph-alpha.com/api/available-models/
"""
response = self._get_request("models_available")
return response.json()

def complete(
self,
request: CompletionRequest,
Expand All @@ -955,7 +969,7 @@ def complete(
Examples:
>>> # create a prompt
>>> prompt = Prompt("An apple a day, ")
>>> prompt = Prompt.from_text("An apple a day, ")
>>>
>>> # create a completion request
>>> request = CompletionRequest(
Expand Down Expand Up @@ -1121,7 +1135,7 @@ def evaluate(
Examples:
>>> request = EvaluationRequest(
prompt=Prompt.from_text("hello"), completion_expected="world"
prompt=Prompt.from_text("hello"), completion_expected=" world"
)
>>> response = client.evaluate(request, model=model_name)
"""
Expand Down Expand Up @@ -1219,6 +1233,15 @@ def _search(
response = self._post_request("search", request, None)
return SearchResponse.from_json(response)

def tokenizer(self, model: str) -> Tokenizer:
"""Returns a Tokenizer instance with the settings that were used to train the model.
Examples:
>>> tokenizer = client.tokenizer(model="luminous-extended")
>>> tokenized_prompt = tokenizer.encode("Hello world")
"""
return Tokenizer.from_str(self._get_request(f"models/{model}/tokenizer").text)


class AsyncClient:
"""
Expand Down Expand Up @@ -1324,13 +1347,26 @@ async def __aexit__(

async def get_version(self) -> str:
"""Gets version of the AlephAlpha HTTP API."""
return await self._get_request_text("version")

async def _get_request_text(self, endpoint: str) -> str:
async with self.session.get(
self.host + "version",
self.host + endpoint,
) as response:
if not response.ok:
_raise_for_status(response.status, await response.text())
return await response.text()

async def _get_request_json(
self, endpoint: str
) -> Union[List[Mapping[str, Any]], Mapping[str, Any]]:
async with self.session.get(
self.host + endpoint,
) as response:
if not response.ok:
_raise_for_status(response.status, await response.text())
return await response.json()

async def _post_request(
self,
endpoint: str,
Expand Down Expand Up @@ -1367,6 +1403,14 @@ def _build_json_body(
json_body["hosting"] = self.hosting
return json_body

async def models(self) -> List[Mapping[str, Any]]:
"""
Queries all models which are currently available.
For documentation of the response, see https://docs.aleph-alpha.com/api/available-models/
"""
return await self._get_request_json("models_available") # type: ignore

async def complete(
self,
request: CompletionRequest,
Expand All @@ -1384,7 +1428,7 @@ async def complete(
Examples:
>>> # create a prompt
>>> prompt = Prompt("An apple a day, ")
>>> prompt = Prompt.from_text("An apple a day, ")
>>>
>>> # create a completion request
>>> request = CompletionRequest(
Expand Down Expand Up @@ -1549,7 +1593,7 @@ async def evaluate(
Examples:
>>> request = EvaluationRequest(
prompt=Prompt.from_text("hello"), completion_expected="world"
prompt=Prompt.from_text("hello"), completion_expected=" world"
)
>>> response = await client.evaluate(request, model=model_name)
"""
Expand Down Expand Up @@ -1646,3 +1690,13 @@ async def _search(
"""
response = await self._post_request("search", request, None)
return SearchResponse.from_json(response)

async def tokenizer(self, model: str) -> Tokenizer:
"""Returns a Tokenizer instance with the settings that were used to train the model.
Examples:
>>> tokenizer = await client.tokenizer(model="luminous-extended")
>>> tokenized_prompt = tokenizer.encode("Hello world")
"""
response = await self._get_request_text(f"models/{model}/tokenizer")
return Tokenizer.from_str(response)
2 changes: 1 addition & 1 deletion aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.11.1"
__version__ = "2.12.0"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def version():
"aiohttp >= 3.8.3",
"aiodns >= 3.0.0",
"aiohttp-retry >= 2.8.3",
"tokenizers >= 0.13.2",
],
tests_require=tests_require,
extras_require={
Expand Down
16 changes: 15 additions & 1 deletion tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CompletionResult,
)
from aleph_alpha_client.prompt import Prompt
from tests.common import model_name
from tests.common import model_name, sync_client, async_client


@pytest.mark.system_test
Expand Down Expand Up @@ -61,3 +61,17 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer):
host=httpserver.url_for(""), token="AA_TOKEN", nice=True
) as client:
await client.complete(request, model="luminous")


@pytest.mark.system_test
def test_available_models_sync_client(sync_client: Client, model_name: str):
models = sync_client.models()
assert model_name in {model["name"] for model in models}


@pytest.mark.system_test
async def test_available_models_async_client(
async_client: AsyncClient, model_name: str
):
models = await async_client.models()
assert model_name in {model["name"] for model in models}
47 changes: 47 additions & 0 deletions tests/test_tokenize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient, AsyncClient, Client
from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel
from aleph_alpha_client.prompt import Prompt
from aleph_alpha_client.tokenization import TokenizationRequest
from aleph_alpha_client.detokenization import DetokenizationRequest

from tests.common import (
sync_client,
Expand Down Expand Up @@ -49,3 +51,48 @@ def test_tokenize_with_client_against_model(client: AlephAlphaClient, model_name

assert len(response["tokens"]) == 1
assert len(response["token_ids"]) == 1


def test_offline_tokenize_sync(sync_client: Client, model_name: str):
prompt = "Hello world"

tokenizer = sync_client.tokenizer(model_name)
offline_tokenization = tokenizer.encode(prompt)

tokenization_request = TokenizationRequest(
prompt=prompt, token_ids=True, tokens=True
)
online_tokenization_response = sync_client.tokenize(
tokenization_request, model_name
)

assert offline_tokenization.ids == online_tokenization_response.token_ids
assert offline_tokenization.tokens == online_tokenization_response.tokens


def test_offline_detokenize_sync(sync_client: Client, model_name: str):
prompt = "Hello world"

tokenizer = sync_client.tokenizer(model_name)
offline_tokenization = tokenizer.encode(prompt)
offline_detokenization = tokenizer.decode(offline_tokenization.ids)

detokenization_request = DetokenizationRequest(token_ids=offline_tokenization.ids)
online_detokenization_response = sync_client.detokenize(
detokenization_request, model_name
)

assert offline_detokenization == online_detokenization_response.result


async def test_offline_tokenizer_async(async_client: AsyncClient, model_name: str):
prompt = "Hello world"

tokenizer = await async_client.tokenizer(model_name)
offline_tokenization = tokenizer.encode(prompt)

request = TokenizationRequest(prompt=prompt, token_ids=True, tokens=True)
response = await async_client.tokenize(request, model_name)

assert offline_tokenization.ids == response.token_ids
assert offline_tokenization.tokens == response.tokens

0 comments on commit 77857b7

Please sign in to comment.