Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 0.5] Add rate limiter for bulk request #571

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,13 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.datasource.flint.auth.username`: basic auth username.
- `spark.datasource.flint.auth.password`: basic auth password.
- `spark.datasource.flint.region`: default is us-west-2. only been used when auth=sigv4
- `spark.datasource.flint.customAWSCredentialsProvider`: default is empty.
- `spark.datasource.flint.customAWSCredentialsProvider`: default is empty.
- `spark.datasource.flint.write.id_name`: no default value.
- `spark.datasource.flint.ignore.id_column` : default value is true.
- `spark.datasource.flint.write.batch_size`: "The number of documents written to Flint in a single batch request. Default value is Integer.MAX_VALUE.
- `spark.datasource.flint.write.batch_bytes`: The approximately amount of data in bytes written to Flint in a single batch request. The actual data write to OpenSearch may more than it. Default value is 1mb. The writing process checks after each document whether the total number of documents (docCount) has reached batch_size or the buffer size has surpassed batch_bytes. If either condition is met, the current batch is flushed and the document count resets to zero.
- `spark.datasource.flint.write.refresh_policy`: default value is false. valid values [NONE(false), IMMEDIATE(true), WAIT_UNTIL(wait_for)]
- `spark.datasource.flint.write.bulkRequestRateLimitPerNode`: [Experimental] Rate limit(request/sec) for bulk request per worker node. Only accept integer value. To reduce the traffic less than 1 req/sec, batch_bytes or batch_size should be reduced. Default value is 0, which disables rate limit.
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.datasource.flint.read.scroll_duration`: default value is 5 minutes. scroll context keep alive duration.
- `spark.datasource.flint.retry.max_retries`: max retries on failed HTTP request. default value is 3. Use 0 to disable retry.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.client.transport.rest_client.RestClientTransport;

import java.io.IOException;
import org.opensearch.flint.core.storage.BulkRequestRateLimiter;

import static org.opensearch.flint.core.metrics.MetricConstants.OS_READ_OP_METRIC_PREFIX;
import static org.opensearch.flint.core.metrics.MetricConstants.OS_WRITE_OP_METRIC_PREFIX;
Expand All @@ -47,6 +48,7 @@
*/
public class RestHighLevelClientWrapper implements IRestHighLevelClient {
private final RestHighLevelClient client;
private final BulkRequestRateLimiter rateLimiter;

private final static JacksonJsonpMapper JACKSON_MAPPER = new JacksonJsonpMapper();

Expand All @@ -55,13 +57,21 @@ public class RestHighLevelClientWrapper implements IRestHighLevelClient {
*
* @param client the RestHighLevelClient instance to wrap
*/
public RestHighLevelClientWrapper(RestHighLevelClient client) {
public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLimiter rateLimiter) {
this.client = client;
this.rateLimiter = rateLimiter;
}

@Override
public BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws IOException {
return execute(OS_WRITE_OP_METRIC_PREFIX, () -> client.bulk(bulkRequest, options));
return execute(OS_WRITE_OP_METRIC_PREFIX, () -> {
try {
rateLimiter.acquirePermit();
return client.bulk(bulkRequest, options);
} catch (InterruptedException e) {
throw new RuntimeException("rateLimiter.acquirePermit was interrupted.", e);
}
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ public class FlintOptions implements Serializable {

public static final String DEFAULT_SUPPORT_SHARD = "true";

public static final String BULK_REQUEST_RATE_LIMIT_PER_NODE = "bulkRequestRateLimitPerNode";
public static final String DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE = "0";

public FlintOptions(Map<String, String> options) {
this.options = options;
this.retryOptions = new FlintRetryOptions(options);
Expand Down Expand Up @@ -197,4 +200,8 @@ public boolean supportShard() {
return options.getOrDefault(SUPPORT_SHARD, DEFAULT_SUPPORT_SHARD).equalsIgnoreCase(
DEFAULT_SUPPORT_SHARD);
}

public long getBulkRequestRateLimitPerNode() {
return Long.parseLong(options.getOrDefault(BULK_REQUEST_RATE_LIMIT_PER_NODE, DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.opensearch.flint.core.storage;

import dev.failsafe.RateLimiter;
import java.time.Duration;
import java.util.logging.Logger;
import org.opensearch.flint.core.FlintOptions;

public class BulkRequestRateLimiter {
private static final Logger LOG = Logger.getLogger(BulkRequestRateLimiter.class.getName());
private RateLimiter<Void> rateLimiter;

public BulkRequestRateLimiter(FlintOptions flintOptions) {
long bulkRequestRateLimitPerNode = flintOptions.getBulkRequestRateLimitPerNode();
if (bulkRequestRateLimitPerNode > 0) {
LOG.info("Setting rate limit for bulk request to " + bulkRequestRateLimitPerNode + "/sec");
this.rateLimiter = RateLimiter.<Void>smoothBuilder(
flintOptions.getBulkRequestRateLimitPerNode(),
Duration.ofSeconds(1)).build();
} else {
LOG.info("Rate limit for bulk request was not set.");
}
}

// Wait so it won't exceed rate limit. Does nothing if rate limit is not set.
public void acquirePermit() throws InterruptedException {
if (rateLimiter != null) {
this.rateLimiter.acquirePermit();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.opensearch.flint.core.storage;

import org.opensearch.flint.core.FlintOptions;

/**
* Hold shared instance of BulkRequestRateLimiter. This class is introduced to make
* BulkRequestRateLimiter testable and share single instance.
*/
public class BulkRequestRateLimiterHolder {

private static BulkRequestRateLimiter instance;

private BulkRequestRateLimiterHolder() {}

public synchronized static BulkRequestRateLimiter getBulkRequestRateLimiter(
FlintOptions flintOptions) {
if (instance == null) {
instance = new BulkRequestRateLimiter(flintOptions);
}
return instance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public static RestHighLevelClient createRestHighLevelClient(FlintOptions options
}

public static IRestHighLevelClient createClient(FlintOptions options) {
return new RestHighLevelClientWrapper(createRestHighLevelClient(options));
return new RestHighLevelClientWrapper(createRestHighLevelClient(options),
BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(options));
}

private static RestClientBuilder configureSigV4Auth(RestClientBuilder restClientBuilder, FlintOptions options) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.opensearch.flint.core.storage;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;
import org.opensearch.flint.core.FlintOptions;

class BulkRequestRateLimiterHolderTest {
FlintOptions flintOptions = new FlintOptions(ImmutableMap.of());
@Test
public void getBulkRequestRateLimiter() {
BulkRequestRateLimiter instance0 = BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(flintOptions);
BulkRequestRateLimiter instance1 = BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(flintOptions);

assertNotNull(instance0);
assertEquals(instance0, instance1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.opensearch.flint.core.storage;


import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;
import org.opensearch.flint.core.FlintOptions;

class BulkRequestRateLimiterTest {
FlintOptions flintOptionsWithRateLimit = new FlintOptions(ImmutableMap.of(FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE, "1"));
FlintOptions flintOptionsWithoutRateLimit = new FlintOptions(ImmutableMap.of(FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE, "0"));

@Test
void acquirePermitWithRateConfig() throws Exception {
BulkRequestRateLimiter limiter = new BulkRequestRateLimiter(flintOptionsWithRateLimit);

assertTrue(timer(() -> {
limiter.acquirePermit();
limiter.acquirePermit();
}) >= 1000);
}

@Test
void acquirePermitWithoutRateConfig() throws Exception {
BulkRequestRateLimiter limiter = new BulkRequestRateLimiter(flintOptionsWithoutRateLimit);

assertTrue(timer(() -> {
limiter.acquirePermit();
limiter.acquirePermit();
}) < 100);
}

private interface Procedure {
void run() throws Exception;
}

private long timer(Procedure procedure) throws Exception {
long start = System.currentTimeMillis();
procedure.run();
long end = System.currentTimeMillis();
return end - start;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ object FlintSparkConf {
.doc("max retries on failed HTTP request, 0 means retry is disabled, default is 3")
.createWithDefault(String.valueOf(FlintRetryOptions.DEFAULT_MAX_RETRIES))

val BULK_REQUEST_RATE_LIMIT_PER_NODE =
FlintConfig(s"spark.datasource.flint.${FlintOptions.BULK_REQUEST_RATE_LIMIT_PER_NODE}")
.datasourceOption()
.doc("[Experimental] Rate limit (requests/sec) for bulk request per worker node. Rate won't be limited by default")
.createWithDefault(FlintOptions.DEFAULT_BULK_REQUEST_RATE_LIMIT_PER_NODE)

val RETRYABLE_HTTP_STATUS_CODES =
FlintConfig(s"spark.datasource.flint.${FlintRetryOptions.RETRYABLE_HTTP_STATUS_CODES}")
.datasourceOption()
Expand Down Expand Up @@ -275,6 +281,7 @@ case class FlintSparkConf(properties: JMap[String, String]) extends Serializable
AUTH,
MAX_RETRIES,
RETRYABLE_HTTP_STATUS_CODES,
BULK_REQUEST_RATE_LIMIT_PER_NODE,
REGION,
CUSTOM_AWS_CREDENTIALS_PROVIDER,
SERVICE_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.Optional

import scala.collection.JavaConverters._

import org.opensearch.flint.core.FlintOptions
import org.opensearch.flint.core.http.FlintRetryOptions._
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

Expand Down Expand Up @@ -63,6 +64,16 @@ class FlintSparkConfSuite extends FlintSuite {
retryOptions.getRetryableExceptionClassNames.get() shouldBe "java.net.ConnectException"
}

test("test bulkRequestRateLimitPerNode default value") {
val options = FlintSparkConf().flintOptions()
options.getBulkRequestRateLimitPerNode shouldBe 0
}

test("test specified bulkRequestRateLimitPerNode") {
val options = FlintSparkConf(Map("bulkRequestRateLimitPerNode" -> "5").asJava).flintOptions()
options.getBulkRequestRateLimitPerNode shouldBe 5
}

test("test metadata access AWS credentials provider option") {
withSparkConf("spark.metadata.accessAWSCredentialsProvider") {
spark.conf.set(
Expand Down
Loading