diff --git a/Changelog.md b/Changelog.md index ec18054..600f7f7 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,9 @@ # Changelog +## Unreleased + +* Add `Error::Unavailable` to decouple service unavailability from 'queue full' 503 responses. + ## 0.7.1 * Add `Client::tokenize` and `Client::detokenize`. Thanks to @andreaskoepf diff --git a/src/http.rs b/src/http.rs index 1f91928..adefc02 100644 --- a/src/http.rs +++ b/src/http.rs @@ -160,9 +160,19 @@ async fn translate_http_error(response: reqwest::Response) -> Result = serde_json::from_str(&body); let translated_error = match status { StatusCode::TOO_MANY_REQUESTS => Error::TooManyRequests, - StatusCode::SERVICE_UNAVAILABLE => Error::Busy, + StatusCode::SERVICE_UNAVAILABLE => { + // Presence of `api_error` implies the error originated from the API itself (rather + // than the intermediate proxy) and so we can decode it as such. + if api_error.is_ok_and(|error| error.code == "QUEUE_FULL") { + Error::Busy + } else { + Error::Unavailable + } + } _ => Error::Http { status: status.as_u16(), body, @@ -175,14 +185,14 @@ async fn translate_http_error(response: reqwest::Response) -> Result { /// Unique string in capital letters emitted by the API to signal different kinds of errors in a /// finer granularity then the HTTP status codes alone would allow for. /// /// E.g. Differentiating between request rate limiting and parallel tasks limiting which both /// are 429 (the former is emitted by NGinx though). - _code: Cow<'a, str>, + code: Cow<'a, str>, } /// Errors returned by the Aleph Alpha Client @@ -202,6 +212,12 @@ pub enum Error { welcome to retry your request any time." )] Busy, + /// The API itself is unavailable, most likely due to restart. + #[error( + "The service is currently unavailable. This is likely due to restart. Please try again \ + later." + )] + Unavailable, #[error("No response received within given timeout: {0:?}")] ClientTimeout(Duration), /// An error on the Http Protocol level. diff --git a/tests/unit.rs b/tests/unit.rs index 28a9725..e6b2726 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -93,12 +93,12 @@ async fn detect_queue_full() { // Start a background HTTP server on a random local part let mock_server = MockServer::start().await; - let answer = r#"{ - "error":"Sorry we had to reject your request because we could not guarantee to finish it in - a reasonable timeframe. This specific model is very busy at this moment. Try again later - or use another model.", - "code":"QUEUE_FULL" - }"#; + let answer = "{ + \"error\":\"Sorry we had to reject your request because we could not guarantee to finish \ + it in a reasonable timeframe. This specific model is very busy at this moment. Try \ + again later or use another model.\", + \"code\":\"QUEUE_FULL\" + }"; let body = r#"{ "model": "luminous-base", "prompt": [{"type": "text", "data": "Hello,"}], @@ -124,9 +124,47 @@ async fn detect_queue_full() { .await .unwrap_err(); + // Then assert!(matches!(error, Error::Busy)); } +/// If the API is down, we want to detect this scenario and inform the user. +#[tokio::test] +async fn detect_service_unavailable() { + // Given + + // Start a background HTTP server on a random local part + let mock_server = MockServer::start().await; + + let answer = "No server is available to handle this request."; + let body = r#"{ + "model": "luminous-base", + "prompt": [{"type": "text", "data": "Hello,"}], + "maximum_tokens": 1 + }"#; + + Mock::given(method("POST")) + .and(path("/complete")) + .and(header("Authorization", "Bearer dummy-token")) + .and(header("Content-Type", "application/json")) + .and(body_json_string(body)) + .respond_with(ResponseTemplate::new(503).set_body_string(answer)) + .mount(&mock_server) + .await; + + // When + let task = TaskCompletion::from_text("Hello,", 1); + let model = "luminous-base"; + let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap(); + let error = client + .output_of(&task.with_model(model), &How::default()) + .await + .unwrap_err(); + + // Then + assert!(matches!(error, Error::Unavailable)); +} + /// Should set `nice=true` in query URL in order to tell the server we do not need our result right /// now in a high stress situation. #[tokio::test]