Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartel committed Feb 13, 2023
1 parent 8584bda commit 61307e8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
5 changes: 2 additions & 3 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,14 +943,13 @@ def _build_json_body(
json_body["hosting"] = self.hosting
return json_body

def models(self):
def models(self) -> Mapping[str, Any]:
"""
Queries all models which are currently available.
For documentation of the response, see https://docs.aleph-alpha.com/api/available-models/
"""
response = self._get_request("models_available")
_raise_for_status(response.status_code, response.text)
return response.json()

def complete(
Expand Down Expand Up @@ -1402,7 +1401,7 @@ def _build_json_body(
json_body["hosting"] = self.hosting
return json_body

async def models(self):
async def models(self) -> Mapping[str, Any]:
"""
Queries all models which are currently available.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ async def test_available_models_async_client(
async_client: AsyncClient, model_name: str
):
models = await async_client.models()
assert model_name in [model["name"] for model in models]
assert model_name in {model["name"] for model in models}
19 changes: 11 additions & 8 deletions tests/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_tokenize_with_client_against_model(client: AlephAlphaClient, model_name
assert len(response["token_ids"]) == 1


def test_offline_tokenizer_sync(sync_client: Client, model_name: str):
def test_offline_tokenize_sync(sync_client: Client, model_name: str):
prompt = "Hello world"

tokenizer = sync_client.tokenizer(model_name)
Expand All @@ -69,17 +69,20 @@ def test_offline_tokenizer_sync(sync_client: Client, model_name: str):
assert offline_tokenization.ids == online_tokenization_response.token_ids
assert offline_tokenization.tokens == online_tokenization_response.tokens

detokenization_request = DetokenizationRequest(
token_ids=online_tokenization_response.token_ids
)

def test_offline_detokenize_sync(sync_client: Client, model_name: str):
prompt = "Hello world"

tokenizer = sync_client.tokenizer(model_name)
offline_tokenization = tokenizer.encode(prompt)
offline_detokenization = tokenizer.decode(offline_tokenization.ids)

detokenization_request = DetokenizationRequest(token_ids=offline_tokenization.ids)
online_detokenization_response = sync_client.detokenize(
detokenization_request, model_name
)

assert (
tokenizer.decode(offline_tokenization.ids)
== online_detokenization_response.result
)
assert offline_detokenization == online_detokenization_response.result


async def test_offline_tokenizer_async(async_client: AsyncClient, model_name: str):
Expand Down

0 comments on commit 61307e8

Please sign in to comment.