From 61307e8a3935519c7493ec187c6221378cb7d65d Mon Sep 17 00:00:00 2001 From: Andreas Hartel Date: Mon, 13 Feb 2023 14:50:10 +0100 Subject: [PATCH] address review comments --- aleph_alpha_client/aleph_alpha_client.py | 5 ++--- tests/test_clients.py | 2 +- tests/test_tokenize.py | 19 +++++++++++-------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index ba9612d..753afff 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -943,14 +943,13 @@ def _build_json_body( json_body["hosting"] = self.hosting return json_body - def models(self): + def models(self) -> Mapping[str, Any]: """ Queries all models which are currently available. For documentation of the response, see https://docs.aleph-alpha.com/api/available-models/ """ response = self._get_request("models_available") - _raise_for_status(response.status_code, response.text) return response.json() def complete( @@ -1402,7 +1401,7 @@ def _build_json_body( json_body["hosting"] = self.hosting return json_body - async def models(self): + async def models(self) -> Mapping[str, Any]: """ Queries all models which are currently available. diff --git a/tests/test_clients.py b/tests/test_clients.py index 45700c2..b613731 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -74,4 +74,4 @@ async def test_available_models_async_client( async_client: AsyncClient, model_name: str ): models = await async_client.models() - assert model_name in [model["name"] for model in models] + assert model_name in {model["name"] for model in models} diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index f72d091..bc4bb0a 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -53,7 +53,7 @@ def test_tokenize_with_client_against_model(client: AlephAlphaClient, model_name assert len(response["token_ids"]) == 1 -def test_offline_tokenizer_sync(sync_client: Client, model_name: str): +def test_offline_tokenize_sync(sync_client: Client, model_name: str): prompt = "Hello world" tokenizer = sync_client.tokenizer(model_name) @@ -69,17 +69,20 @@ def test_offline_tokenizer_sync(sync_client: Client, model_name: str): assert offline_tokenization.ids == online_tokenization_response.token_ids assert offline_tokenization.tokens == online_tokenization_response.tokens - detokenization_request = DetokenizationRequest( - token_ids=online_tokenization_response.token_ids - ) + +def test_offline_detokenize_sync(sync_client: Client, model_name: str): + prompt = "Hello world" + + tokenizer = sync_client.tokenizer(model_name) + offline_tokenization = tokenizer.encode(prompt) + offline_detokenization = tokenizer.decode(offline_tokenization.ids) + + detokenization_request = DetokenizationRequest(token_ids=offline_tokenization.ids) online_detokenization_response = sync_client.detokenize( detokenization_request, model_name ) - assert ( - tokenizer.decode(offline_tokenization.ids) - == online_detokenization_response.result - ) + assert offline_detokenization == online_detokenization_response.result async def test_offline_tokenizer_async(async_client: AsyncClient, model_name: str):