Skip to content

Commit

Permalink
Merge pull request #45 from Aleph-Alpha/translate-errors
Browse files Browse the repository at this point in the history
Translate errors instead of throwing RetryErrors
  • Loading branch information
ahartel authored Aug 11, 2022
2 parents af0cb94 + 4a2f8af commit 4e04f45
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
7 changes: 7 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 2.2.1

### Bugfix

* Restore original error handling of HTTP status codes to before 2.2.0
* Add dedicated exception BusyError for status code 503

## 2.2.0

### New feature
Expand Down
43 changes: 25 additions & 18 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class BusyError(Exception):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class AlephAlphaClient:
def __init__(
self, host, token=None, email=None, password=None, request_timeout_seconds=180
Expand All @@ -38,6 +43,7 @@ def __init__(
backoff_factor=0.1,
status_forcelist=[408, 429, 500, 502, 503, 504],
allowed_methods=["POST", "GET"],
raise_on_status=False,
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.requests_session = requests.Session()
Expand All @@ -57,8 +63,7 @@ def __init__(

def get_version(self):
response = self.get_request(self.host + "version")
response.raise_for_status()
return response.text
return self._translate_errors(response).text

def get_token(self, email, password):
response = self.post_request(
Expand Down Expand Up @@ -94,7 +99,7 @@ def available_models(self):
response = self.get_request(
self.host + "models_available", headers=self.request_headers
)
return self._translate_errors(response)
return self._translate_errors(response).json()

def tokenize(
self, model: str, prompt: str, tokens: bool = True, token_ids: bool = True
Expand All @@ -113,7 +118,7 @@ def tokenize(
headers=self.request_headers,
json=payload,
)
return self._translate_errors(response)
return self._translate_errors(response).json()

def detokenize(self, model: str, token_ids: List[int]):
"""
Expand All @@ -125,7 +130,7 @@ def detokenize(self, model: str, token_ids: List[int]):
headers=self.request_headers,
json=payload,
)
return self._translate_errors(response)
return self._translate_errors(response).json()

def complete(
self,
Expand Down Expand Up @@ -299,7 +304,7 @@ def complete(
headers=self.request_headers,
json=payload,
)
response_json = self._translate_errors(response)
response_json = self._translate_errors(response).json()
if response_json.get("optimized_prompt") is not None:
# Return a message to the user that we optimized their prompt
print(
Expand Down Expand Up @@ -376,7 +381,7 @@ def embed(
response = self.post_request(
self.host + "embed", headers=self.request_headers, json=payload
)
return self._translate_errors(response)
return self._translate_errors(response).json()

def semantic_embed(
self,
Expand Down Expand Up @@ -417,7 +422,7 @@ def semantic_embed(
response = self.post_request(
self.host + "semantic_embed", headers=self.request_headers, json=payload
)
return self._translate_errors(response)
return self._translate_errors(response).json()

def evaluate(
self,
Expand Down Expand Up @@ -459,7 +464,7 @@ def evaluate(
response = self.post_request(
self.host + "evaluate", headers=self.request_headers, json=payload
)
return self._translate_errors(response)
return self._translate_errors(response).json()

def qa(
self,
Expand Down Expand Up @@ -537,7 +542,7 @@ def qa(
headers=self.request_headers,
json=payload,
)
response_json = self._translate_errors(response)
response_json = self._translate_errors(response).json()
return response_json

def _explain(
Expand All @@ -554,20 +559,22 @@ def _explain(
response = self.post_request(
f"{self.host}explain", headers=self.request_headers, json=body
)
return self._translate_errors(response)
return self._translate_errors(response).json()

@staticmethod
def _translate_errors(response: Response):
def _translate_errors(response: Response) -> Response:
if response.status_code == 200:
return response.json()
return response
else:
if response.status_code == 400:
raise ValueError(response.status_code, response.json())
raise ValueError(response.status_code, response.text)
elif response.status_code == 401:
raise PermissionError(response.status_code, response.json())
raise PermissionError(response.status_code, response.text)
elif response.status_code == 402:
raise QuotaError(response.status_code, response.json())
raise QuotaError(response.status_code, response.text)
elif response.status_code == 408:
raise TimeoutError(response.status_code, response.json())
raise TimeoutError(response.status_code, response.text)
elif response.status_code == 503:
raise BusyError(response.status_code, response.text)
else:
raise RuntimeError(response.status_code, response.json())
raise RuntimeError(response.status_code, response.text)
2 changes: 1 addition & 1 deletion aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.0"
__version__ = "2.2.1"
11 changes: 6 additions & 5 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient
from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient, BusyError
import pytest
import requests
from tests.common import client, model_name
from pytest_httpserver import HTTPServer


@pytest.mark.parametrize(
Expand Down Expand Up @@ -387,7 +388,7 @@ def httpserver_listen_address():
return ("127.0.0.1", 8000)


def test_timeout(httpserver):
def test_timeout(httpserver: HTTPServer):
def handler(foo):
time.sleep(2)

Expand All @@ -404,13 +405,13 @@ def test_retry_on_503(httpserver):
httpserver.expect_request("/version").respond_with_data("busy", status=503)

"""Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint."""
with pytest.raises(requests.exceptions.RetryError):
with pytest.raises(BusyError):
AlephAlphaClient(host="http://localhost:8000/", token="AA_TOKEN")


def test_retry_on_408(httpserver):
httpserver.expect_request("/version").respond_with_data("busy", status=408)
httpserver.expect_request("/version").respond_with_data("timeout", status=408)

"""Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint."""
with pytest.raises(requests.exceptions.RetryError):
with pytest.raises(TimeoutError):
AlephAlphaClient(host="http://localhost:8000/", token="AA_TOKEN")

0 comments on commit 4e04f45

Please sign in to comment.