diff --git a/Changelog.md b/Changelog.md index 77283a2..6a0c0e8 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,12 @@ # Changelog +## 2.2.1 + +### Bugfix + +* Restore original error handling of HTTP status codes to before 2.2.0 +* Add dedicated exception BusyError for status code 503 + ## 2.2.0 ### New feature diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index d2545dc..46cb105 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -23,6 +23,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) +class BusyError(Exception): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class AlephAlphaClient: def __init__( self, host, token=None, email=None, password=None, request_timeout_seconds=180 @@ -38,6 +43,7 @@ def __init__( backoff_factor=0.1, status_forcelist=[408, 429, 500, 502, 503, 504], allowed_methods=["POST", "GET"], + raise_on_status=False, ) adapter = HTTPAdapter(max_retries=retry_strategy) self.requests_session = requests.Session() @@ -57,8 +63,7 @@ def __init__( def get_version(self): response = self.get_request(self.host + "version") - response.raise_for_status() - return response.text + return self._translate_errors(response).text def get_token(self, email, password): response = self.post_request( @@ -94,7 +99,7 @@ def available_models(self): response = self.get_request( self.host + "models_available", headers=self.request_headers ) - return self._translate_errors(response) + return self._translate_errors(response).json() def tokenize( self, model: str, prompt: str, tokens: bool = True, token_ids: bool = True @@ -113,7 +118,7 @@ def tokenize( headers=self.request_headers, json=payload, ) - return self._translate_errors(response) + return self._translate_errors(response).json() def detokenize(self, model: str, token_ids: List[int]): """ @@ -125,7 +130,7 @@ def detokenize(self, model: str, token_ids: List[int]): headers=self.request_headers, json=payload, ) - return self._translate_errors(response) + return self._translate_errors(response).json() def complete( self, @@ -299,7 +304,7 @@ def complete( headers=self.request_headers, json=payload, ) - response_json = self._translate_errors(response) + response_json = self._translate_errors(response).json() if response_json.get("optimized_prompt") is not None: # Return a message to the user that we optimized their prompt print( @@ -376,7 +381,7 @@ def embed( response = self.post_request( self.host + "embed", headers=self.request_headers, json=payload ) - return self._translate_errors(response) + return self._translate_errors(response).json() def semantic_embed( self, @@ -417,7 +422,7 @@ def semantic_embed( response = self.post_request( self.host + "semantic_embed", headers=self.request_headers, json=payload ) - return self._translate_errors(response) + return self._translate_errors(response).json() def evaluate( self, @@ -459,7 +464,7 @@ def evaluate( response = self.post_request( self.host + "evaluate", headers=self.request_headers, json=payload ) - return self._translate_errors(response) + return self._translate_errors(response).json() def qa( self, @@ -537,7 +542,7 @@ def qa( headers=self.request_headers, json=payload, ) - response_json = self._translate_errors(response) + response_json = self._translate_errors(response).json() return response_json def _explain( @@ -554,20 +559,22 @@ def _explain( response = self.post_request( f"{self.host}explain", headers=self.request_headers, json=body ) - return self._translate_errors(response) + return self._translate_errors(response).json() @staticmethod - def _translate_errors(response: Response): + def _translate_errors(response: Response) -> Response: if response.status_code == 200: - return response.json() + return response else: if response.status_code == 400: - raise ValueError(response.status_code, response.json()) + raise ValueError(response.status_code, response.text) elif response.status_code == 401: - raise PermissionError(response.status_code, response.json()) + raise PermissionError(response.status_code, response.text) elif response.status_code == 402: - raise QuotaError(response.status_code, response.json()) + raise QuotaError(response.status_code, response.text) elif response.status_code == 408: - raise TimeoutError(response.status_code, response.json()) + raise TimeoutError(response.status_code, response.text) + elif response.status_code == 503: + raise BusyError(response.status_code, response.text) else: - raise RuntimeError(response.status_code, response.json()) + raise RuntimeError(response.status_code, response.text) diff --git a/aleph_alpha_client/version.py b/aleph_alpha_client/version.py index 8a124bf..b19ee4b 100644 --- a/aleph_alpha_client/version.py +++ b/aleph_alpha_client/version.py @@ -1 +1 @@ -__version__ = "2.2.0" +__version__ = "2.2.1" diff --git a/tests/test_errors.py b/tests/test_errors.py index 2dc34bf..3ede89c 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,8 +1,9 @@ import time -from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient +from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient, BusyError import pytest import requests from tests.common import client, model_name +from pytest_httpserver import HTTPServer @pytest.mark.parametrize( @@ -387,7 +388,7 @@ def httpserver_listen_address(): return ("127.0.0.1", 8000) -def test_timeout(httpserver): +def test_timeout(httpserver: HTTPServer): def handler(foo): time.sleep(2) @@ -404,13 +405,13 @@ def test_retry_on_503(httpserver): httpserver.expect_request("/version").respond_with_data("busy", status=503) """Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint.""" - with pytest.raises(requests.exceptions.RetryError): + with pytest.raises(BusyError): AlephAlphaClient(host="http://localhost:8000/", token="AA_TOKEN") def test_retry_on_408(httpserver): - httpserver.expect_request("/version").respond_with_data("busy", status=408) + httpserver.expect_request("/version").respond_with_data("timeout", status=408) """Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint.""" - with pytest.raises(requests.exceptions.RetryError): + with pytest.raises(TimeoutError): AlephAlphaClient(host="http://localhost:8000/", token="AA_TOKEN")