Skip to content

Commit

Permalink
Add retry on timeout
Browse files Browse the repository at this point in the history
... and remove references to cloud hosting
  • Loading branch information
ahartel authored Aug 10, 2022
2 parents 045b6f0 + 7008dfe commit af0cb94
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 30 deletions.
6 changes: 6 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 2.2.0

### New feature

* Retry failed HTTP requests via urllib for status codes 408, 429, 500, 502, 503, 504

## 2.1.0

### New feature
Expand Down
71 changes: 50 additions & 21 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import logging

from requests import Response
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

import aleph_alpha_client
from aleph_alpha_client.document import Document
from aleph_alpha_client.embedding import SemanticEmbeddingRequest
Expand All @@ -30,6 +33,17 @@ def __init__(

self.request_timeout_seconds = request_timeout_seconds

retry_strategy = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[408, 429, 500, 502, 503, 504],
allowed_methods=["POST", "GET"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.requests_session = requests.Session()
self.requests_session.mount("https://", adapter)
self.requests_session.mount("http://", adapter)

# check server version
expect_release = "1"
version = self.get_version()
Expand Down Expand Up @@ -57,10 +71,12 @@ def get_token(self, email, password):
raise ValueError("cannot get token")

def get_request(self, url, headers=None):
return requests.get(url, headers=headers, timeout=self.request_timeout_seconds)
return self.requests_session.get(
url, headers=headers, timeout=self.request_timeout_seconds
)

def post_request(self, url, json, headers=None):
return requests.post(
return self.requests_session.post(
url, headers=headers, json=json, timeout=self.request_timeout_seconds
)

Expand Down Expand Up @@ -115,7 +131,7 @@ def complete(
self,
model: str,
prompt: Union[str, List[Union[str, ImagePrompt]]] = "",
hosting: str = "cloud",
hosting: Optional[str] = None,
maximum_tokens: Optional[int] = 64,
temperature: Optional[float] = 0.0,
top_k: Optional[int] = 0,
Expand Down Expand Up @@ -145,8 +161,8 @@ def complete(
prompt (str, optional, default ""):
The text to be completed. Unconditional completion can be started with an empty string (default). The prompt may contain a zero shot or few shot task.
hosting (str, optional, default "cloud"):
Specifies where the computation will take place. This defaults to "cloud", meaning that it can be
hosting (str, optional, default None):
Specifies where the computation will take place. This defaults to None, meaning that it can be
executed on any of our servers. An error will be returned if the specified hosting is not available.
Check available_models() for available hostings.
Expand Down Expand Up @@ -255,7 +271,6 @@ def complete(
payload = {
"model": model,
"prompt": _to_serializable_prompt(prompt=prompt),
"hosting": hosting,
"maximum_tokens": maximum_tokens,
"temperature": temperature,
"top_k": top_k,
Expand All @@ -276,6 +291,9 @@ def complete(
"disable_optimizations": disable_optimizations,
}

if hosting is not None:
payload["hosting"] = hosting

response = self.post_request(
self.host + "complete",
headers=self.request_headers,
Expand All @@ -295,7 +313,7 @@ def embed(
prompt: Union[str, Sequence[Union[str, ImagePrompt]]],
pooling: List[str],
layers: List[int],
hosting: str = "cloud",
hosting: Optional[str] = None,
tokens: Optional[bool] = False,
type: Optional[str] = None,
):
Expand Down Expand Up @@ -323,8 +341,8 @@ def embed(
* last_token: just use the last token
* abs_max: aggregate token embeddings across the sequence dimension using a maximum of absolute values
hosting (str, optional, default "cloud"):
Specifies where the computation will take place. This defaults to "cloud", meaning that it can be
hosting (str, optional, default None):
Specifies where the computation will take place. This defaults to None, meaning that it can be
executed on any of our servers. An error will be returned if the specified hosting is not available.
Check available_models() for available hostings.
Expand All @@ -346,12 +364,15 @@ def embed(
payload = {
"model": model,
"prompt": serializable_prompt,
"hosting": hosting,
"layers": layers,
"tokens": tokens,
"pooling": pooling,
"type": type,
}

if hosting is not None:
payload["hosting"] = hosting

response = self.post_request(
self.host + "embed", headers=self.request_headers, json=payload
)
Expand All @@ -361,7 +382,7 @@ def semantic_embed(
self,
model: str,
request: SemanticEmbeddingRequest,
hosting: str = "cloud",
hosting: Optional[str] = None,
):
"""
Embeds a text and returns vectors that can be used for downstream tasks (e.g. semantic similarity) and models (e.g. classifiers).
Expand All @@ -371,7 +392,7 @@ def semantic_embed(
Name of model to use. A model name refers to a model architecture (number of parameters among others). Always the latest version of model is used. The model output contains information as to the model version.
hosting (str, required):
Specifies where the computation will take place. This defaults to "cloud", meaning that it can be
Specifies where the computation will take place. This defaults to None, meaning that it can be
executed on any of our servers. An error will be returned if the specified hosting is not available.
Check available_models() for available hostings.
Expand All @@ -385,11 +406,14 @@ def semantic_embed(

payload: Dict[str, Any] = {
"model": model,
"hosting": hosting,
"prompt": serializable_prompt,
"representation": request.representation.value,
"compress_to_size": request.compress_to_size,
}

if hosting is not None:
payload["hosting"] = hosting

response = self.post_request(
self.host + "semantic_embed", headers=self.request_headers, json=payload
)
Expand All @@ -399,7 +423,7 @@ def evaluate(
self,
model,
completion_expected,
hosting: str = "cloud",
hosting: Optional[str] = None,
prompt: Union[str, List[Union[str, ImagePrompt]]] = "",
):
"""
Expand All @@ -412,8 +436,8 @@ def evaluate(
completion_expected (str, required):
The ground truth completion expected to be produced given the prompt.
hosting (str, optional, default "cloud"):
Specifies where the computation will take place. This defaults to "cloud", meaning that it can be
hosting (str, optional, default None):
Specifies where the computation will take place. This defaults to None, meaning that it can be
executed on any of our servers. An error will be returned if the specified hosting is not available.
Check available_models() for available hostings.
Expand All @@ -426,9 +450,12 @@ def evaluate(
payload = {
"model": model,
"prompt": serializable_prompt,
"hosting": hosting,
"completion_expected": completion_expected,
}

if hosting is not None:
payload["hosting"] = hosting

response = self.post_request(
self.host + "evaluate", headers=self.request_headers, json=payload
)
Expand All @@ -444,7 +471,7 @@ def qa(
disable_optimizations: bool = False,
max_answers: int = 0,
min_score: float = 0.0,
hosting: str = "cloud",
hosting: Optional[str] = None,
):
"""
Answers a question about a prompt.
Expand Down Expand Up @@ -483,8 +510,8 @@ def qa(
min_score (float, default 0.0):
The lower limit of minimum score for every answer.
hosting (str, default "cloud"):
Specifies where the computation will take place. This defaults to "cloud", meaning that it can be
hosting (str, default None):
Specifies where the computation will take place. This defaults to None, meaning that it can be
executed on any of our servers. An error will be returned if the specified hosting is not available.
Check available_models() for available hostings.
"""
Expand All @@ -500,9 +527,11 @@ def qa(
"min_score": min_score,
"max_chunk_size": max_chunk_size,
"disable_optimizations": disable_optimizations,
"hosting": hosting,
}

if hosting is not None:
payload["hosting"] = hosting

response = self.post_request(
self.host + "qa",
headers=self.request_headers,
Expand Down
4 changes: 2 additions & 2 deletions aleph_alpha_client/aleph_alpha_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import ChainMap
from typing import Any, Mapping, Union
from typing import Any, Mapping, Optional, Union
from aleph_alpha_client.aleph_alpha_client import AlephAlphaClient
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.detokenization import (
Expand All @@ -20,7 +20,7 @@

class AlephAlphaModel:
def __init__(
self, client: AlephAlphaClient, model_name: str, hosting: str = "cloud"
self, client: AlephAlphaClient, model_name: str, hosting: Optional[str] = None
) -> None:
self.client = client
self.model_name = model_name
Expand Down
2 changes: 1 addition & 1 deletion aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.1.0"
__version__ = "2.2.0"
27 changes: 22 additions & 5 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,29 @@ def httpserver_listen_address():


def test_timeout(httpserver):
def handler(foo):
time.sleep(2)

httpserver.expect_request("/version").respond_with_handler(handler)

httpserver.expect_request("/version").respond_with_handler(
lambda request: time.sleep(2)
)
"""Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint."""
with pytest.raises(requests.exceptions.Timeout):
with pytest.raises(requests.exceptions.ConnectionError):
AlephAlphaClient(
host="http://localhost:8000/", token="AA_TOKEN", request_timeout_seconds=1
host="http://localhost:8000/", token="AA_TOKEN", request_timeout_seconds=0.1
)


def test_retry_on_503(httpserver):
httpserver.expect_request("/version").respond_with_data("busy", status=503)

"""Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint."""
with pytest.raises(requests.exceptions.RetryError):
AlephAlphaClient(host="http://localhost:8000/", token="AA_TOKEN")


def test_retry_on_408(httpserver):
httpserver.expect_request("/version").respond_with_data("busy", status=408)

"""Ensures Timeouts works. AlephAlphaClient constructor calls version endpoint."""
with pytest.raises(requests.exceptions.RetryError):
AlephAlphaClient(host="http://localhost:8000/", token="AA_TOKEN")
1 change: 0 additions & 1 deletion tests/test_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_qa_with_client(client: AlephAlphaClient):
# when posting a QA request with explicit parameters
response = client.qa(
model_name,
hosting="cloud",
query="Who likes pizza?",
documents=[Document.from_prompt(["Andreas likes pizza."])],
)
Expand Down

0 comments on commit af0cb94

Please sign in to comment.