diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 28ed1bc200..4acf7a2733 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -16,6 +16,7 @@ import java.util.List; import java.util.Map; +import org.apache.commons.collections.MapUtils; import org.apache.http.HttpStatus; import org.apache.logging.log4j.util.Strings; import org.opensearch.OpenSearchStatusException; @@ -80,6 +81,23 @@ public void onHeaders(SdkHttpResponse response) { SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse) response; log.debug("received response headers: " + sdkResponse.headers()); this.statusCode = sdkResponse.statusCode(); + if (MapUtils.isEmpty(sdkResponse.headers())) { + return; + } + List<String> errorsInHeader = sdkResponse.headers().get("x-amzn-ErrorType"); + if (errorsInHeader == null || errorsInHeader.isEmpty()) { + return; + } + boolean containsThrottlingException = errorsInHeader.stream().anyMatch(str -> str.startsWith("ThrottlingException")); + if (containsThrottlingException) { + actionListener + .onFailure( + new OpenSearchStatusException( + REMOTE_SERVICE_ERROR + "The request was denied due to request throttling.", + RestStatus.fromCode(statusCode) + ) + ); + } } @Override