From 78193f5f32ea0b3d9995de7b89e8d3ec92718909 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 22 Nov 2023 15:59:34 -0800 Subject: [PATCH] Add backoff retry capability in rest client (#170) * Add retry http client, builder and future Signed-off-by: Chen Dai * Add UT Signed-off-by: Chen Dai * Add more UT Signed-off-by: Chen Dai * Delete IT Signed-off-by: Chen Dai * Refactor UT Signed-off-by: Chen Dai * Replace class name check with instanceOf check Signed-off-by: Chen Dai * Make retryable exception list optional Signed-off-by: Chen Dai * Add retryable status code option and handler Signed-off-by: Chen Dai * Add retry enabled check Signed-off-by: Chen Dai * Add Spark conf and user manual Signed-off-by: Chen Dai * Add exception class name to Spark conf Signed-off-by: Chen Dai * Separate failure and result handler class Signed-off-by: Chen Dai * Refactor failure and result predicate Signed-off-by: Chen Dai * Add more UT Signed-off-by: Chen Dai * Reword user manual Signed-off-by: Chen Dai --------- Signed-off-by: Chen Dai --- build.sbt | 1 + docs/index.md | 3 + .../opensearch/flint/core/FlintOptions.java | 12 + .../flint/core/http/FlintRetryOptions.java | 110 ++++++++++ .../core/http/RetryableHttpAsyncClient.java | 143 ++++++++++++ .../ErrorStacktraceFailurePredicate.java | 47 ++++ .../ExceptionClassNameFailurePredicate.java | 70 ++++++ .../HttpStatusCodeResultPredicate.java | 46 ++++ .../core/storage/FlintOpenSearchClient.java | 22 +- .../http/RetryableHttpAsyncClientSuite.scala | 207 ++++++++++++++++++ .../sql/flint/config/FlintSparkConf.scala | 62 ++++-- .../flint/config/FlintSparkConfSuite.scala | 26 +++ 12 files changed, 727 insertions(+), 22 deletions(-) create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClient.java create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ErrorStacktraceFailurePredicate.java create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ExceptionClassNameFailurePredicate.java create mode 100644 flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java create mode 100644 flint-core/src/test/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClientSuite.scala diff --git a/build.sbt b/build.sbt index ccb735ae0..0798092af 100644 --- a/build.sbt +++ b/build.sbt @@ -57,6 +57,7 @@ lazy val flintCore = (project in file("flint-core")) "org.opensearch.client" % "opensearch-rest-client" % opensearchVersion, "org.opensearch.client" % "opensearch-rest-high-level-client" % opensearchVersion exclude ("org.apache.logging.log4j", "log4j-api"), + "dev.failsafe" % "failsafe" % "3.3.2", "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" exclude ("com.fasterxml.jackson.core", "jackson-databind"), "org.scalactic" %% "scalactic" % "3.2.15" % "test", diff --git a/docs/index.md b/docs/index.md index d0228cceb..03164e942 100644 --- a/docs/index.md +++ b/docs/index.md @@ -358,6 +358,9 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i (false), IMMEDIATE(true), WAIT_UNTIL(wait_for)] - `spark.datasource.flint.read.scroll_size`: default value is 100. - `spark.datasource.flint.read.scroll_duration`: default value is 5 minutes. scroll context keep alive duration. +- `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry. +- `spark.datasource.flint.retry.http_status_codes`: retryable HTTP response status code list. default value is "429,502" (429 Too Many Request and 502 Bad Gateway). +- `spark.datasource.flint.retry.exception_class_names`: retryable exception class name list. by default no retry on any exception thrown. - `spark.flint.optimizer.enabled`: default is true. - `spark.flint.index.hybridscan.enabled`: default is false. - `spark.flint.index.checkpoint.mandatory`: default is true. diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java index 52ba61192..8ce3054d9 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/FlintOptions.java @@ -5,8 +5,10 @@ package org.opensearch.flint.core; +import dev.failsafe.RetryPolicy; import java.io.Serializable; import java.util.Map; +import org.opensearch.flint.core.http.FlintRetryOptions; /** * Flint Options include all the flint related configuration. @@ -15,6 +17,11 @@ public class FlintOptions implements Serializable { private final Map options; + /** + * Flint options related to HTTP retry policy. + */ + private final FlintRetryOptions retryOptions; + public static final String HOST = "host"; public static final String PORT = "port"; @@ -68,6 +75,7 @@ public class FlintOptions implements Serializable { public FlintOptions(Map options) { this.options = options; + this.retryOptions = new FlintRetryOptions(options); } public String getHost() { @@ -88,6 +96,10 @@ public int getScrollDuration() { public String getRefreshPolicy() {return options.getOrDefault(REFRESH_POLICY, DEFAULT_REFRESH_POLICY);} + public FlintRetryOptions getRetryOptions() { + return retryOptions; + } + public String getRegion() { return options.getOrDefault(REGION, DEFAULT_REGION); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java new file mode 100644 index 000000000..7b6139014 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.http; + +import static java.time.temporal.ChronoUnit.SECONDS; + +import dev.failsafe.RetryPolicy; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; +import org.opensearch.flint.core.http.handler.ExceptionClassNameFailurePredicate; +import org.opensearch.flint.core.http.handler.HttpStatusCodeResultPredicate; + +/** + * Flint options related to HTTP request retry. + */ +public class FlintRetryOptions { + + private static final Logger LOG = Logger.getLogger(FlintRetryOptions.class.getName()); + + /** + * All Flint options. + */ + private final Map options; + + /** + * Maximum retry attempt + */ + public static final int DEFAULT_MAX_RETRIES = 3; + public static final String MAX_RETRIES = "retry.max_retries"; + + public static final String DEFAULT_RETRYABLE_HTTP_STATUS_CODES = "429,502"; + public static final String RETRYABLE_HTTP_STATUS_CODES = "retry.http_status_codes"; + + /** + * Retryable exception class name + */ + public static final String RETRYABLE_EXCEPTION_CLASS_NAMES = "retry.exception_class_names"; + + public FlintRetryOptions(Map options) { + this.options = options; + } + + /** + * Is auto retry capability enabled. + * + * @return true if enabled, otherwise false. + */ + public boolean isRetryEnabled() { + return getMaxRetries() > 0; + } + + /** + * Build retry policy based on the given Flint options. + * + * @param success execution result type + * @return Failsafe retry policy + */ + public RetryPolicy getRetryPolicy() { + return RetryPolicy.builder() + // Backoff strategy config (can be configurable as needed in future) + .withBackoff(1, 30, SECONDS) + .withJitter(Duration.ofMillis(100)) + // Failure handling config from Flint options + .withMaxRetries(getMaxRetries()) + .handleIf(ExceptionClassNameFailurePredicate.create(getRetryableExceptionClassNames())) + .handleResultIf(new HttpStatusCodeResultPredicate<>(getRetryableHttpStatusCodes())) + // Logging listener + .onFailedAttempt(event -> + LOG.severe("Attempt to execute request failed: " + event)) + .onRetry(ex -> + LOG.warning("Retrying failed request at #" + ex.getAttemptCount())) + .build(); + } + + /** + * @return maximum retry option value + */ + public int getMaxRetries() { + return Integer.parseInt( + options.getOrDefault(MAX_RETRIES, String.valueOf(DEFAULT_MAX_RETRIES))); + } + + /** + * @return retryable HTTP status code list + */ + public String getRetryableHttpStatusCodes() { + return options.getOrDefault(RETRYABLE_HTTP_STATUS_CODES, DEFAULT_RETRYABLE_HTTP_STATUS_CODES); + } + + /** + * @return retryable exception class name list + */ + public Optional getRetryableExceptionClassNames() { + return Optional.ofNullable(options.get(RETRYABLE_EXCEPTION_CLASS_NAMES)); + } + + @Override + public String toString() { + return "FlintRetryOptions{" + + "maxRetries=" + getMaxRetries() + + ", retryableStatusCodes=" + getRetryableHttpStatusCodes() + + ", retryableExceptionClassNames=" + getRetryableExceptionClassNames() + + '}'; + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClient.java new file mode 100644 index 000000000..c7c9258ec --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClient.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.http; + +import dev.failsafe.Failsafe; +import dev.failsafe.FailsafeException; +import java.io.IOException; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; +import org.apache.http.concurrent.FutureCallback; +import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; +import org.apache.http.nio.protocol.HttpAsyncRequestProducer; +import org.apache.http.nio.protocol.HttpAsyncResponseConsumer; +import org.apache.http.protocol.HttpContext; +import org.opensearch.flint.core.FlintOptions; + +/** + * HTTP client that retries request to tolerant transient fault. + */ +public class RetryableHttpAsyncClient extends CloseableHttpAsyncClient { + + private static final Logger LOG = Logger.getLogger(RetryableHttpAsyncClient.class.getName()); + + /** + * Delegated internal HTTP client that execute the request underlying. + */ + private final CloseableHttpAsyncClient internalClient; + + /** + * Flint retry options. + */ + private final FlintRetryOptions options; + + public RetryableHttpAsyncClient(CloseableHttpAsyncClient internalClient, + FlintRetryOptions options) { + this.internalClient = internalClient; + this.options = options; + } + + @Override + public boolean isRunning() { + return internalClient.isRunning(); + } + + @Override + public void start() { + internalClient.start(); + } + + @Override + public void close() throws IOException { + internalClient.close(); + } + + @Override + public Future execute(HttpAsyncRequestProducer requestProducer, + HttpAsyncResponseConsumer responseConsumer, + HttpContext context, + FutureCallback callback) { + return new Future<>() { + /** + * Delegated future object created per doExecuteAndFutureGetWithRetry() call which creates initial object too. + * In this way, we avoid the duplicate logic of first call and subsequent retry calls. + * Here the assumption is cancel, isCancelled and isDone never called before get(). + * (OpenSearch RestClient seems only call get() API) + */ + private Future delegate; + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return delegate.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return delegate.isCancelled(); + } + + @Override + public boolean isDone() { + return delegate.isDone(); + } + + @Override + public T get() throws InterruptedException, ExecutionException { + return doExecuteAndFutureGetWithRetry(() -> delegate.get()); + } + + @Override + public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException { + return doExecuteAndFutureGetWithRetry(() -> delegate.get(timeout, unit)); + } + + private T doExecuteAndFutureGetWithRetry(Callable futureGet) throws InterruptedException, ExecutionException { + try { + // Retry by creating a new Future object (as new delegate) and get its result again + return Failsafe + .with(options.getRetryPolicy()) + .get(() -> { + this.delegate = internalClient.execute(requestProducer, responseConsumer, context, callback); + return futureGet.call(); + }); + } catch (FailsafeException ex) { + LOG.severe("Request failed permanently. Re-throwing original exception."); + + // Failsafe will wrap checked exception, such as ExecutionException + // So here we have to unwrap failsafe exception and rethrow it + Throwable cause = ex.getCause(); + if (cause instanceof InterruptedException) { + throw (InterruptedException) cause; + } else if (cause instanceof ExecutionException) { + throw (ExecutionException) cause; + } else { + throw ex; + } + } + } + }; + } + + public static HttpAsyncClientBuilder builder(HttpAsyncClientBuilder delegate, FlintOptions options) { + FlintRetryOptions retryOptions = options.getRetryOptions(); + if (!retryOptions.isRetryEnabled()) { + return delegate; + } + + // Wrap original builder so created client will be wrapped by retryable client too + return new HttpAsyncClientBuilder() { + @Override + public CloseableHttpAsyncClient build() { + LOG.info("Building retryable http async client with options: " + retryOptions); + return new RetryableHttpAsyncClient(delegate.build(), retryOptions); + } + }; + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ErrorStacktraceFailurePredicate.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ErrorStacktraceFailurePredicate.java new file mode 100644 index 000000000..d4b04edea --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ErrorStacktraceFailurePredicate.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.http.handler; + +import dev.failsafe.function.CheckedPredicate; +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; + +/** + * Failure predicate that determines if retryable based on error stacktrace iteration. + */ +public abstract class ErrorStacktraceFailurePredicate implements CheckedPredicate { + + private static final Logger LOG = Logger.getLogger(ErrorStacktraceFailurePredicate.class.getName()); + + /** + * This base class implementation iterates the stacktrace and pass each exception + * to subclass for retryable decision. + */ + @Override + public boolean test(Throwable throwable) throws Throwable { + // Use extra set to Handle nested exception to avoid dead loop + Set seen = new HashSet<>(); + + while (throwable != null && seen.add(throwable)) { + LOG.info("Checking if exception retryable: " + throwable); + + if (isRetryable(throwable)) { + LOG.info("Exception is retryable: " + throwable); + return true; + } + throwable = throwable.getCause(); + } + + LOG.info("No retryable exception found on the stacktrace"); + return false; + } + + /** + * Is exception retryable decided by subclass implementation + */ + protected abstract boolean isRetryable(Throwable throwable); +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ExceptionClassNameFailurePredicate.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ExceptionClassNameFailurePredicate.java new file mode 100644 index 000000000..1c93e5c19 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/ExceptionClassNameFailurePredicate.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.http.handler; + +import static java.util.Collections.newSetFromMap; +import static java.util.logging.Level.SEVERE; + +import dev.failsafe.function.CheckedPredicate; +import java.util.Arrays; +import java.util.Optional; +import java.util.Set; +import java.util.WeakHashMap; +import java.util.logging.Logger; + +/** + * Failure handler based on exception class type check. + */ +public class ExceptionClassNameFailurePredicate extends ErrorStacktraceFailurePredicate { + + private static final Logger LOG = Logger.getLogger(ErrorStacktraceFailurePredicate.class.getName()); + + /** + * Retryable exception class types. + */ + private final Set> retryableExceptions; + + /** + * @return exception class handler or empty handler (treat any exception non-retryable) + */ + public static CheckedPredicate create(Optional exceptionClassNames) { + if (exceptionClassNames.isEmpty()) { + // This is required because Failsafe treats any Exception retryable by default + return ex -> false; + } + return new ExceptionClassNameFailurePredicate(exceptionClassNames.get()); + } + + public ExceptionClassNameFailurePredicate(String exceptionClassNames) { + // Use weak collection avoids blocking class unloading + this.retryableExceptions = newSetFromMap(new WeakHashMap<>()); + Arrays.stream(exceptionClassNames.split(",")) + .map(String::trim) + .map(this::loadClass) + .forEach(retryableExceptions::add); + } + + @Override + protected boolean isRetryable(Throwable throwable) { + for (Class retryable : retryableExceptions) { + if (retryable.isInstance(throwable)) { + return true; + } + } + return false; + } + + private Class loadClass(String className) { + try { + //noinspection unchecked + return (Class) Class.forName(className); + } catch (ClassNotFoundException e) { + String errorMsg = "Failed to load class " + className; + LOG.log(SEVERE, errorMsg, e); + throw new IllegalStateException(errorMsg); + } + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java new file mode 100644 index 000000000..fa82e3655 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/handler/HttpStatusCodeResultPredicate.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.http.handler; + +import dev.failsafe.function.CheckedPredicate; +import java.util.Arrays; +import java.util.Set; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import org.apache.http.HttpResponse; + +/** + * Failure handler based on status code in HTTP response. + * + * @param result type (supposed to be HttpResponse for OS client) + */ +public class HttpStatusCodeResultPredicate implements CheckedPredicate { + + private static final Logger LOG = Logger.getLogger(HttpStatusCodeResultPredicate.class.getName()); + + /** + * Retryable HTTP status code list + */ + private final Set retryableStatusCodes; + + public HttpStatusCodeResultPredicate(String httpStatusCodes) { + this.retryableStatusCodes = + Arrays.stream(httpStatusCodes.split(",")) + .map(String::trim) + .map(Integer::valueOf) + .collect(Collectors.toSet()); + } + + @Override + public boolean test(T result) throws Throwable { + int statusCode = ((HttpResponse) result).getStatusLine().getStatusCode(); + LOG.info("Checking if status code is retryable: " + statusCode); + + boolean isRetryable = retryableStatusCodes.contains(statusCode); + LOG.info("Status code " + statusCode + " check result: " + isRetryable); + return isRetryable; + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java index 92a749d86..e3ac49607 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/FlintOpenSearchClient.java @@ -25,6 +25,7 @@ import org.apache.http.auth.UsernamePasswordCredentials; import org.apache.http.client.CredentialsProvider; import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.client.RequestOptions; import org.opensearch.client.RestClient; @@ -42,6 +43,7 @@ import org.opensearch.flint.core.FlintClient; import org.opensearch.flint.core.FlintOptions; import org.opensearch.flint.core.auth.AWSRequestSigningApacheInterceptor; +import org.opensearch.flint.core.http.RetryableHttpAsyncClient; import org.opensearch.flint.core.metadata.FlintMetadata; import org.opensearch.flint.core.metadata.log.DefaultOptimisticTransaction; import org.opensearch.flint.core.metadata.log.FlintMetadataLogEntry; @@ -248,16 +250,26 @@ public RestHighLevelClient createClient() { throw new RuntimeException(e); } } - restClientBuilder.setHttpClientConfigCallback(cb -> - cb.addInterceptorLast(new AWSRequestSigningApacheInterceptor(signer.getServiceName(), - signer, awsCredentialsProvider.get()))); + restClientBuilder.setHttpClientConfigCallback(builder -> { + HttpAsyncClientBuilder delegate = + builder.addInterceptorLast( + new AWSRequestSigningApacheInterceptor( + signer.getServiceName(), signer, awsCredentialsProvider.get())); + return RetryableHttpAsyncClient.builder(delegate, options); + } + ); } else if (options.getAuth().equals(FlintOptions.BASIC_AUTH)) { CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); credentialsProvider.setCredentials( AuthScope.ANY, new UsernamePasswordCredentials(options.getUsername(), options.getPassword())); - restClientBuilder.setHttpClientConfigCallback( - httpClientBuilder -> httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider)); + restClientBuilder.setHttpClientConfigCallback(builder -> { + HttpAsyncClientBuilder delegate = builder.setDefaultCredentialsProvider(credentialsProvider); + return RetryableHttpAsyncClient.builder(delegate, options); + }); + } else { + restClientBuilder.setHttpClientConfigCallback(delegate -> + RetryableHttpAsyncClient.builder(delegate, options)); } return new RestHighLevelClient(restClientBuilder); } diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClientSuite.scala b/flint-core/src/test/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClientSuite.scala new file mode 100644 index 000000000..7d3b79a9e --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/http/RetryableHttpAsyncClientSuite.scala @@ -0,0 +1,207 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.http + +import java.net.{ConnectException, SocketTimeoutException} +import java.util +import java.util.Collections.emptyMap +import java.util.concurrent.{ExecutionException, Future} + +import scala.collection.JavaConverters.mapAsJavaMapConverter + +import org.apache.http.HttpResponse +import org.apache.http.concurrent.FutureCallback +import org.apache.http.impl.nio.client.{CloseableHttpAsyncClient, HttpAsyncClientBuilder} +import org.apache.http.nio.protocol.{HttpAsyncRequestProducer, HttpAsyncResponseConsumer} +import org.apache.http.protocol.HttpContext +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.mockito.verification.VerificationMode +import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.http.FlintRetryOptions.DEFAULT_MAX_RETRIES +import org.scalatest.BeforeAndAfter +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar.mock + +class RetryableHttpAsyncClientSuite extends AnyFlatSpec with BeforeAndAfter with Matchers { + + /** Mocked internal client and future callback */ + val internalClient: CloseableHttpAsyncClient = mock[CloseableHttpAsyncClient] + val future: Future[HttpResponse] = mock[Future[HttpResponse]] + + behavior of "Retryable HTTP async client" + + before { + when( + internalClient.execute( + any[HttpAsyncRequestProducer], + any[HttpAsyncResponseConsumer[HttpResponse]], + any[HttpContext], + any[FutureCallback[HttpResponse]])).thenReturn(future) + } + + after { + reset(internalClient, future) + } + + it should "return retry client builder by default" in { + val builder = mock[HttpAsyncClientBuilder] + val finalBuilder = RetryableHttpAsyncClient.builder(builder, new FlintOptions(emptyMap())) + + finalBuilder should not be builder + } + + it should "return retry client builder if retry enabled (max_retries > 0)" in { + val builder = mock[HttpAsyncClientBuilder] + val finalBuilder = RetryableHttpAsyncClient.builder( + builder, + new FlintOptions(Map("retry.max_retries" -> "5").asJava)) + + finalBuilder should not be builder + } + + it should "return original client builder if retry disabled (max_retries = 0)" in { + val builder = mock[HttpAsyncClientBuilder] + val finalBuilder = RetryableHttpAsyncClient.builder( + builder, + new FlintOptions(Map("retry.max_retries" -> "0").asJava)) + + finalBuilder shouldBe builder + } + + it should "retry if response code is on the retryable status code list" in { + Seq(429, 502).foreach { statusCode => + retryableClient + .whenStatusCode(statusCode) + .shouldExecute(times(DEFAULT_MAX_RETRIES + 1)) + } + } + + it should "not retry if response code is not on the retryable status code list" in { + retryableClient + .whenStatusCode(400) + .shouldExecute(times(1)) + } + + it should "not retry any exception by default" in { + retryableClient + .whenThrow(new ConnectException) + .shouldExecute(times(1)) + } + + it should "retry if exception is on the retryable exception list" in { + Seq(new ConnectException, new SocketTimeoutException).foreach { ex => + retryableClient + .withOption( + "retry.exception_class_names", + "java.net.ConnectException,java.net.SocketTimeoutException") + .whenThrow(ex) + .shouldExecute(times(DEFAULT_MAX_RETRIES + 1)) + } + } + + it should "retry if exception's root cause is on the retryable exception list" in { + retryableClient + .withOption("retry.exception_class_names", "java.net.ConnectException") + .whenThrow(new IllegalStateException(new ConnectException)) + .shouldExecute(times(DEFAULT_MAX_RETRIES + 1)) + } + + it should "not retry if exception is not on the retryable exception list" in { + retryableClient + .whenThrow(new SocketTimeoutException) + .shouldExecute(times(1)) + } + + it should "retry with configured max attempt count" in { + retryableClient + .withOption("retry.max_retries", "1") + .whenStatusCode(429) + .shouldExecute(times(2)) + } + + it should "return if retry successfully" in { + val response = mock[HttpResponse](RETURNS_DEEP_STUBS) + when(future.get()).thenReturn(response) + when(response.getStatusLine.getStatusCode) + .thenReturn(429) + .thenReturn(429) + .thenReturn(200) + + retryableClient + .shouldExecute(times(3)) + } + + // Exception like AmazonServiceException is thrown from interceptor in execute() directly + it should "retry too if exception thrown from execute instead of future get" in { + reset(internalClient) + when( + internalClient.execute( + any[HttpAsyncRequestProducer], + any[HttpAsyncResponseConsumer[HttpResponse]], + any[HttpContext], + any[FutureCallback[HttpResponse]])).thenThrow(new IllegalStateException) + + retryableClient + .withOption("retry.exception_class_names", "java.lang.IllegalStateException") + .shouldExecute( + expectExecuteTimes = times(DEFAULT_MAX_RETRIES + 1), + expectFutureGetTimes = times(0)) + } + + private def retryableClient: AssertionHelper = new AssertionHelper + + class AssertionHelper { + private val options: util.Map[String, String] = new util.HashMap[String, String]() + + def withOption(key: String, value: String): AssertionHelper = { + options.put(key, value) + this + } + + def whenThrow(throwable: Throwable): AssertionHelper = { + when(future.get()).thenThrow(new ExecutionException(throwable)) + this + } + + def whenStatusCode(statusCode: Int): AssertionHelper = { + val response = mock[HttpResponse](RETURNS_DEEP_STUBS) + when(response.getStatusLine.getStatusCode).thenReturn(statusCode) + when(future.get()).thenReturn(response) + this + } + + def shouldExecute(expectExecuteTimes: VerificationMode): Unit = { + shouldExecute(expectExecuteTimes, expectExecuteTimes) + } + + def shouldExecute( + expectExecuteTimes: VerificationMode, + expectFutureGetTimes: VerificationMode): Unit = { + val client = + new RetryableHttpAsyncClient(internalClient, new FlintOptions(options).getRetryOptions) + + try { + client.execute(null, null, null, null).get() + } catch { + case _: Throwable => // Ignore because we're testing error case + } finally { + // Verify `execute(...).get()` was called with expected times + verify(internalClient, expectExecuteTimes) + .execute( + any[HttpAsyncRequestProducer], + any[HttpAsyncResponseConsumer[HttpResponse]], + any[HttpContext], + any[FutureCallback[HttpResponse]]) + verify(future, expectFutureGetTimes).get() + + reset(future) + clearInvocations(internalClient) + } + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 2c42f9f20..c220b4b01 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -8,9 +8,10 @@ package org.apache.spark.sql.flint.config import java.util import java.util.{Map => JMap, NoSuchElementException} -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.mapAsJavaMapConverter import org.opensearch.flint.core.FlintOptions +import org.opensearch.flint.core.http.FlintRetryOptions import org.apache.spark.internal.config.ConfigReader import org.apache.spark.sql.flint.config.FlintSparkConf._ @@ -111,6 +112,23 @@ object FlintSparkConf { .doc("scroll duration in minutes") .createWithDefault(String.valueOf(FlintOptions.DEFAULT_SCROLL_DURATION)) + val MAX_RETRIES = FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.MAX_RETRIES}") + .datasourceOption() + .doc("max retries on failed HTTP request, 0 means retry is disabled, default is 3") + .createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_MAX_RETRIES)) + + val RETRYABLE_HTTP_STATUS_CODES = + FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.RETRYABLE_HTTP_STATUS_CODES}") + .datasourceOption() + .doc("retryable HTTP response status code list") + .createWithDefault(FlintRetryOptions.DEFAULT_RETRYABLE_HTTP_STATUS_CODES) + + val RETRYABLE_EXCEPTION_CLASS_NAMES = + FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.RETRYABLE_EXCEPTION_CLASS_NAMES}") + .datasourceOption() + .doc("retryable exception class name list, by default no retry on exception thrown") + .createOptional() + val OPTIMIZER_RULE_ENABLED = FlintConfig("spark.flint.optimizer.enabled") .doc("Enable Flint optimizer rule for query rewrite with Flint index") .createWithDefault("true") @@ -157,21 +175,31 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable * Helper class, create {@link FlintOptions}. */ def flintOptions(): FlintOptions = { - new FlintOptions( - Seq( - HOST_ENDPOINT, - HOST_PORT, - REFRESH_POLICY, - SCROLL_SIZE, - SCROLL_DURATION, - SCHEME, - AUTH, - REGION, - CUSTOM_AWS_CREDENTIALS_PROVIDER, - USERNAME, - PASSWORD) - .map(conf => (conf.optionKey, conf.readFrom(reader))) - .toMap - .asJava) + val optionsWithDefault = Seq( + HOST_ENDPOINT, + HOST_PORT, + REFRESH_POLICY, + SCROLL_SIZE, + SCROLL_DURATION, + SCHEME, + AUTH, + MAX_RETRIES, + RETRYABLE_HTTP_STATUS_CODES, + REGION, + CUSTOM_AWS_CREDENTIALS_PROVIDER, + USERNAME, + PASSWORD) + .map(conf => (conf.optionKey, conf.readFrom(reader))) + .toMap + + val optionsWithoutDefault = Seq(RETRYABLE_EXCEPTION_CLASS_NAMES) + .map(conf => (conf.optionKey, conf.readFrom(reader))) + .flatMap { + case (_, None) => None + case (key, value) => Some(key, value.get) + } + .toMap + + new FlintOptions((optionsWithDefault ++ optionsWithoutDefault).asJava) } } diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala index c15cf1073..149e8128b 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/sql/flint/config/FlintSparkConfSuite.scala @@ -5,8 +5,13 @@ package org.apache.spark.sql.flint.config +import java.util.Optional + import scala.collection.JavaConverters._ +import org.opensearch.flint.core.http.FlintRetryOptions._ +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + import org.apache.spark.FlintSuite class FlintSparkConfSuite extends FlintSuite { @@ -36,6 +41,27 @@ class FlintSparkConfSuite extends FlintSuite { assert(options.flintOptions().getPort == 9200) } + test("test retry options default values") { + val retryOptions = FlintSparkConf().flintOptions().getRetryOptions + retryOptions.getMaxRetries shouldBe DEFAULT_MAX_RETRIES + retryOptions.getRetryableHttpStatusCodes shouldBe DEFAULT_RETRYABLE_HTTP_STATUS_CODES + retryOptions.getRetryableExceptionClassNames shouldBe Optional.empty + } + + test("test specified retry options") { + val retryOptions = FlintSparkConf( + Map( + "retry.max_retries" -> "5", + "retry.http_status_codes" -> "429,502,503,504", + "retry.exception_class_names" -> "java.net.ConnectException").asJava) + .flintOptions() + .getRetryOptions + + retryOptions.getMaxRetries shouldBe 5 + retryOptions.getRetryableHttpStatusCodes shouldBe "429,502,503,504" + retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException" + } + /** * Delete index `indexNames` after calling `f`. */