Skip to content

Commit

Permalink
Implementing retry for remote connector to mitigate throttling issue (o…
Browse files Browse the repository at this point in the history
…pensearch-project#2462)

* use retryable action; execution context

Signed-off-by: zhichao-aws <[email protected]>

* change to groupedActionListener

Signed-off-by: zhichao-aws <[email protected]>

* fix group

Signed-off-by: zhichao-aws <[email protected]>

* retry policy

Signed-off-by: zhichao-aws <[email protected]>

* base time

Signed-off-by: zhichao-aws <[email protected]>

* retry option, cluster settings

Signed-off-by: zhichao-aws <[email protected]>

* nit

Signed-off-by: zhichao-aws <[email protected]>

* lint

Signed-off-by: zhichao-aws <[email protected]>

* change interface to class

Signed-off-by: zhichao-aws <[email protected]>

* fix ut due to code change

Signed-off-by: zhichao-aws <[email protected]>

* license header

Signed-off-by: zhichao-aws <[email protected]>

* add ut

Signed-off-by: zhichao-aws <[email protected]>

* add test

Signed-off-by: zhichao-aws <[email protected]>

* fix core interface

Signed-off-by: zhichao-aws <[email protected]>

* test

Signed-off-by: zhichao-aws <[email protected]>

* license header

Signed-off-by: zhichao-aws <[email protected]>

* use exception holder

Signed-off-by: zhichao-aws <[email protected]>

* add max retry times settings

Signed-off-by: zhichao-aws <[email protected]>

* fix typo

Signed-off-by: zhichao-aws <[email protected]>

* nit

Signed-off-by: zhichao-aws <[email protected]>

* change the order to avoid misleading log

Signed-off-by: zhichao-aws <[email protected]>

* license header

Signed-off-by: zhichao-aws <[email protected]>

* move settings to connector

Signed-off-by: zhichao-aws <[email protected]>

* remove settings

Signed-off-by: zhichao-aws <[email protected]>

* add test

Signed-off-by: zhichao-aws <[email protected]>

* add retry_backoff_policy setting

Signed-off-by: zhichao-aws <[email protected]>

* changes for comments

Signed-off-by: zhichao-aws <[email protected]>

* fix retry times

Signed-off-by: zhichao-aws <[email protected]>

* make the error handling more neat in MLSdkAsyncHttpResponseHandler

Signed-off-by: zhichao-aws <[email protected]>

* change to SageMakerThrottlingException

Signed-off-by: zhichao-aws <[email protected]>

* use enum for retry backoff policy

Signed-off-by: zhichao-aws <[email protected]>

* fix seconds to milliseconds in equal jitter policy

Signed-off-by: zhichao-aws <[email protected]>

* disable retry by default

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored Jun 6, 2024
1 parent 865a424 commit 399825f
Show file tree
Hide file tree
Showing 19 changed files with 940 additions and 441 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand All @@ -16,6 +17,9 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Objects;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -26,45 +30,88 @@ public class ConnectorClientConfig implements ToXContentObject, Writeable {
public static final String MAX_CONNECTION_FIELD = "max_connection";
public static final String CONNECTION_TIMEOUT_FIELD = "connection_timeout";
public static final String READ_TIMEOUT_FIELD = "read_timeout";
public static final String RETRY_BACKOFF_MILLIS_FIELD = "retry_backoff_millis";
public static final String RETRY_TIMEOUT_SECONDS_FIELD = "retry_timeout_seconds";
public static final String MAX_RETRY_TIMES_FIELD = "max_retry_times";
public static final String RETRY_BACKOFF_POLICY_FIELD = "retry_backoff_policy";

public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30);
public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);
public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);

public static final Integer RETRY_BACKOFF_MILLIS_DEFAULT_VALUE = 200;
public static final Integer RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE = 30;
public static final Integer MAX_RETRY_TIMES_DEFAULT_VALUE = 0;
public static final RetryBackoffPolicy RETRY_BACKOFF_POLICY_DEFAULT_VALUE = RetryBackoffPolicy.CONSTANT;
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_RETRY = Version.V_2_15_0;
private Integer maxConnections;
private Integer connectionTimeout;
private Integer readTimeout;
private Integer retryBackoffMillis;
private Integer retryTimeoutSeconds;
private Integer maxRetryTimes;
private RetryBackoffPolicy retryBackoffPolicy;

@Builder(toBuilder = true)
public ConnectorClientConfig(
Integer maxConnections,
Integer connectionTimeout,
Integer readTimeout
Integer readTimeout,
Integer retryBackoffMillis,
Integer retryTimeoutSeconds,
Integer maxRetryTimes,
RetryBackoffPolicy retryBackoffPolicy
) {
this.maxConnections = maxConnections;
this.connectionTimeout = connectionTimeout;
this.readTimeout = readTimeout;

this.retryBackoffMillis = retryBackoffMillis;
this.retryTimeoutSeconds = retryTimeoutSeconds;
this.maxRetryTimes = maxRetryTimes;
this.retryBackoffPolicy = retryBackoffPolicy;
}

public ConnectorClientConfig(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
this.maxConnections = input.readOptionalInt();
this.connectionTimeout = input.readOptionalInt();
this.readTimeout = input.readOptionalInt();
if(streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) {
this.retryBackoffMillis = input.readOptionalInt();
this.retryTimeoutSeconds = input.readOptionalInt();
this.maxRetryTimes = input.readOptionalInt();
if (input.readBoolean()) {
this.retryBackoffPolicy = RetryBackoffPolicy.from(input.readString());
}
}
}

public ConnectorClientConfig() {
this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
this.retryBackoffMillis = RETRY_BACKOFF_MILLIS_DEFAULT_VALUE;
this.retryTimeoutSeconds = RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE;
this.maxRetryTimes = MAX_RETRY_TIMES_DEFAULT_VALUE;
this.retryBackoffPolicy = RETRY_BACKOFF_POLICY_DEFAULT_VALUE;
}

@Override
public void writeTo(StreamOutput out) throws IOException {

Version streamOutputVersion = out.getVersion();
out.writeOptionalInt(maxConnections);
out.writeOptionalInt(connectionTimeout);
out.writeOptionalInt(readTimeout);
if(streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)){
out.writeOptionalInt(retryBackoffMillis);
out.writeOptionalInt(retryTimeoutSeconds);
out.writeOptionalInt(maxRetryTimes);
if (Objects.nonNull(retryBackoffPolicy)) {
out.writeBoolean(true);
out.writeString(retryBackoffPolicy.name());
} else {
out.writeBoolean(false);
}
}
}

@Override
Expand All @@ -79,6 +126,18 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (readTimeout != null) {
builder.field(READ_TIMEOUT_FIELD, readTimeout);
}
if (retryBackoffMillis != null) {
builder.field(RETRY_BACKOFF_MILLIS_FIELD, retryBackoffMillis);
}
if (retryTimeoutSeconds != null) {
builder.field(RETRY_TIMEOUT_SECONDS_FIELD, retryTimeoutSeconds);
}
if (maxRetryTimes != null) {
builder.field(MAX_RETRY_TIMES_FIELD, maxRetryTimes);
}
if (retryBackoffPolicy != null) {
builder.field(RETRY_BACKOFF_POLICY_FIELD, retryBackoffPolicy.name().toLowerCase(Locale.ROOT));
}
return builder.endObject();
}

Expand All @@ -88,9 +147,13 @@ public static ConnectorClientConfig fromStream(StreamInput in) throws IOExceptio
}

public static ConnectorClientConfig parse(XContentParser parser) throws IOException {
Integer maxConnections = null;
Integer connectionTimeout = null;
Integer readTimeout = null;
Integer maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
Integer connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
Integer readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
Integer retryBackoffMillis = RETRY_BACKOFF_MILLIS_DEFAULT_VALUE;
Integer retryTimeoutSeconds = RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE;
Integer maxRetryTimes = MAX_RETRY_TIMES_DEFAULT_VALUE;
RetryBackoffPolicy retryBackoffPolicy = RETRY_BACKOFF_POLICY_DEFAULT_VALUE;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -107,6 +170,18 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept
case READ_TIMEOUT_FIELD:
readTimeout = parser.intValue();
break;
case RETRY_BACKOFF_MILLIS_FIELD:
retryBackoffMillis = parser.intValue();
break;
case RETRY_TIMEOUT_SECONDS_FIELD:
retryTimeoutSeconds = parser.intValue();
break;
case MAX_RETRY_TIMES_FIELD:
maxRetryTimes = parser.intValue();
break;
case RETRY_BACKOFF_POLICY_FIELD:
retryBackoffPolicy = RetryBackoffPolicy.from(parser.text());
break;
default:
parser.skipChildren();
break;
Expand All @@ -116,6 +191,10 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept
.maxConnections(maxConnections)
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.retryBackoffMillis(retryBackoffMillis)
.retryTimeoutSeconds(retryTimeoutSeconds)
.maxRetryTimes(maxRetryTimes)
.retryBackoffPolicy(retryBackoffPolicy)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector;

import java.util.Locale;

public enum RetryBackoffPolicy {
CONSTANT,
EXPONENTIAL_EQUAL_JITTER,
EXPONENTIAL_FULL_JITTER;

public static RetryBackoffPolicy from(String value) {
try {
return RetryBackoffPolicy.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Unsupported retry backoff policy");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public static Map<String, String> convertScriptStringToJsonString(Map<String, Ob
Map<String, String> parameterStringMap = new HashMap<>();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Map<String, Object> parametersMap = (Map<String, Object>) processedInput.get("parameters");
Map<String, Object> parametersMap = (Map<String, Object>) processedInput.getOrDefault("parameters", Map.of());
for (String key : parametersMap.keySet()) {
if (parametersMap.get(key) instanceof String) {
parameterStringMap.put(key, (String) parametersMap.get(key));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public void toXContent_InternalConnector() throws IOException {
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
"\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," +
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}}",
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," +
"\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}",
mlModelContent);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand All @@ -23,6 +25,10 @@ public void writeTo_ReadFromStream() throws IOException {
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
.retryBackoffPolicy(RetryBackoffPolicy.CONSTANT)
.build();

BytesStreamOutput output = new BytesStreamOutput();
Expand All @@ -32,25 +38,71 @@ public void writeTo_ReadFromStream() throws IOException {
Assert.assertEquals(config, readConfig);
}

@Test
public void writeTo_ReadFromStream_nullValues() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.builder()
.build();

BytesStreamOutput output = new BytesStreamOutput();
config.writeTo(output);
ConnectorClientConfig readConfig = new ConnectorClientConfig(output.bytes().streamInput());

Assert.assertEquals(config, readConfig);
}

@Test
public void writeTo_ReadFromStream_diffVersionThenNotProcessRetryOptions() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
.retryBackoffPolicy(RetryBackoffPolicy.CONSTANT)
.build();

BytesStreamOutput output = new BytesStreamOutput();
output.setVersion(Version.V_2_14_0);
config.writeTo(output);
StreamInput input = output.bytes().streamInput();
input.setVersion(Version.V_2_14_0);
ConnectorClientConfig readConfig = ConnectorClientConfig.fromStream(input);

Assert.assertEquals(Integer.valueOf(10),readConfig.getMaxConnections());
Assert.assertEquals(Integer.valueOf(5000),readConfig.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(3000),readConfig.getReadTimeout());
Assert.assertNull(readConfig.getRetryBackoffMillis());
Assert.assertNull(readConfig.getRetryTimeoutSeconds());
Assert.assertNull(readConfig.getMaxRetryTimes());
Assert.assertNull(readConfig.getRetryBackoffPolicy());
}

@Test
public void toXContent() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
.retryBackoffPolicy(RetryBackoffPolicy.CONSTANT)
.build();

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000}";
String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," +
"\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}";
Assert.assertEquals(expectedJson, content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000}";
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," +
"\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();
Expand All @@ -60,6 +112,22 @@ public void parse() throws IOException {
Assert.assertEquals(Integer.valueOf(10), config.getMaxConnections());
Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(3000), config.getReadTimeout());
Assert.assertEquals(Integer.valueOf(123), config.getRetryBackoffMillis());
Assert.assertEquals(Integer.valueOf(456), config.getRetryTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(789), config.getMaxRetryTimes());
Assert.assertEquals(RetryBackoffPolicy.CONSTANT, config.getRetryBackoffPolicy());
}

@Test
public void parse_whenMalformedBackoffPolicy_thenFail() throws IOException {
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," +
"\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"test\"}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> ConnectorClientConfig.parse(parser));
Assert.assertEquals("Unsupported retry backoff policy", exception.getMessage());
}

@Test
Expand All @@ -69,6 +137,23 @@ public void testDefaultValues() {
Assert.assertNull(config.getMaxConnections());
Assert.assertNull(config.getConnectionTimeout());
Assert.assertNull(config.getReadTimeout());
Assert.assertNull(config.getRetryBackoffMillis());
Assert.assertNull(config.getRetryTimeoutSeconds());
Assert.assertNull(config.getMaxRetryTimes());
Assert.assertNull(config.getRetryBackoffPolicy());
}

@Test
public void testDefaultValuesInitByNewInstance() {
ConnectorClientConfig config = new ConnectorClientConfig();

Assert.assertEquals(Integer.valueOf(30),config.getMaxConnections());
Assert.assertEquals(Integer.valueOf(30000),config.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(30000),config.getReadTimeout());
Assert.assertEquals(Integer.valueOf(200),config.getRetryBackoffMillis());
Assert.assertEquals(Integer.valueOf(30),config.getRetryTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(0),config.getMaxRetryTimes());
Assert.assertEquals(RetryBackoffPolicy.CONSTANT, config.getRetryBackoffPolicy());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class HttpConnectorTest {
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
"\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," +
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}";
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," +
"\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

@Before
public void setUp() {
Expand Down Expand Up @@ -293,7 +294,7 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod
Map<String, String> credential = new HashMap<>();
credential.put("key", "test_key_value");

ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(30, 30000, 30000);
ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT);

HttpConnector connector = HttpConnector.builder()
.name("test_connector_name")
Expand Down
Loading

0 comments on commit 399825f

Please sign in to comment.