Skip to content

Commit

Permalink
Refactor to use interface so default behavior is unaffected. Add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Chase Engelbrecht <[email protected]>
  • Loading branch information
engechas committed Mar 7, 2024
1 parent b4b8279 commit afd6527
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public class ConnectionConfiguration {
public static final String NETWORK_POLICY_NAME = "network_policy_name";
public static final String VPCE_ID = "vpce_id";
public static final String REQUEST_COMPRESSION_ENABLED = "enable_request_compression";
public static final String CLIENTS = "clients";
public static final String CONCURRENT_REQUESTS = "concurrent_requests";

/**
* The valid port range per https://tools.ietf.org/html/rfc6335.
Expand All @@ -118,7 +118,7 @@ public class ConnectionConfiguration {
private final String serverlessCollectionName;
private final String serverlessVpceId;
private final boolean requestCompressionEnabled;
private final Integer clients;
private final Integer concurrentRequests;

List<String> getHosts() {
return hosts;
Expand Down Expand Up @@ -180,8 +180,8 @@ boolean isRequestCompressionEnabled() {
return requestCompressionEnabled;
}

Integer getClients() {
return clients;
Integer getConcurrentRequests() {
return concurrentRequests;
}

private ConnectionConfiguration(final Builder builder) {
Expand All @@ -204,7 +204,7 @@ private ConnectionConfiguration(final Builder builder) {
this.serverlessVpceId = builder.serverlessVpceId;
this.requestCompressionEnabled = builder.requestCompressionEnabled;
this.pipelineName = builder.pipelineName;
this.clients = builder.clients;
this.concurrentRequests = builder.concurrentRequests;
}

public static ConnectionConfiguration readConnectionConfiguration(final PluginSetting pluginSetting){
Expand Down Expand Up @@ -282,8 +282,8 @@ public static ConnectionConfiguration readConnectionConfiguration(final PluginSe
REQUEST_COMPRESSION_ENABLED, !DistributionVersion.ES6.equals(distributionVersion));
builder = builder.withRequestCompressionEnabled(requestCompressionEnabled);

final Integer clients = pluginSetting.getIntegerOrDefault(CLIENTS, 1);
builder = builder.withClients(clients);
final Integer concurrentRequests = pluginSetting.getIntegerOrDefault(CONCURRENT_REQUESTS, -1);
builder = builder.withConcurrentRequests(concurrentRequests);

return builder.build();
}
Expand Down Expand Up @@ -518,7 +518,7 @@ public static class Builder {
private String serverlessCollectionName;
private String serverlessVpceId;
private boolean requestCompressionEnabled;
private Integer clients;
private Integer concurrentRequests;

private void validateStsRoleArn(final String awsStsRoleArn) {
final Arn arn = getArn(awsStsRoleArn);
Expand Down Expand Up @@ -648,8 +648,8 @@ public Builder withRequestCompressionEnabled(final boolean requestCompressionEna
return this;
}

public Builder withClients(final Integer clients) {
this.clients = clients;
public Builder withConcurrentRequests(final Integer concurrentRequests) {
this.concurrentRequests = concurrentRequests;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.BulkApiWrapper;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.BulkApiWrapperFactory;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.BulkOperationWriter;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.InlineRequestSender;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.JavaClientAccumulatingCompressedBulkRequest;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.JavaClientAccumulatingUncompressedBulkRequest;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.ConcurrentRequestSender;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.RequestSender;
import org.opensearch.dataprepper.plugins.sink.opensearch.bulk.SerializedJson;
import org.opensearch.dataprepper.plugins.sink.opensearch.dlq.FailedBulkOperation;
import org.opensearch.dataprepper.plugins.sink.opensearch.dlq.FailedBulkOperationConverter;
Expand Down Expand Up @@ -80,11 +83,6 @@
import java.util.Optional;
import java.util.StringJoiner;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -192,7 +190,12 @@ public OpenSearchSink(final PluginSetting pluginSetting,
dlqProvider = pluginFactory.loadPlugin(DlqProvider.class, dlqPluginSetting);
}

this.requestSender = new RequestSender(openSearchSinkConfig.getConnectionConfiguration().getClients());
final int concurrentRequests = openSearchSinkConfig.getConnectionConfiguration().getConcurrentRequests();
if (concurrentRequests > 0) {
this.requestSender = new ConcurrentRequestSender(concurrentRequests);
} else {
this.requestSender = new InlineRequestSender();
}
}

@Override
Expand Down Expand Up @@ -499,10 +502,10 @@ SerializedJson getDocument(final Event event) {
}

private void flushBatch(AccumulatingBulkRequest accumulatingBulkRequest) {
requestSender.sendRequest(() -> doFlushBatch(accumulatingBulkRequest));
requestSender.sendRequest(this::doFlushBatch, accumulatingBulkRequest);
}

private Void doFlushBatch(AccumulatingBulkRequest accumulatingBulkRequest) {
private void doFlushBatch(AccumulatingBulkRequest accumulatingBulkRequest) {
bulkRequestTimer.record(() -> {
try {
LOG.debug("Sending data to OpenSearch");
Expand All @@ -514,8 +517,6 @@ private Void doFlushBatch(AccumulatingBulkRequest accumulatingBulkRequest) {
Thread.currentThread().interrupt();
}
});

return null;
}

private void logFailureForBulkRequests(final List<FailedBulkOperation> failedBulkOperations, final Throwable failure) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.opensearch.dataprepper.plugins.sink.opensearch;
package org.opensearch.dataprepper.plugins.sink.opensearch.bulk;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -9,49 +10,63 @@
import java.util.concurrent.Callable;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;

public class RequestSender {
private static final Logger LOG = LoggerFactory.getLogger(RequestSender.class);
public class ConcurrentRequestSender implements RequestSender {
private static final Logger LOG = LoggerFactory.getLogger(ConcurrentRequestSender.class);

private final List<Future<Void>> pendingRequestFutures;
private final ExecutorService requestExecutor;
private final CompletionService<Void> completionService;
private final int concurrentRequestCount;
private final ReentrantLock reentrantLock;

public RequestSender(final int concurrentRequestCount) {
public ConcurrentRequestSender(final int concurrentRequestCount) {
this.concurrentRequestCount = concurrentRequestCount;
pendingRequestFutures = new ArrayList<>();
requestExecutor = Executors.newFixedThreadPool(concurrentRequestCount);
completionService = new ExecutorCompletionService(requestExecutor);
completionService = new ExecutorCompletionService(Executors.newFixedThreadPool(concurrentRequestCount));
reentrantLock = new ReentrantLock();
}

public void sendRequest(final Callable<Void> requestRunnable) {
@VisibleForTesting
ConcurrentRequestSender(final int concurrentRequestCount, final CompletionService<Void> completionService) {
this.concurrentRequestCount = concurrentRequestCount;
pendingRequestFutures = new ArrayList<>();
this.completionService = completionService;
reentrantLock = new ReentrantLock();
}

@Override
public void sendRequest(final Consumer<AccumulatingBulkRequest> requestConsumer, final AccumulatingBulkRequest request) {
reentrantLock.lock();

if (pendingRequestFutures.size() >= concurrentRequestCount) {
if (isRequestQueueFull()) {
waitForRequestSlot();
}

final Future<Void> future = completionService.submit(requestRunnable);
final Future<Void> future = completionService.submit(convertConsumerIntoCallable(requestConsumer, request));
pendingRequestFutures.add(future);

reentrantLock.unlock();
}

private Callable<Void> convertConsumerIntoCallable(final Consumer<AccumulatingBulkRequest> requestConsumer, final AccumulatingBulkRequest request) {
return () -> {
requestConsumer.accept(request);
return null;
};
}

private void waitForRequestSlot() {
do {
checkFutureCompletion();
if (isRequestQueueFull()) {
try {
LOG.info("Request queue is full, waiting for slot to free up");
LOG.debug("Request queue is full, waiting for slot to free up");
completionService.take();
} catch (InterruptedException e) {
} catch (final Exception e) {
LOG.error("Interrupted while waiting for future completion");
}
}
Expand All @@ -70,12 +85,9 @@ private void checkFutureCompletion() {
future.get();
} catch (final Exception e) {
LOG.error("Indexing future was cancelled", e);
iterator.remove();
return;
}
}

if (future.isDone()) {
iterator.remove();
} else if (future.isDone()) {
iterator.remove();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.opensearch.dataprepper.plugins.sink.opensearch.bulk;

import java.util.function.Consumer;

public class InlineRequestSender implements RequestSender {

@Override
public void sendRequest(final Consumer<AccumulatingBulkRequest> requestConsumer, final AccumulatingBulkRequest request) {
requestConsumer.accept(request);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.opensearch.dataprepper.plugins.sink.opensearch.bulk;

import java.util.function.Consumer;

public interface RequestSender {
/**
* Executes the provided request with the provided consumer
*
* @param requestConsumer - the consumer function of the request that performs the work to execute the request
* @param request - the request to be consumed
*/
void sendRequest(Consumer<AccumulatingBulkRequest> requestConsumer, AccumulatingBulkRequest request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ void testReadConnectionConfigurationDefault() {
assertNull(connectionConfiguration.getSocketTimeout());
assertEquals(TEST_PIPELINE_NAME, connectionConfiguration.getPipelineName());
assertTrue(connectionConfiguration.isRequestCompressionEnabled());
assertEquals(-1, connectionConfiguration.getConcurrentRequests());
}

@Test
Expand Down Expand Up @@ -648,6 +649,21 @@ void testCreateClient_WithConnectionConfigurationBuilder_ProxyOptionalObjectShou
client.close();
}

@Test
void testConcurrentRequestsSetting() {
final Map<String, Object> metadata = new HashMap<>();
metadata.put("hosts", TEST_HOSTS);
metadata.put("username", UUID.randomUUID().toString());
metadata.put("password", UUID.randomUUID().toString());
metadata.put("connect_timeout", 1);
metadata.put("socket_timeout", 1);
metadata.put("concurrent_requests", 32);
final PluginSetting pluginSetting = getPluginSettingByConfigurationMetadata(metadata);
final ConnectionConfiguration connectionConfiguration = ConnectionConfiguration.readConnectionConfiguration(pluginSetting);

assertThat(connectionConfiguration.getConcurrentRequests(), equalTo(32));
}

private PluginSetting generatePluginSetting(
final List<String> hosts, final String username, final String password,
final Integer connectTimeout, final Integer socketTimeout, final boolean awsSigv4, final String awsRegion,
Expand Down
Loading

0 comments on commit afd6527

Please sign in to comment.