From 28590236243c0db04d8f8e4498743de3566dff79 Mon Sep 17 00:00:00 2001 From: Taylor Gray Date: Thu, 18 Apr 2024 11:09:26 -0500 Subject: [PATCH] Change s3 sink client to async client (#4425) Signed-off-by: Taylor Gray --- .../plugins/sink/s3/S3SinkServiceIT.java | 8 +- .../plugins/codec/parquet/S3OutputStream.java | 120 ++++--- .../plugins/sink/s3/ClientFactory.java | 13 +- .../dataprepper/plugins/sink/s3/S3Sink.java | 6 +- .../plugins/sink/s3/S3SinkService.java | 117 +++--- .../plugins/sink/s3/accumulator/Buffer.java | 6 +- .../sink/s3/accumulator/BufferFactory.java | 4 +- .../sink/s3/accumulator/BufferUtilities.java | 65 ++-- .../sink/s3/accumulator/CodecBuffer.java | 7 +- .../s3/accumulator/CodecBufferFactory.java | 4 +- .../s3/accumulator/CompressionBuffer.java | 7 +- .../accumulator/CompressionBufferFactory.java | 4 +- .../sink/s3/accumulator/InMemoryBuffer.java | 17 +- .../s3/accumulator/InMemoryBufferFactory.java | 4 +- .../sink/s3/accumulator/LocalFileBuffer.java | 22 +- .../accumulator/LocalFileBufferFactory.java | 4 +- .../sink/s3/accumulator/MultipartBuffer.java | 7 +- .../accumulator/MultipartBufferFactory.java | 4 +- .../sink/s3/grouping/S3GroupManager.java | 6 +- .../codec/parquet/S3OutputStreamTest.java | 107 +++++- .../plugins/sink/s3/S3SinkServiceTest.java | 337 ++++++++---------- .../s3/accumulator/BufferUtilitiesTest.java | 95 ++++- .../sink/s3/accumulator/CodecBufferTest.java | 11 +- .../CompressionBufferFactoryTest.java | 4 +- .../s3/accumulator/CompressionBufferTest.java | 11 +- .../s3/accumulator/InMemoryBufferTest.java | 72 ++-- .../s3/accumulator/LocalFileBufferTest.java | 72 ++-- .../sink/s3/grouping/S3GroupManagerTest.java | 4 +- 28 files changed, 665 insertions(+), 473 deletions(-) diff --git a/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceIT.java b/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceIT.java index 9a958e035d..a9ab424eee 100644 --- a/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceIT.java +++ b/data-prepper-plugins/s3-sink/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceIT.java @@ -64,6 +64,7 @@ import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; @@ -104,6 +105,8 @@ class S3SinkServiceIT { private static final String PATH_PREFIX = UUID.randomUUID() + "/%{yyyy}/%{MM}/%{dd}/"; private static final int numberOfRecords = 2; private S3Client s3Client; + + private S3AsyncClient s3AsyncClient; private String bucketName; private String s3region; private ParquetOutputCodecConfig parquetOutputCodecConfig; @@ -152,6 +155,7 @@ public void setUp() { s3region = System.getProperty("tests.s3sink.region"); s3Client = S3Client.builder().region(Region.of(s3region)).build(); + s3AsyncClient = S3AsyncClient.builder().region(Region.of(s3region)).build(); bucketName = System.getProperty("tests.s3sink.bucket"); bufferFactory = new InMemoryBufferFactory(); @@ -266,9 +270,9 @@ void verify_flushed_records_into_s3_bucketNewLine_with_compression() throws IOEx private S3SinkService createObjectUnderTest() { OutputCodecContext codecContext = new OutputCodecContext("Tag", Collections.emptyList(), Collections.emptyList()); final S3GroupIdentifierFactory groupIdentifierFactory = new S3GroupIdentifierFactory(keyGenerator, expressionEvaluator, s3SinkConfig); - s3GroupManager = new S3GroupManager(s3SinkConfig, groupIdentifierFactory, bufferFactory, codecFactory, s3Client); + s3GroupManager = new S3GroupManager(s3SinkConfig, groupIdentifierFactory, bufferFactory, codecFactory, s3AsyncClient); - return new S3SinkService(s3SinkConfig, codecContext, s3Client, keyGenerator, Duration.ofSeconds(5), pluginMetrics, s3GroupManager); + return new S3SinkService(s3SinkConfig, codecContext, Duration.ofSeconds(5), pluginMetrics, s3GroupManager); } private int gets3ObjectCount() { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStream.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStream.java index 8e2d2aa68f..1c71b9e6db 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStream.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStream.java @@ -8,22 +8,29 @@ import org.apache.parquet.io.PositionOutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.NoSuchBucketException; -import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Consumer; import java.util.function.Supplier; public class S3OutputStream extends PositionOutputStream { @@ -51,7 +58,7 @@ public class S3OutputStream extends PositionOutputStream { */ private final byte[] buf; - private final S3Client s3Client; + private final S3AsyncClient s3Client; /** * Collection of the etags for the parts that have been uploaded */ @@ -74,6 +81,8 @@ public class S3OutputStream extends PositionOutputStream { */ private final String defaultBucket; + private final ExecutorService executorService; + /** * Creates a new S3 OutputStream * @@ -81,7 +90,7 @@ public class S3OutputStream extends PositionOutputStream { * @param bucketSupplier name of the bucket * @param keySupplier path within the bucket */ - public S3OutputStream(final S3Client s3Client, + public S3OutputStream(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { @@ -93,13 +102,18 @@ public S3OutputStream(final S3Client s3Client, etags = new ArrayList<>(); open = true; this.defaultBucket = defaultBucket; + this.executorService = Executors.newSingleThreadExecutor(); } @Override public void write(int b) { assertOpen(); if (position >= buf.length) { - flushBufferAndRewind(); + try { + flushBufferAndRewind(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } } buf[position++] = (byte) b; } @@ -132,7 +146,12 @@ public void write(byte[] byteArray, int o, int l) { while (len > (size = buf.length - position)) { System.arraycopy(byteArray, ofs, buf, position, size); position += size; - flushBufferAndRewind(); + try { + flushBufferAndRewind(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } + ofs += size; len -= size; } @@ -147,36 +166,48 @@ public void write(byte[] byteArray, int o, int l) { public void flush() { } - @Override - public void close() { + public CompletableFuture close(final Consumer runOnCompletion, final Consumer runOnError) { if (open) { open = false; - possiblyStartMultipartUpload(); - if (position > 0) { - uploadPart(); - } + try { + possiblyStartMultipartUpload(); + + if (position > 0) { + uploadPart(); + } + + CompletedPart[] completedParts = new CompletedPart[etags.size()]; + for (int i = 0; i < etags.size(); i++) { + completedParts[i] = CompletedPart.builder() + .eTag(etags.get(i)) + .partNumber(i + 1) + .build(); + } - CompletedPart[] completedParts = new CompletedPart[etags.size()]; - for (int i = 0; i < etags.size(); i++) { - completedParts[i] = CompletedPart.builder() - .eTag(etags.get(i)) - .partNumber(i + 1) + LOG.debug("Completing S3 multipart upload with {} parts.", completedParts.length); + + CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder() + .parts(completedParts) .build(); - } + CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .multipartUpload(completedMultipartUpload) + .build(); + CompletableFuture multipartUploadResponseCompletableFuture = s3Client.completeMultipartUpload(completeMultipartUploadRequest); - LOG.debug("Completing S3 multipart upload with {} parts.", completedParts.length); - - CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder() - .parts(completedParts) - .build(); - CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() - .bucket(bucket) - .key(key) - .uploadId(uploadId) - .multipartUpload(completedMultipartUpload) - .build(); - s3Client.completeMultipartUpload(completeMultipartUploadRequest); + multipartUploadResponseCompletableFuture.join(); + + runOnCompletion.accept(true); + return multipartUploadResponseCompletableFuture; + } catch (final Exception e) { + runOnError.accept(e); + runOnCompletion.accept(false); + } } + + return null; } public String getKey() { @@ -189,7 +220,7 @@ private void assertOpen() { } } - private void flushBufferAndRewind() { + private void flushBufferAndRewind() throws ExecutionException, InterruptedException { possiblyStartMultipartUpload(); uploadPart(); position = 0; @@ -200,10 +231,11 @@ private void possiblyStartMultipartUpload() { try { createMultipartUpload(); - } catch (final S3Exception e) { - if (defaultBucket != null && (e instanceof NoSuchBucketException || e.getMessage().contains(ACCESS_DENIED))) { + } catch (final CompletionException e) { + if (defaultBucket != null && (e.getCause() != null && + (e.getCause() instanceof NoSuchBucketException || (e.getCause().getMessage() != null && e.getCause().getMessage().contains(ACCESS_DENIED))))) { bucket = defaultBucket; - LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}", bucket, defaultBucket); + LOG.warn("Bucket {} could not be accessed to create multi-part upload, attempting to create multi-part upload to default_bucket {}. Error: {}", bucket, defaultBucket, e.getCause().getMessage()); createMultipartUpload(); } else { throw e; @@ -223,12 +255,17 @@ private void uploadPart() { .partNumber(partNumber) .contentLength((long) position) .build(); - RequestBody requestBody = RequestBody.fromInputStream(new ByteArrayInputStream(buf, 0, position), - position); + + final InputStream inputStream = new ByteArrayInputStream(buf, 0, position); + + AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromInputStream(inputStream, (long) position, executorService); LOG.debug("Writing {} bytes to S3 multipart part number {}.", buf.length, partNumber); - UploadPartResponse uploadPartResponse = s3Client.uploadPart(uploadRequest, requestBody); + CompletableFuture uploadPartResponseFuture = s3Client.uploadPart(uploadRequest, asyncRequestBody); + + final UploadPartResponse uploadPartResponse = uploadPartResponseFuture.join(); + etags.add(uploadPartResponse.eTag()); } @@ -242,8 +279,11 @@ private void createMultipartUpload() { .bucket(bucket) .key(key) .build(); - CreateMultipartUploadResponse multipartUpload = s3Client.createMultipartUpload(uploadRequest); - uploadId = multipartUpload.uploadId(); + CompletableFuture multipartUpload = s3Client.createMultipartUpload(uploadRequest); + + final CreateMultipartUploadResponse response = multipartUpload.join(); + + uploadId = response.uploadId(); } } diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java index 3d6b8fc12b..910f3966cc 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/ClientFactory.java @@ -11,6 +11,7 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.retry.RetryPolicy; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; public final class ClientFactory { @@ -26,8 +27,18 @@ static S3Client createS3Client(final S3SinkConfig s3SinkConfig, final AwsCredent .overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build(); } + static S3AsyncClient createS3AsyncClient(final S3SinkConfig s3SinkConfig, final AwsCredentialsSupplier awsCredentialsSupplier) { + final AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(s3SinkConfig.getAwsAuthenticationOptions()); + final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions); + + return S3AsyncClient.builder() + .region(s3SinkConfig.getAwsAuthenticationOptions().getAwsRegion()) + .credentialsProvider(awsCredentialsProvider) + .overrideConfiguration(createOverrideConfiguration(s3SinkConfig)).build(); + } + private static ClientOverrideConfiguration createOverrideConfiguration(final S3SinkConfig s3SinkConfig) { - final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries()).build(); + final RetryPolicy retryPolicy = RetryPolicy.builder().numRetries(s3SinkConfig.getMaxConnectionRetries() * s3SinkConfig.getMaxUploadRetries()).build(); return ClientOverrideConfiguration.builder() .retryPolicy(retryPolicy) .build(); diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3Sink.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3Sink.java index 18fee25c93..e1dd406eb1 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3Sink.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3Sink.java @@ -33,7 +33,7 @@ import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.time.Duration; import java.util.Collection; @@ -77,7 +77,7 @@ public S3Sink(final PluginSetting pluginSetting, final OutputCodec testCodec = pluginFactory.loadPlugin(OutputCodec.class, codecPluginSettings); sinkInitialized = Boolean.FALSE; - final S3Client s3Client = ClientFactory.createS3Client(s3SinkConfig, awsCredentialsSupplier); + final S3AsyncClient s3Client = ClientFactory.createS3AsyncClient(s3SinkConfig, awsCredentialsSupplier); BufferFactory innerBufferFactory = s3SinkConfig.getBufferType().getBufferFactory(); if(testCodec instanceof ParquetOutputCodec && s3SinkConfig.getBufferType() != BufferTypeOptions.INMEMORY) { throw new InvalidPluginConfigurationException("The Parquet sink codec is an in_memory buffer only."); @@ -115,7 +115,7 @@ public S3Sink(final PluginSetting pluginSetting, final S3GroupManager s3GroupManager = new S3GroupManager(s3SinkConfig, s3GroupIdentifierFactory, bufferFactory, codecFactory, s3Client); - s3SinkService = new S3SinkService(s3SinkConfig, s3OutputCodecContext, s3Client, keyGenerator, RETRY_FLUSH_BACKOFF, pluginMetrics, s3GroupManager); + s3SinkService = new S3SinkService(s3SinkConfig, s3OutputCodecContext, RETRY_FLUSH_BACKOFF, pluginMetrics, s3GroupManager); } @Override diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkService.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkService.java index 15c8d71177..c0b7c18db5 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkService.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkService.java @@ -18,8 +18,6 @@ import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.awscore.exception.AwsServiceException; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3Client; import java.io.IOException; @@ -27,8 +25,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; /** * Class responsible for create {@link S3Client} object, check thresholds, @@ -48,7 +49,6 @@ public class S3SinkService { static final String S3_OBJECTS_SIZE = "s3SinkObjectSizeBytes"; private final S3SinkConfig s3SinkConfig; private final Lock reentrantLock; - private final S3Client s3Client; private final int maxEvents; private final ByteCount maxBytes; private final Duration maxCollectionDuration; @@ -62,23 +62,21 @@ public class S3SinkService { private final Counter numberOfObjectsForceFlushed; private final OutputCodecContext codecContext; - private final KeyGenerator keyGenerator; private final Duration retrySleepTime; private final S3GroupManager s3GroupManager; /** * @param s3SinkConfig s3 sink related configuration. - * @param s3Client * @param pluginMetrics metrics. */ public S3SinkService(final S3SinkConfig s3SinkConfig, - final OutputCodecContext codecContext, final S3Client s3Client, final KeyGenerator keyGenerator, - final Duration retrySleepTime, final PluginMetrics pluginMetrics, final S3GroupManager s3GroupManager) { + final OutputCodecContext codecContext, + final Duration retrySleepTime, + final PluginMetrics pluginMetrics, + final S3GroupManager s3GroupManager) { this.s3SinkConfig = s3SinkConfig; - this.s3Client = s3Client; this.codecContext = codecContext; - this.keyGenerator = keyGenerator; this.retrySleepTime = retrySleepTime; reentrantLock = new ReentrantLock(); @@ -114,6 +112,7 @@ void output(Collection> records) { Exception sampleException = null; reentrantLock.lock(); try { + final List> completableFutures = new ArrayList<>(); for (Record record : records) { final Event event = record.getData(); try { @@ -130,11 +129,7 @@ void output(Collection> records) { currentBuffer.setEventCount(count); s3Group.addEventHandle(event.getEventHandle()); - final boolean flushed = flushToS3IfNeeded(s3Group, false); - - if (flushed) { - s3GroupManager.removeGroup(s3Group); - } + flushToS3IfNeeded(completableFutures, s3Group, false); } catch (Exception ex) { if(sampleException == null) { sampleException = ex; @@ -145,15 +140,22 @@ void output(Collection> records) { } for (final S3Group s3Group : s3GroupManager.getS3GroupEntries()) { - final boolean flushed = flushToS3IfNeeded(s3Group, false); - - if (flushed) { - s3GroupManager.removeGroup(s3Group); - } + flushToS3IfNeeded(completableFutures, s3Group, false); } if (s3SinkConfig.getAggregateThresholdOptions() != null) { - checkAggregateThresholdsAndFlushIfNeeded(); + checkAggregateThresholdsAndFlushIfNeeded(completableFutures); + } + + if (!completableFutures.isEmpty()) { + try { + CompletableFuture.allOf(completableFutures.toArray(new CompletableFuture[0])) + .thenRun(() -> LOG.debug("All {} requests to S3 have completed", completableFutures.size())) + .join(); + } catch (final Exception e) { + LOG.warn("There was an exception while waiting for all requests to complete", e); + } + } } finally { reentrantLock.unlock(); @@ -171,28 +173,37 @@ void output(Collection> records) { /** * @return whether the flush was attempted */ - private boolean flushToS3IfNeeded(final S3Group s3Group, final boolean forceFlush) { + private boolean flushToS3IfNeeded(final List> completableFutures, final S3Group s3Group, final boolean forceFlush) { LOG.trace("Flush to S3 check: currentBuffer.size={}, currentBuffer.events={}, currentBuffer.duration={}", s3Group.getBuffer().getSize(), s3Group.getBuffer().getEventCount(), s3Group.getBuffer().getDuration()); if (forceFlush || ThresholdCheck.checkThresholdExceed(s3Group.getBuffer(), maxEvents, maxBytes, maxCollectionDuration)) { + + s3GroupManager.removeGroup(s3Group); try { + s3Group.getOutputCodec().complete(s3Group.getBuffer().getOutputStream()); String s3Key = s3Group.getBuffer().getKey(); LOG.info("Writing {} to S3 with {} events and size of {} bytes.", s3Key, s3Group.getBuffer().getEventCount(), s3Group.getBuffer().getSize()); - final boolean isFlushToS3 = retryFlushToS3(s3Group.getBuffer(), s3Key); - if (isFlushToS3) { - LOG.info("Successfully saved {} to S3.", s3Key); - numberOfRecordsSuccessCounter.increment(s3Group.getBuffer().getEventCount()); - objectsSucceededCounter.increment(); - s3ObjectSizeSummary.record(s3Group.getBuffer().getSize()); - s3Group.releaseEventHandles(true); - } else { - LOG.error("Failed to save {} to S3.", s3Key); - numberOfRecordsFailedCounter.increment(s3Group.getBuffer().getEventCount()); - objectsFailedCounter.increment(); - s3Group.releaseEventHandles(false); - } + + final Consumer consumeOnGroupCompletion = (success) -> { + if (success) { + + LOG.info("Successfully saved {} to S3.", s3Key); + numberOfRecordsSuccessCounter.increment(s3Group.getBuffer().getEventCount()); + objectsSucceededCounter.increment(); + s3ObjectSizeSummary.record(s3Group.getBuffer().getSize()); + s3Group.releaseEventHandles(true); + } else { + LOG.error("Failed to save {} to S3.", s3Key); + numberOfRecordsFailedCounter.increment(s3Group.getBuffer().getEventCount()); + objectsFailedCounter.increment(); + s3Group.releaseEventHandles(false); + } + }; + + final Optional> completableFuture = s3Group.getBuffer().flushToS3(consumeOnGroupCompletion, this::handleFailures); + completableFuture.ifPresent(completableFutures::add); return true; } catch (final IOException e) { @@ -203,40 +214,11 @@ private boolean flushToS3IfNeeded(final S3Group s3Group, final boolean forceFlus return false; } - /** - * perform retry in-case any issue occurred, based on max_upload_retries configuration. - * - * @param currentBuffer current buffer. - * @param s3Key - * @return boolean based on object upload status. - */ - protected boolean retryFlushToS3(final Buffer currentBuffer, final String s3Key) { - boolean isUploadedToS3 = Boolean.FALSE; - int retryCount = maxRetries; - do { - try { - currentBuffer.flushToS3(); - isUploadedToS3 = Boolean.TRUE; - } catch (AwsServiceException | SdkClientException e) { - LOG.error("Exception occurred while uploading records to s3 bucket. Retry countdown : {} | exception:", - retryCount, e); - LOG.info("Error Message {}", e.getMessage()); - --retryCount; - if (retryCount == 0) { - return isUploadedToS3; - } - - try { - Thread.sleep(retrySleepTime.toMillis()); - } catch (final InterruptedException ex) { - LOG.warn("Interrupted while backing off before retrying S3 upload", ex); - } - } - } while (!isUploadedToS3); - return isUploadedToS3; + private void handleFailures(final Throwable e) { + LOG.error("Exception occurred while uploading records to s3 bucket: {}", e.getMessage()); } - private void checkAggregateThresholdsAndFlushIfNeeded() { + private void checkAggregateThresholdsAndFlushIfNeeded(final List> completableFutures) { long currentTotalGroupSize = s3GroupManager.recalculateAndGetGroupSize(); LOG.debug("Total groups size is {} bytes", currentTotalGroupSize); @@ -249,12 +231,11 @@ private void checkAggregateThresholdsAndFlushIfNeeded() { for (final S3Group s3Group : s3GroupManager.getS3GroupsSortedBySize()) { LOG.info("Forcing a flush of object with key {} due to aggregate_threshold of {} bytes being reached", s3Group.getBuffer().getKey(), aggregateThresholdBytes); - boolean flushed = flushToS3IfNeeded(s3Group, true); + final boolean flushed = flushToS3IfNeeded(completableFutures, s3Group, true); numberOfObjectsForceFlushed.increment(); if (flushed) { currentTotalGroupSize -= s3Group.getBuffer().getSize(); - s3GroupManager.removeGroup(s3Group); } if (currentTotalGroupSize <= aggregateThresholdBytes * aggregateThresholdFlushRatio) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/Buffer.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/Buffer.java index f6031cd3a4..90dd60fd91 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/Buffer.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/Buffer.java @@ -5,8 +5,12 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; + import java.io.OutputStream; import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; /** * A buffer can hold data before flushing it to S3. @@ -21,7 +25,7 @@ public interface Buffer { Duration getDuration(); - void flushToS3(); + Optional> flushToS3(final Consumer consumeOnGroupCompletion, final Consumer runOnFailure); OutputStream getOutputStream(); diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferFactory.java index ae8e503c08..7447182383 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferFactory.java @@ -5,10 +5,10 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.function.Supplier; public interface BufferFactory { - Buffer getBuffer(S3Client s3Client, Supplier bucketSupplier, Supplier keySupplier, String defaultBucket); + Buffer getBuffer(S3AsyncClient s3Client, Supplier bucketSupplier, Supplier keySupplier, String defaultBucket); } diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilities.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilities.java index ae02fd01bc..0f075906ae 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilities.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilities.java @@ -7,11 +7,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.NoSuchBucketException; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.function.Function; class BufferUtilities { @@ -20,25 +24,46 @@ class BufferUtilities { static final String ACCESS_DENIED = "Access Denied"; static final String INVALID_BUCKET = "The specified bucket is not valid"; - static void putObjectOrSendToDefaultBucket(final S3Client s3Client, - final RequestBody requestBody, + static CompletableFuture putObjectOrSendToDefaultBucket(final S3AsyncClient s3Client, + final AsyncRequestBody requestBody, + final Consumer runOnCompletion, + final Consumer runOnFailure, final String objectKey, final String targetBucket, final String defaultBucket) { - try { - s3Client.putObject( - PutObjectRequest.builder().bucket(targetBucket).key(objectKey).build(), - requestBody); - } catch (final S3Exception e) { - if (defaultBucket != null && - (e instanceof NoSuchBucketException || e.getMessage().contains(ACCESS_DENIED) || e.getMessage().contains(INVALID_BUCKET))) { - LOG.warn("Bucket {} could not be accessed, attempting to send to default_bucket {}", targetBucket, defaultBucket); - s3Client.putObject( - PutObjectRequest.builder().bucket(defaultBucket).key(objectKey).build(), - requestBody); - } else { - throw e; - } - } + + final boolean[] defaultBucketAttempted = new boolean[1]; + return s3Client.putObject( + PutObjectRequest.builder().bucket(targetBucket).key(objectKey).build(), requestBody) + .handle((result, ex) -> { + if (ex != null) { + runOnFailure.accept(ex); + + if (defaultBucket != null && + (ex instanceof NoSuchBucketException || ex.getMessage().contains(ACCESS_DENIED) || ex.getMessage().contains(INVALID_BUCKET))) { + LOG.warn("Bucket {} could not be accessed, attempting to send to default_bucket {}", targetBucket, defaultBucket); + defaultBucketAttempted[0] = true; + return s3Client.putObject( + PutObjectRequest.builder().bucket(defaultBucket).key(objectKey).build(), + requestBody); + } else { + runOnCompletion.accept(false); + return CompletableFuture.completedFuture(result); + } + } + + runOnCompletion.accept(true); + return CompletableFuture.completedFuture(result); + }) + .thenCompose(Function.identity()) + .whenComplete((res, ex) -> { + if (ex != null) { + runOnFailure.accept(ex); + } + + if (defaultBucketAttempted[0]) { + runOnCompletion.accept(ex == null); + } + }); } } diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBuffer.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBuffer.java index 452bcf5599..fbf3b2d8c3 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBuffer.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBuffer.java @@ -4,6 +4,9 @@ import java.io.OutputStream; import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; public class CodecBuffer implements Buffer { private final Buffer innerBuffer; @@ -31,8 +34,8 @@ public Duration getDuration() { } @Override - public void flushToS3() { - innerBuffer.flushToS3(); + public Optional> flushToS3(final Consumer runOnCompletion, final Consumer runOnException) { + return innerBuffer.flushToS3(runOnCompletion, runOnException); } @Override diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferFactory.java index fb8c51fa86..f27f8a2642 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferFactory.java @@ -1,7 +1,7 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; import org.opensearch.dataprepper.plugins.sink.s3.codec.BufferedCodec; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.function.Supplier; @@ -15,7 +15,7 @@ public CodecBufferFactory(BufferFactory innerBufferFactory, BufferedCodec codec) } @Override - public Buffer getBuffer(final S3Client s3Client, + public Buffer getBuffer(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBuffer.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBuffer.java index a0db9d8d38..f5e9b402f0 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBuffer.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBuffer.java @@ -11,6 +11,9 @@ import java.io.OutputStream; import java.time.Duration; import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; class CompressionBuffer implements Buffer { private final Buffer innerBuffer; @@ -38,8 +41,8 @@ public Duration getDuration() { } @Override - public void flushToS3() { - innerBuffer.flushToS3(); + public Optional> flushToS3(final Consumer runOnCompletion, final Consumer runOnException) { + return innerBuffer.flushToS3(runOnCompletion, runOnException); } @Override diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactory.java index fe7eb55f3d..b0341f5bc8 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactory.java @@ -7,7 +7,7 @@ import org.opensearch.dataprepper.model.codec.OutputCodec; import org.opensearch.dataprepper.plugins.sink.s3.compression.CompressionEngine; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.Objects; import java.util.function.Supplier; @@ -24,7 +24,7 @@ public CompressionBufferFactory(final BufferFactory innerBufferFactory, final Co } @Override - public Buffer getBuffer(final S3Client s3Client, + public Buffer getBuffer(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBuffer.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBuffer.java index 3adef41731..122c6b7e0c 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBuffer.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBuffer.java @@ -6,13 +6,16 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; import org.apache.commons.lang3.time.StopWatch; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.io.ByteArrayOutputStream; import java.io.OutputStream; import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import java.util.function.Supplier; /** @@ -22,7 +25,7 @@ public class InMemoryBuffer implements Buffer { private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); private final ByteArrayPositionOutputStream byteArrayPositionOutputStream = new ByteArrayPositionOutputStream(byteArrayOutputStream); - private final S3Client s3Client; + private final S3AsyncClient s3Client; private final Supplier bucketSupplier; private final Supplier keySupplier; private int eventCount; @@ -33,7 +36,7 @@ public class InMemoryBuffer implements Buffer { private String defaultBucket; - InMemoryBuffer(final S3Client s3Client, + InMemoryBuffer(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { @@ -66,9 +69,11 @@ public Duration getDuration() { * Upload accumulated data to s3 bucket. */ @Override - public void flushToS3() { + public Optional> flushToS3(final Consumer consumeOnCompletion, final Consumer consumeOnException) { final byte[] byteArray = byteArrayOutputStream.toByteArray(); - BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, RequestBody.fromBytes(byteArray), getKey(), getBucket(), defaultBucket); + return Optional.ofNullable(BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, AsyncRequestBody.fromBytes(byteArray), + consumeOnCompletion, consumeOnException, + getKey(), getBucket(), defaultBucket)); } private String getBucket() { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferFactory.java index bc15664289..32d56fd7c9 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferFactory.java @@ -5,13 +5,13 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.function.Supplier; public class InMemoryBufferFactory implements BufferFactory { @Override - public Buffer getBuffer(final S3Client s3Client, + public Buffer getBuffer(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBuffer.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBuffer.java index 570946b008..550ee4702e 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBuffer.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBuffer.java @@ -8,8 +8,9 @@ import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.BufferedOutputStream; import java.io.File; @@ -20,7 +21,10 @@ import java.nio.file.Files; import java.nio.file.Paths; import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import java.util.function.Supplier; /** @@ -30,7 +34,7 @@ public class LocalFileBuffer implements Buffer { private static final Logger LOG = LoggerFactory.getLogger(LocalFileBuffer.class); private final OutputStream outputStream; - private final S3Client s3Client; + private final S3AsyncClient s3Client; private final Supplier bucketSupplier; private final Supplier keySupplier; private int eventCount; @@ -44,7 +48,7 @@ public class LocalFileBuffer implements Buffer { LocalFileBuffer(final File tempFile, - final S3Client s3Client, + final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) throws FileNotFoundException { @@ -84,10 +88,14 @@ public Duration getDuration(){ * Upload accumulated data to amazon s3. */ @Override - public void flushToS3() { + public Optional> flushToS3(final Consumer consumeOnCompletion, final Consumer consumeOnException) { flushAndCloseStream(); - BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, RequestBody.fromFile(localFile), getKey(), getBucket(), defaultBucket); - removeTemporaryFile(); + final CompletableFuture putObjectResponseCompletableFuture = BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, + AsyncRequestBody.fromFile(localFile), + consumeOnCompletion, consumeOnException, + getKey(), getBucket(), defaultBucket) + .whenComplete(((response, throwable) -> removeTemporaryFile())); + return Optional.of(putObjectResponseCompletableFuture); } /** diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferFactory.java index 68cc65b087..da787a9794 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferFactory.java @@ -7,7 +7,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.io.File; import java.io.IOException; @@ -20,7 +20,7 @@ public class LocalFileBufferFactory implements BufferFactory { public static final String SUFFIX = ".log"; @Override - public Buffer getBuffer(final S3Client s3Client, + public Buffer getBuffer(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBuffer.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBuffer.java index a76a3f926a..cfe4ded9ec 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBuffer.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBuffer.java @@ -10,7 +10,10 @@ import java.io.IOException; import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; public class MultipartBuffer implements Buffer { @@ -49,8 +52,8 @@ public Duration getDuration() { * Upload accumulated data to s3 bucket. */ @Override - public void flushToS3() { - s3OutputStream.close(); + public Optional> flushToS3(final Consumer runOnCompletion, final Consumer runOnFailure) { + return Optional.ofNullable(s3OutputStream.close(runOnCompletion, runOnFailure)); } @Override diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBufferFactory.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBufferFactory.java index 2f801060aa..321e294a38 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBufferFactory.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/MultipartBufferFactory.java @@ -6,13 +6,13 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; import org.opensearch.dataprepper.plugins.codec.parquet.S3OutputStream; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.function.Supplier; public class MultipartBufferFactory implements BufferFactory { @Override - public Buffer getBuffer(final S3Client s3Client, + public Buffer getBuffer(final S3AsyncClient s3Client, final Supplier bucketSupplier, final Supplier keySupplier, final String defaultBucket) { diff --git a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManager.java b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManager.java index 5945b7f648..1b2f08ca9f 100644 --- a/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManager.java +++ b/data-prepper-plugins/s3-sink/src/main/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManager.java @@ -14,7 +14,7 @@ import org.opensearch.dataprepper.plugins.sink.s3.codec.CodecFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.Collection; import java.util.Collections; @@ -31,7 +31,7 @@ public class S3GroupManager { private final CodecFactory codecFactory; - private final S3Client s3Client; + private final S3AsyncClient s3Client; private long totalGroupSize; @@ -40,7 +40,7 @@ public S3GroupManager(final S3SinkConfig s3SinkConfig, final S3GroupIdentifierFactory s3GroupIdentifierFactory, final BufferFactory bufferFactory, final CodecFactory codecFactory, - final S3Client s3Client) { + final S3AsyncClient s3Client) { this.s3SinkConfig = s3SinkConfig; this.s3GroupIdentifierFactory = s3GroupIdentifierFactory; this.bufferFactory = bufferFactory; diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStreamTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStreamTest.java index 7050347761..3709a507fc 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStreamTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/codec/parquet/S3OutputStreamTest.java @@ -11,8 +11,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; @@ -21,11 +21,15 @@ import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; -import java.util.UUID; import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.Consumer; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -33,11 +37,18 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; + @ExtendWith(MockitoExtension.class) public class S3OutputStreamTest { @Mock - private S3Client s3Client; + private S3AsyncClient s3Client; + + @Mock + private Consumer runOnCompletion; + + @Mock + private Consumer runOnError; private String bucket; @@ -63,20 +74,25 @@ void close_creates_and_completes_multi_part_upload() { final String uploadId = UUID.randomUUID().toString(); final CreateMultipartUploadResponse createMultipartUploadResponse = mock(CreateMultipartUploadResponse.class); when(createMultipartUploadResponse.uploadId()).thenReturn(uploadId); - when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn(createMultipartUploadResponse); + final CompletableFuture createMultipartUploadResponseCompletableFuture = CompletableFuture.completedFuture(createMultipartUploadResponse); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn(createMultipartUploadResponseCompletableFuture); final UploadPartResponse uploadPartResponse = mock(UploadPartResponse.class); + final CompletableFuture uploadPartResponseCompletableFuture = CompletableFuture.completedFuture(uploadPartResponse); when(uploadPartResponse.eTag()).thenReturn(UUID.randomUUID().toString()); - when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn(uploadPartResponse); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn(uploadPartResponseCompletableFuture); - when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn(mock(CompleteMultipartUploadResponse.class)); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn(CompletableFuture.completedFuture(mock(CompleteMultipartUploadResponse.class))); final S3OutputStream s3OutputStream = createObjectUnderTest(); s3OutputStream.write(bytes); - s3OutputStream.close(); + final CompletableFuture completableFuture = s3OutputStream.close(runOnCompletion, runOnError); + assertThat(completableFuture, notNullValue()); + assertThat(completableFuture.isDone(), equalTo(true)); + assertThat(completableFuture.isCompletedExceptionally(), equalTo(false)); final ArgumentCaptor createMultipartUploadRequestArgumentCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); verify(s3Client).createMultipartUpload(createMultipartUploadRequestArgumentCaptor.capture()); @@ -87,7 +103,7 @@ void close_creates_and_completes_multi_part_upload() { assertThat(createMultipartUploadRequest.key(), equalTo(objectKey)); final ArgumentCaptor uploadPartRequestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); - verify(s3Client).uploadPart(uploadPartRequestArgumentCaptor.capture(), any(RequestBody.class)); + verify(s3Client).uploadPart(uploadPartRequestArgumentCaptor.capture(), any(AsyncRequestBody.class)); final UploadPartRequest uploadPartRequest = uploadPartRequestArgumentCaptor.getValue(); assertThat(uploadPartRequest, notNullValue()); @@ -102,6 +118,8 @@ void close_creates_and_completes_multi_part_upload() { assertThat(completeMultipartUploadRequest, notNullValue()); assertThat(completeMultipartUploadRequest.bucket(), equalTo(bucket)); assertThat(completeMultipartUploadRequest.key(), equalTo(objectKey)); + + verify(runOnCompletion).accept(true); } @Test @@ -109,24 +127,33 @@ void close_with_no_such_bucket_exception_creates_and_completes_multi_part_upload final byte[] bytes = new byte[25]; final String uploadId = UUID.randomUUID().toString(); + final CompletableFuture failedFuture = CompletableFuture.failedFuture(NoSuchBucketException.builder().build()); + final CreateMultipartUploadResponse createMultipartUploadResponse = mock(CreateMultipartUploadResponse.class); when(createMultipartUploadResponse.uploadId()).thenReturn(uploadId); + final CompletableFuture successfulFuture = CompletableFuture.completedFuture(createMultipartUploadResponse); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenThrow(NoSuchBucketException.class) - .thenReturn(createMultipartUploadResponse); + .thenReturn(failedFuture) + .thenReturn(successfulFuture); final UploadPartResponse uploadPartResponse = mock(UploadPartResponse.class); + final CompletableFuture uploadPartResponseCompletableFuture = CompletableFuture.completedFuture(uploadPartResponse); when(uploadPartResponse.eTag()).thenReturn(UUID.randomUUID().toString()); - when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn(uploadPartResponse); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn(uploadPartResponseCompletableFuture); + - when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn(mock(CompleteMultipartUploadResponse.class)); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn(CompletableFuture.completedFuture(mock(CompleteMultipartUploadResponse.class))); final S3OutputStream s3OutputStream = createObjectUnderTest(); s3OutputStream.write(bytes); - s3OutputStream.close(); + final CompletableFuture completableFuture = s3OutputStream.close(runOnCompletion, runOnError); + assertThat(completableFuture, notNullValue()); + assertThat(completableFuture.isDone(), equalTo(true)); + assertThat(completableFuture.isCompletedExceptionally(), equalTo(false)); final ArgumentCaptor createMultipartUploadRequestArgumentCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); verify(s3Client, times(2)).createMultipartUpload(createMultipartUploadRequestArgumentCaptor.capture()); @@ -145,7 +172,7 @@ void close_with_no_such_bucket_exception_creates_and_completes_multi_part_upload assertThat(defaultBucketCreateMultiPartUploadRequest.key(), equalTo(objectKey)); final ArgumentCaptor uploadPartRequestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); - verify(s3Client).uploadPart(uploadPartRequestArgumentCaptor.capture(), any(RequestBody.class)); + verify(s3Client).uploadPart(uploadPartRequestArgumentCaptor.capture(), any(AsyncRequestBody.class)); final UploadPartRequest uploadPartRequest = uploadPartRequestArgumentCaptor.getValue(); assertThat(uploadPartRequest, notNullValue()); @@ -160,5 +187,55 @@ void close_with_no_such_bucket_exception_creates_and_completes_multi_part_upload assertThat(completeMultipartUploadRequest, notNullValue()); assertThat(completeMultipartUploadRequest.bucket(), equalTo(defaultBucket)); assertThat(completeMultipartUploadRequest.key(), equalTo(objectKey)); + + verify(runOnCompletion).accept(true); + } + + @Test + void close_with_upload_part_exception_completes_with_failure_and_returns_null() { + final byte[] bytes = new byte[25]; + final String uploadId = UUID.randomUUID().toString(); + final CreateMultipartUploadResponse createMultipartUploadResponse = mock(CreateMultipartUploadResponse.class); + when(createMultipartUploadResponse.uploadId()).thenReturn(uploadId); + final CompletableFuture createMultipartUploadResponseCompletableFuture = CompletableFuture.completedFuture(createMultipartUploadResponse); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn(createMultipartUploadResponseCompletableFuture); + + final RuntimeException mockException = mock(RuntimeException.class); + final CompletableFuture uploadPartResponseCompletableFuture = CompletableFuture.failedFuture(mockException); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn(uploadPartResponseCompletableFuture); + + final S3OutputStream s3OutputStream = createObjectUnderTest(); + + s3OutputStream.write(bytes); + + final CompletableFuture completableFuture = s3OutputStream.close(runOnCompletion, runOnError); + assertThat(completableFuture, equalTo(null)); + + final ArgumentCaptor createMultipartUploadRequestArgumentCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client).createMultipartUpload(createMultipartUploadRequestArgumentCaptor.capture()); + + final CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestArgumentCaptor.getValue(); + assertThat(createMultipartUploadRequest, notNullValue()); + assertThat(createMultipartUploadRequest.bucket(), equalTo(bucket)); + assertThat(createMultipartUploadRequest.key(), equalTo(objectKey)); + + final ArgumentCaptor uploadPartRequestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); + verify(s3Client).uploadPart(uploadPartRequestArgumentCaptor.capture(), any(AsyncRequestBody.class)); + + final UploadPartRequest uploadPartRequest = uploadPartRequestArgumentCaptor.getValue(); + assertThat(uploadPartRequest, notNullValue()); + assertThat(uploadPartRequest.bucket(), equalTo(bucket)); + assertThat(uploadPartRequest.uploadId(), equalTo(uploadId)); + assertThat(uploadPartRequest.key(), equalTo(objectKey)); + + verify(runOnCompletion).accept(false); + + final ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(runOnError).accept(argumentCaptor.capture()); + + final Throwable exception = argumentCaptor.getValue(); + assertThat(exception, notNullValue()); + assertThat(exception, instanceOf(CompletionException.class)); + assertThat(exception.getCause(), equalTo(mockException)); } } diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceTest.java index 4235b86473..88c4df5202 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/S3SinkServiceTest.java @@ -10,7 +10,9 @@ import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.InOrder; +import org.mockito.MockedStatic; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.codec.OutputCodec; @@ -33,7 +35,6 @@ import org.opensearch.dataprepper.plugins.sink.s3.configuration.ThresholdOptions; import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3Group; import org.opensearch.dataprepper.plugins.sink.s3.grouping.S3GroupManager; -import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; @@ -44,15 +45,17 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.Random; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.stream.Collectors; +import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -61,6 +64,7 @@ import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -85,6 +89,8 @@ class S3SinkServiceTest { private PluginMetrics pluginMetrics; private Counter snapshotSuccessCounter; + private Counter numberOfRecordsSuccessCounter; + private Counter s3ObjectsForceFlushedCounter; private DistributionSummary s3ObjectSizeSummary; private Random random; @@ -111,8 +117,8 @@ void setUp() { PluginFactory pluginFactory = mock(PluginFactory.class); codec = mock(OutputCodec.class); snapshotSuccessCounter = mock(Counter.class); + numberOfRecordsSuccessCounter = mock(Counter.class); Counter snapshotFailedCounter = mock(Counter.class); - Counter numberOfRecordsSuccessCounter = mock(Counter.class); Counter numberOfRecordsFailedCounter = mock(Counter.class); s3ObjectSizeSummary = mock(DistributionSummary.class); s3ObjectsForceFlushedCounter = mock(Counter.class); @@ -156,7 +162,7 @@ private DefaultEventHandle castToDefaultHandle(EventHandle eventHandle) { } private S3SinkService createObjectUnderTest() { - return new S3SinkService(s3SinkConfig, codecContext, s3Client, keyGenerator, Duration.ofMillis(100), pluginMetrics, s3GroupManager); + return new S3SinkService(s3SinkConfig, codecContext, Duration.ofMillis(100), pluginMetrics, s3GroupManager); } @Test @@ -170,7 +176,10 @@ void test_s3SinkService_notNull() { void test_output_with_threshold_set_as_more_then_zero_event_count() throws IOException { InMemoryBuffer buffer = mock(InMemoryBuffer.class); when(buffer.getEventCount()).thenReturn(10); - doNothing().when(buffer).flushToS3(); + when(buffer.getKey()).thenReturn(UUID.randomUUID().toString()); + + final CompletableFuture completableFuture = mock(CompletableFuture.class); + when(buffer.flushToS3(any(Consumer.class), any(Consumer.class))).thenReturn(Optional.of(completableFuture)); when(s3SinkConfig.getThresholdOptions().getEventCount()).thenReturn(5); final OutputStream outputStream = mock(OutputStream.class); @@ -185,8 +194,29 @@ void test_output_with_threshold_set_as_more_then_zero_event_count() throws IOExc doNothing().when(codec).writeEvent(event, outputStream); S3SinkService s3SinkService = createObjectUnderTest(); assertNotNull(s3SinkService); - s3SinkService.output(generateRandomStringEventRecord()); - verify(snapshotSuccessCounter, times(51)).increment(); + + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(generateRandomStringEventRecord()); + } + + final ArgumentCaptor> argumentCaptorForCompletion = ArgumentCaptor.forClass(Consumer.class); + verify(buffer, times(51)).flushToS3(argumentCaptorForCompletion.capture(), any(Consumer.class)); + + final List> completionConsumers = argumentCaptorForCompletion.getAllValues(); + assertThat(completionConsumers.size(), equalTo(51)); + + final Consumer completionConsumer = completionConsumers.get(0); + completionConsumer.accept(true); + + + // only ran one of the completion consumers + verify(snapshotSuccessCounter, times(1)).increment(); + verify(numberOfRecordsSuccessCounter).increment(s3Group.getBuffer().getEventCount()); + verify(s3Group).releaseEventHandles(true); } @@ -197,7 +227,9 @@ void test_output_with_threshold_set_as_zero_event_count() throws IOException { InMemoryBuffer buffer = mock(InMemoryBuffer.class); when(buffer.getSize()).thenReturn(25500L); - doNothing().when(buffer).flushToS3(); + + final CompletableFuture completableFuture = mock(CompletableFuture.class); + when(buffer.flushToS3(any(Consumer.class), any(Consumer.class))).thenReturn(Optional.of(completableFuture)); when(s3SinkConfig.getThresholdOptions().getEventCount()).thenReturn(0); when(s3SinkConfig.getThresholdOptions().getMaximumSize()).thenReturn(ByteCount.parse("2kb")); @@ -212,8 +244,13 @@ void test_output_with_threshold_set_as_zero_event_count() throws IOException { doNothing().when(codec).writeEvent(event, outputStream); S3SinkService s3SinkService = createObjectUnderTest(); assertNotNull(s3SinkService); - s3SinkService.output(generateRandomStringEventRecord()); - verify(snapshotSuccessCounter, times(51)).increment(); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(generateRandomStringEventRecord()); + } } @Test @@ -221,7 +258,8 @@ void test_output_with_uploadedToS3_success() throws IOException { InMemoryBuffer buffer = mock(InMemoryBuffer.class); when(buffer.getEventCount()).thenReturn(10); - doNothing().when(buffer).flushToS3(); + final CompletableFuture completableFuture = mock(CompletableFuture.class); + when(buffer.flushToS3(any(Consumer.class), any(Consumer.class))).thenReturn(Optional.of(completableFuture)); final OutputStream outputStream = mock(OutputStream.class); final Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); @@ -236,8 +274,13 @@ void test_output_with_uploadedToS3_success() throws IOException { S3SinkService s3SinkService = createObjectUnderTest(); assertNotNull(s3SinkService); assertThat(s3SinkService, instanceOf(S3SinkService.class)); - s3SinkService.output(generateRandomStringEventRecord()); - verify(snapshotSuccessCounter, times(51)).increment(); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(generateRandomStringEventRecord()); + } } @Test @@ -248,6 +291,9 @@ void test_output_with_uploadedToS3_success_records_byte_count() throws IOExcepti final long objectSize = random.nextInt(1_000_000) + 10_000; when(buffer.getSize()).thenReturn(objectSize); + final CompletableFuture completableFuture = mock(CompletableFuture.class); + when(buffer.flushToS3(any(Consumer.class), any(Consumer.class))).thenReturn(Optional.of(completableFuture)); + final OutputStream outputStream = mock(OutputStream.class); final Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); final S3Group s3Group = mock(S3Group.class); @@ -259,17 +305,21 @@ void test_output_with_uploadedToS3_success_records_byte_count() throws IOExcepti doNothing().when(codec).writeEvent(event, outputStream); final S3SinkService s3SinkService = createObjectUnderTest(); - s3SinkService.output(generateRandomStringEventRecord()); - - verify(s3ObjectSizeSummary, times(51)).record(objectSize); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(generateRandomStringEventRecord()); + } } @Test void test_output_with_uploadedToS3_midBatch_generatesNewOutputStream() throws IOException { - InMemoryBuffer buffer = mock(InMemoryBuffer.class); when(buffer.getEventCount()).thenReturn(10); - doNothing().when(buffer).flushToS3(); + final CompletableFuture completableFuture = mock(CompletableFuture.class); + when(buffer.flushToS3(any(Consumer.class), any(Consumer.class))).thenReturn(Optional.of(completableFuture)); final OutputStream outputStream1 = mock(OutputStream.class); final OutputStream outputStream2 = mock(OutputStream.class); when(buffer.getOutputStream()) @@ -290,9 +340,14 @@ void test_output_with_uploadedToS3_midBatch_generatesNewOutputStream() throws IO assertNotNull(s3SinkService); assertThat(s3SinkService, instanceOf(S3SinkService.class)); - s3SinkService.output(generateEventRecords(2)); - verify(snapshotSuccessCounter, times(3)).increment(); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(generateEventRecords(2)); + } verify(codec).writeEvent(any(), eq(outputStream1)); verify(codec).writeEvent(any(), eq(outputStream2)); } @@ -325,7 +380,8 @@ void test_output_with_uploadedToS3_failure_does_not_record_byte_count() throws I Buffer buffer = mock(Buffer.class); - doThrow(AwsServiceException.class).when(buffer).flushToS3(); + final CompletableFuture completableFuture = mock(CompletableFuture.class); + when(buffer.flushToS3(any(Consumer.class), any(Consumer.class))).thenReturn(Optional.of(completableFuture)); final long objectSize = random.nextInt(1_000_000) + 10_000; when(buffer.getSize()).thenReturn(objectSize); @@ -337,14 +393,27 @@ void test_output_with_uploadedToS3_failure_does_not_record_byte_count() throws I when(s3Group.getOutputCodec()).thenReturn(codec); when(s3GroupManager.getOrCreateGroupForEvent(any(Event.class))).thenReturn(s3Group); - when(s3GroupManager.getS3GroupEntries()).thenReturn(Collections.singletonList(s3Group)); + when(s3GroupManager.getS3GroupEntries()).thenReturn(Collections.emptyList()); final OutputStream outputStream = mock(OutputStream.class); doNothing().when(codec).writeEvent(event, outputStream); - s3SinkService.output(Collections.singletonList(new Record<>(event))); + + + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(Collections.singletonList(new Record<>(event))); + } + final ArgumentCaptor> argumentCaptorForCompletion = ArgumentCaptor.forClass(Consumer.class); + verify(buffer, times(1)).flushToS3(argumentCaptorForCompletion.capture(), any(Consumer.class)); + + final Consumer completionConsumer = argumentCaptorForCompletion.getValue(); + completionConsumer.accept(false); verify(s3ObjectSizeSummary, never()).record(anyLong()); - verify(buffer, times(6)).flushToS3(); + verify(s3Group).releaseEventHandles(false); } @Test @@ -365,10 +434,15 @@ void test_output_with_no_incoming_records_flushes_batch() throws IOException { doNothing().when(codec).writeEvent(event, outputStream); final S3SinkService s3SinkService = createObjectUnderTest(); - s3SinkService.output(Collections.emptyList()); - verify(snapshotSuccessCounter, times(1)).increment(); - verify(buffer, times(1)).flushToS3(); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(Collections.emptyList()); + } + verify(buffer, times(1)).flushToS3(any(Consumer.class), any(Consumer.class)); } @Test @@ -380,52 +454,6 @@ void test_output_with_no_incoming_records_or_buffered_records_short_circuits() t verify(snapshotSuccessCounter, times(0)).increment(); } - @Test - void test_retryFlushToS3_positive() throws InterruptedException, IOException { - InMemoryBuffer buffer = mock(InMemoryBuffer.class); - doNothing().when(buffer).flushToS3(); - - S3SinkService s3SinkService = createObjectUnderTest(); - assertNotNull(s3SinkService); - assertNotNull(buffer); - OutputStream outputStream = buffer.getOutputStream(); - final Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); - final S3Group s3Group = mock(S3Group.class); - when(s3Group.getBuffer()).thenReturn(buffer); - when(s3Group.getOutputCodec()).thenReturn(codec); - - when(s3GroupManager.getOrCreateGroupForEvent(event)).thenReturn(s3Group); - when(s3GroupManager.getS3GroupEntries()).thenReturn(Collections.singletonList(s3Group)); - - codec.writeEvent(event, outputStream); - final String s3Key = UUID.randomUUID().toString(); - boolean isUploadedToS3 = s3SinkService.retryFlushToS3(buffer, s3Key); - assertTrue(isUploadedToS3); - } - - @Test - void test_retryFlushToS3_negative() throws InterruptedException, IOException { - InMemoryBuffer buffer = mock(InMemoryBuffer.class); - when(s3SinkConfig.getBucketName()).thenReturn(""); - S3SinkService s3SinkService = createObjectUnderTest(); - assertNotNull(s3SinkService); - OutputStream outputStream = buffer.getOutputStream(); - final Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); - final S3Group s3Group = mock(S3Group.class); - when(s3Group.getBuffer()).thenReturn(buffer); - when(s3Group.getOutputCodec()).thenReturn(codec); - - when(s3GroupManager.getOrCreateGroupForEvent(event)).thenReturn(s3Group); - when(s3GroupManager.getS3GroupEntries()).thenReturn(Collections.singletonList(s3Group)); - - codec.writeEvent(event, outputStream); - final String s3Key = UUID.randomUUID().toString(); - doThrow(AwsServiceException.class).when(buffer).flushToS3(); - boolean isUploadedToS3 = s3SinkService.retryFlushToS3(buffer, s3Key); - assertFalse(isUploadedToS3); - } - - @Test void output_will_release_all_handles_since_a_flush() throws IOException { final Buffer buffer = mock(Buffer.class); @@ -446,7 +474,20 @@ void output_will_release_all_handles_since_a_flush() throws IOException { final S3SinkService s3SinkService = createObjectUnderTest(); final Collection> records = generateRandomStringEventRecord(); final List eventHandles = records.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - s3SinkService.output(records); + + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(records); + } + + final ArgumentCaptor> argumentCaptorForCompletion = ArgumentCaptor.forClass(Consumer.class); + verify(buffer, times(51)).flushToS3(argumentCaptorForCompletion.capture(), any(Consumer.class)); + + final Consumer completionConsumer = argumentCaptorForCompletion.getValue(); + completionConsumer.accept(true); InOrder inOrder = inOrder(s3Group); for (final EventHandle eventHandle : eventHandles) { @@ -475,89 +516,30 @@ void output_will_skip_releasing_events_without_EventHandle_objects() throws IOEx final Collection> records = generateRandomStringEventRecord(); final List eventHandles = records.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - s3SinkService.output(records); - final Collection> records2 = generateRandomStringEventRecord(); final List eventHandles2 = records2.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - s3SinkService.output(records2); - - InOrder inOrder = inOrder(s3Group); - for (final EventHandle eventHandle : eventHandles) { - inOrder.verify(s3Group).addEventHandle(eventHandle); - } - inOrder.verify(s3Group).releaseEventHandles(true); - for (final EventHandle eventHandle : eventHandles2) { - inOrder.verify(s3Group).addEventHandle(eventHandle); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(records); + s3SinkService.output(records2); } - inOrder.verify(s3Group).releaseEventHandles(true); - - } - @Test - void output_will_release_all_handles_since_a_flush_when_S3_fails() throws IOException { - final Buffer buffer = mock(Buffer.class); + final ArgumentCaptor> argumentCaptorForCompletion = ArgumentCaptor.forClass(Consumer.class); + verify(buffer, times(100)).flushToS3(argumentCaptorForCompletion.capture(), any(Consumer.class)); - doThrow(AwsServiceException.class).when(buffer).flushToS3(); - - final long objectSize = random.nextInt(1_000_000) + 10_000; - when(buffer.getSize()).thenReturn(objectSize); - - final OutputStream outputStream = mock(OutputStream.class); - final Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); - final S3Group s3Group = mock(S3Group.class); - when(s3Group.getBuffer()).thenReturn(buffer); - when(s3Group.getOutputCodec()).thenReturn(codec); - - when(s3GroupManager.getOrCreateGroupForEvent(any(Event.class))).thenReturn(s3Group); - - doNothing().when(codec).writeEvent(event, outputStream); - final S3SinkService s3SinkService = createObjectUnderTest(); - final List> records = generateEventRecords(1); - final List eventHandles = records.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - - s3SinkService.output(records); - - InOrder inOrder = inOrder(s3Group); - for (final EventHandle eventHandle : eventHandles) { - inOrder.verify(s3Group).addEventHandle(eventHandle); - } - inOrder.verify(s3Group).releaseEventHandles(false); - } - - @Test - void output_will_release_only_new_handles_since_a_flush() throws IOException { - final Buffer buffer = mock(Buffer.class); - - final long objectSize = random.nextInt(1_000_000) + 10_000; - when(buffer.getSize()).thenReturn(objectSize); - - final OutputStream outputStream = mock(OutputStream.class); - final Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); - final S3Group s3Group = mock(S3Group.class); - when(s3Group.getBuffer()).thenReturn(buffer); - when(s3Group.getOutputCodec()).thenReturn(codec); - - when(s3GroupManager.getOrCreateGroupForEvent(any(Event.class))).thenReturn(s3Group); - - doNothing().when(codec).writeEvent(event, outputStream); - final S3SinkService s3SinkService = createObjectUnderTest(); - final Collection> records = generateRandomStringEventRecord(); - final List eventHandles = records.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - s3SinkService.output(records); - final Collection> records2 = generateRandomStringEventRecord(); - final List eventHandles2 = records2.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - s3SinkService.output(records2); + final Consumer completionConsumer = argumentCaptorForCompletion.getValue(); + completionConsumer.accept(true); InOrder inOrder = inOrder(s3Group); for (final EventHandle eventHandle : eventHandles) { inOrder.verify(s3Group).addEventHandle(eventHandle); } inOrder.verify(s3Group).releaseEventHandles(true); - for (final EventHandle eventHandle : eventHandles2) { - inOrder.verify(s3Group).addEventHandle(eventHandle); - } - inOrder.verify(s3Group).releaseEventHandles(true); + } @Test @@ -588,7 +570,14 @@ void output_will_skip_and_drop_failed_records() throws IOException { doThrow(RuntimeException.class).when(codec).writeEvent(event1, outputStream); - createObjectUnderTest().output(records); + + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + createObjectUnderTest().output(records); + } InOrder inOrder = inOrder(codec, s3Group); inOrder.verify(codec).start(eq(outputStream), eq(event1), any()); @@ -596,51 +585,12 @@ void output_will_skip_and_drop_failed_records() throws IOException { inOrder.verify(s3Group, never()).addEventHandle(eventHandle1); inOrder.verify(codec).writeEvent(event2, outputStream); inOrder.verify(s3Group).addEventHandle(eventHandle2); - inOrder.verify(s3Group).releaseEventHandles(true); verify(acknowledgementSet).release(eventHandle1, false); verify(acknowledgementSet, never()).release(eventHandle1, true); verify(acknowledgementSet, never()).release(eventHandle2, false); } - @Test - void output_will_release_only_new_handles_since_a_flush_when_S3_fails() throws IOException { - final Buffer buffer = mock(Buffer.class); - - doThrow(AwsServiceException.class).when(buffer).flushToS3(); - - final long objectSize = random.nextInt(1_000_000) + 10_000; - when(buffer.getSize()).thenReturn(objectSize); - - final OutputStream outputStream = mock(OutputStream.class); - final S3Group s3Group = mock(S3Group.class); - when(s3Group.getBuffer()).thenReturn(buffer); - when(s3Group.getOutputCodec()).thenReturn(codec); - - when(s3GroupManager.getOrCreateGroupForEvent(any(Event.class))).thenReturn(s3Group); - - doNothing().when(codec).writeEvent(any(Event.class), eq(outputStream)); - final S3SinkService s3SinkService = createObjectUnderTest(); - final List> records = generateEventRecords(1); - final List eventHandles = records.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - s3SinkService.output(records); - - final List> records2 = generateEventRecords(1); - final List eventHandles2 = records2.stream().map(Record::getData).map(Event::getEventHandle).map(this::castToDefaultHandle).collect(Collectors.toList()); - - s3SinkService.output(records2); - InOrder inOrder = inOrder(s3Group); - for (final EventHandle eventHandle : eventHandles) { - inOrder.verify(s3Group).addEventHandle(eventHandle); - } - inOrder.verify(s3Group).releaseEventHandles(false); - for (final EventHandle eventHandle : eventHandles2) { - inOrder.verify(s3Group).addEventHandle(eventHandle); - } - inOrder.verify(s3Group).releaseEventHandles(false); - - } - @Test void output_will_flush_the_largest_group_until_below_aggregate_threshold_when_aggregate_threshold_is_reached() throws IOException { final long bytesThreshold = 100_000L; @@ -692,10 +642,17 @@ void output_will_flush_the_largest_group_until_below_aggregate_threshold_when_ag doNothing().when(codec).writeEvent(any(Event.class), any(OutputStream.class)); final S3SinkService s3SinkService = createObjectUnderTest(); - s3SinkService.output(List.of(new Record<>(firstGroupEvent), new Record<>(secondGroupEvent), new Record<>(thirdGroupEvent))); - verify(thirdGroupBuffer).flushToS3(); - verify(firstGroupBuffer).flushToS3(); + try (final MockedStatic completableFutureMockedStatic = mockStatic(CompletableFuture.class)) { + final CompletableFuture mockCompletableFuture = mock(CompletableFuture.class); + when(mockCompletableFuture.thenRun(any(Runnable.class))).thenReturn(mockCompletableFuture); + when(mockCompletableFuture.join()).thenReturn(null); + completableFutureMockedStatic.when(() -> CompletableFuture.allOf(any())).thenReturn(mockCompletableFuture); + s3SinkService.output(List.of(new Record<>(firstGroupEvent), new Record<>(secondGroupEvent), new Record<>(thirdGroupEvent))); + } + + verify(thirdGroupBuffer).flushToS3(any(Consumer.class), any(Consumer.class)); + verify(firstGroupBuffer).flushToS3(any(Consumer.class), any(Consumer.class)); verify(codec, times(2)).complete(any(OutputStream.class)); @@ -703,7 +660,7 @@ void output_will_flush_the_largest_group_until_below_aggregate_threshold_when_ag verify(s3GroupManager).removeGroup(firstGroup); verify(s3GroupManager, never()).removeGroup(secondGroup); - verify(secondGroupBuffer, never()).flushToS3(); + verify(secondGroupBuffer, never()).flushToS3(any(Consumer.class), any(Consumer.class)); verify(s3ObjectsForceFlushedCounter, times(2)).increment(); } diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilitiesTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilitiesTest.java index 23de985b58..f438631b4c 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilitiesTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/BufferUtilitiesTest.java @@ -10,10 +10,12 @@ import org.junit.jupiter.params.provider.ArgumentsSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.NoSuchBucketException; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; @@ -21,14 +23,17 @@ import java.util.List; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.Consumer; import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -44,10 +49,16 @@ public class BufferUtilitiesTest { private String objectKey; @Mock - private RequestBody requestBody; + private AsyncRequestBody requestBody; @Mock - private S3Client s3Client; + private Consumer mockRunOnCompletion; + + @Mock + private Consumer mockRunOnFailure; + + @Mock + private S3AsyncClient s3Client; @BeforeEach void setup() { @@ -59,9 +70,11 @@ void setup() { @Test void putObjectOrSendToDefaultBucket_with_no_exception_sends_to_target_bucket() { - when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenReturn(mock(PutObjectResponse.class)); + final CompletableFuture successfulFuture = CompletableFuture.completedFuture(mock(PutObjectResponse.class)); - BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, objectKey, targetBucket, defaultBucket); + when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenReturn(successfulFuture); + + BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, mockRunOnCompletion, mockRunOnFailure, objectKey, targetBucket, defaultBucket).join(); final ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); verify(s3Client, times(1)).putObject(argumentCaptor.capture(), eq(requestBody)); @@ -72,36 +85,45 @@ void putObjectOrSendToDefaultBucket_with_no_exception_sends_to_target_bucket() { assertThat(putObjectRequest.bucket(), equalTo(targetBucket)); assertThat(putObjectRequest.key(), equalTo(objectKey)); + verify(mockRunOnFailure, never()).accept(any()); + verify(mockRunOnCompletion).accept(true); } @Test - void putObjectOrSendToDefaultBucket_with_no_such_bucket_exception_and_null_default_bucket_throws_exception() { - when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenThrow(NoSuchBucketException.class); + void putObjectOrSendToDefaultBucket_with_no_such_bucket_exception_and_null_default_bucket_completes_with_exception() { + + final CompletableFuture failedFuture = CompletableFuture.failedFuture(NoSuchBucketException.builder().build()); + when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenReturn(failedFuture); - assertThrows(NoSuchBucketException.class, () -> BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, objectKey, targetBucket, null)); + BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, mockRunOnCompletion, mockRunOnFailure, objectKey, targetBucket, null).join(); verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), eq(requestBody)); + verify(mockRunOnCompletion).accept(false); + verify(mockRunOnFailure).accept(any()); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void putObjectOrSendToDefaultBucket_with_S3Exception_that_is_not_access_denied_or_no_such_bucket_throws_exception(final boolean defaultBucketEnabled) { - when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenThrow(RuntimeException.class); + void putObjectOrSendToDefaultBucket_with_S3Exception_that_is_not_access_denied_or_no_such_bucket_completes_with_exception(final boolean defaultBucketEnabled) { + final CompletableFuture failedFuture = CompletableFuture.failedFuture(new RuntimeException(UUID.randomUUID().toString())); + when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenReturn(failedFuture); - assertThrows(RuntimeException.class, () -> BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, objectKey, targetBucket, - defaultBucketEnabled ? defaultBucket : null)); + BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, mockRunOnCompletion, mockRunOnFailure, objectKey, targetBucket, + defaultBucketEnabled ? defaultBucket : null); verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), eq(requestBody)); + verify(mockRunOnCompletion).accept(false); + verify(mockRunOnFailure).accept(any()); } @ParameterizedTest @ArgumentsSource(ExceptionsProvider.class) void putObjectOrSendToDefaultBucket_with_NoSuchBucketException_or_access_denied_sends_to_default_bucket(final Exception exception) { - when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))) - .thenThrow(exception) - .thenReturn(mock(PutObjectResponse.class)); + final CompletableFuture successfulFuture = CompletableFuture.completedFuture(mock(PutObjectResponse.class)); + final CompletableFuture failedFuture = CompletableFuture.failedFuture(exception); + when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenReturn(failedFuture).thenReturn(successfulFuture); - BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, objectKey, targetBucket, defaultBucket); + BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, mockRunOnCompletion, mockRunOnFailure, objectKey, targetBucket, defaultBucket); final ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); verify(s3Client, times(2)).putObject(argumentCaptor.capture(), eq(requestBody)); @@ -116,6 +138,43 @@ void putObjectOrSendToDefaultBucket_with_NoSuchBucketException_or_access_denied_ final PutObjectRequest defaultBucketPutObjectRequest = putObjectRequestList.get(1); assertThat(defaultBucketPutObjectRequest.bucket(), equalTo(defaultBucket)); assertThat(defaultBucketPutObjectRequest.key(), equalTo(objectKey)); + + final InOrder inOrder = Mockito.inOrder(mockRunOnCompletion, mockRunOnFailure); + + inOrder.verify(mockRunOnFailure).accept(exception); + inOrder.verify(mockRunOnCompletion).accept(true); + } + + @Test + void putObject_failing_to_send_to_bucket_and_default_bucket_completes_as_expected() { + + final NoSuchBucketException noSuchBucketException = NoSuchBucketException.builder().build(); + final RuntimeException runtimeException = new RuntimeException(); + final CompletableFuture failedDefaultBucket = CompletableFuture.failedFuture(runtimeException); + final CompletableFuture failedFuture = CompletableFuture.failedFuture(noSuchBucketException); + when(s3Client.putObject(any(PutObjectRequest.class), eq(requestBody))).thenReturn(failedFuture).thenReturn(failedDefaultBucket); + + BufferUtilities.putObjectOrSendToDefaultBucket(s3Client, requestBody, mockRunOnCompletion, mockRunOnFailure, objectKey, targetBucket, defaultBucket); + + final ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(s3Client, times(2)).putObject(argumentCaptor.capture(), eq(requestBody)); + + assertThat(argumentCaptor.getAllValues().size(), equalTo(2)); + + final List putObjectRequestList = argumentCaptor.getAllValues(); + final PutObjectRequest putObjectRequest = putObjectRequestList.get(0); + assertThat(putObjectRequest.bucket(), equalTo(targetBucket)); + assertThat(putObjectRequest.key(), equalTo(objectKey)); + + final PutObjectRequest defaultBucketPutObjectRequest = putObjectRequestList.get(1); + assertThat(defaultBucketPutObjectRequest.bucket(), equalTo(defaultBucket)); + assertThat(defaultBucketPutObjectRequest.key(), equalTo(objectKey)); + + final InOrder inOrder = Mockito.inOrder(mockRunOnCompletion, mockRunOnFailure); + + inOrder.verify(mockRunOnFailure).accept(noSuchBucketException); + inOrder.verify(mockRunOnFailure).accept(any(CompletionException.class)); + inOrder.verify(mockRunOnCompletion).accept(false); } private static class ExceptionsProvider implements ArgumentsProvider { diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferTest.java index 8813c84823..25e541ebbf 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CodecBufferTest.java @@ -12,6 +12,7 @@ import java.util.Optional; import java.util.Random; import java.util.UUID; +import java.util.function.Consumer; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; @@ -27,6 +28,12 @@ class CodecBufferTest { @Mock private BufferedCodec bufferedCodec; + + @Mock + private Consumer mockRunOnCompletion; + + @Mock + private Consumer mockRunOnFailure; private Random random; @BeforeEach @@ -100,8 +107,8 @@ void setEventCount_calls_inner_setEventCount() { @Test void flushToS3_calls_inner_flushToS3() { - createObjectUnderTest().flushToS3(); + createObjectUnderTest().flushToS3(mockRunOnCompletion, mockRunOnFailure); - verify(innerBuffer).flushToS3(); + verify(innerBuffer).flushToS3(mockRunOnCompletion, mockRunOnFailure); } } \ No newline at end of file diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactoryTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactoryTest.java index 90a365c846..0f27b69de5 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactoryTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferFactoryTest.java @@ -13,7 +13,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.model.codec.OutputCodec; import org.opensearch.dataprepper.plugins.sink.s3.compression.CompressionEngine; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.UUID; import java.util.function.Supplier; @@ -37,7 +37,7 @@ class CompressionBufferFactoryTest { private CompressionEngine compressionEngine; @Mock - private S3Client s3Client; + private S3AsyncClient s3Client; @Mock private Supplier bucketSupplier; diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferTest.java index 428223783a..3d12dcd928 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/CompressionBufferTest.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.util.Random; import java.util.UUID; +import java.util.function.Consumer; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.sameInstance; @@ -37,6 +38,12 @@ class CompressionBufferTest { @Mock private CompressionEngine compressionEngine; + @Mock + private Consumer mockRunOnCompletion; + + @Mock + private Consumer mockRunOnFailure; + private Random random; @BeforeEach @@ -96,9 +103,9 @@ void flushToS3_calls_inner_flushToS3() { final String bucket = UUID.randomUUID().toString(); final String key = UUID.randomUUID().toString(); - createObjectUnderTest().flushToS3(); + createObjectUnderTest().flushToS3(mockRunOnCompletion, mockRunOnFailure); - verify(innerBuffer).flushToS3(); + verify(innerBuffer).flushToS3(mockRunOnCompletion, mockRunOnFailure); } @Test diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferTest.java index 27c470c5e3..5bd0539152 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/InMemoryBufferTest.java @@ -6,23 +6,26 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; import org.apache.parquet.io.PositionOutputStream; -import org.hamcrest.Matchers; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; -import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.IOException; import java.io.OutputStream; import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.hamcrest.CoreMatchers.equalTo; @@ -31,10 +34,10 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -42,11 +45,18 @@ class InMemoryBufferTest { public static final int MAX_EVENTS = 55; @Mock - private S3Client s3Client; + private S3AsyncClient s3Client; @Mock private Supplier bucketSupplier; @Mock private Supplier keySupplier; + + @Mock + private Consumer mockRunOnCompletion; + + @Mock + private Consumer mockRunOnFailure; + private InMemoryBuffer inMemoryBuffer; @Test @@ -102,39 +112,29 @@ void getDuration_provides_duration_within_expected_range() throws IOException, I } @Test - void test_with_write_event_into_buffer_and_flush_toS3() throws IOException { - inMemoryBuffer = new InMemoryBuffer(s3Client, bucketSupplier, keySupplier, null); + void test_flush_to_s3_success() { + final String key = UUID.randomUUID().toString(); + final String bucket = UUID.randomUUID().toString(); + when(keySupplier.get()).thenReturn(key); + when(bucketSupplier.get()).thenReturn(bucket); - while (inMemoryBuffer.getEventCount() < MAX_EVENTS) { - OutputStream outputStream = inMemoryBuffer.getOutputStream(); - outputStream.write(generateByteArray()); - int eventCount = inMemoryBuffer.getEventCount() +1; - inMemoryBuffer.setEventCount(eventCount); - } - assertDoesNotThrow(() -> { - inMemoryBuffer.flushToS3(); - }); - } - - @Test - void test_uploadedToS3_success() { inMemoryBuffer = new InMemoryBuffer(s3Client, bucketSupplier, keySupplier, null); Assertions.assertNotNull(inMemoryBuffer); - assertDoesNotThrow(() -> { - inMemoryBuffer.flushToS3(); - }); - } - @Test - void test_uploadedToS3_fails() { - inMemoryBuffer = new InMemoryBuffer(s3Client, bucketSupplier, keySupplier, null); - Assertions.assertNotNull(inMemoryBuffer); - SdkClientException sdkClientException = mock(SdkClientException.class); - when(s3Client.putObject(any(PutObjectRequest.class), any(RequestBody.class))) - .thenThrow(sdkClientException); - SdkClientException actualException = assertThrows(SdkClientException.class, () -> inMemoryBuffer.flushToS3()); + final CompletableFuture expectedFuture = mock(CompletableFuture.class); + + try (final MockedStatic bufferUtilitiesMockedStatic = mockStatic(BufferUtilities.class)) { + bufferUtilitiesMockedStatic.when(() -> + BufferUtilities.putObjectOrSendToDefaultBucket(eq(s3Client), any(AsyncRequestBody.class), + eq(mockRunOnCompletion), eq(mockRunOnFailure), eq(key), eq(bucket), eq(null))) + .thenReturn(expectedFuture); + + final Optional> result = inMemoryBuffer.flushToS3(mockRunOnCompletion, mockRunOnFailure); + assertThat(result.isPresent(), equalTo(true)); + assertThat(result.get(), equalTo(expectedFuture)); + + } - assertThat(actualException, Matchers.equalTo(sdkClientException)); } @Test diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferTest.java index bcc608e1ec..9a839a85f5 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/accumulator/LocalFileBufferTest.java @@ -1,33 +1,38 @@ package org.opensearch.dataprepper.plugins.sink.s3.accumulator; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; -import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import java.io.File; import java.io.IOException; import java.io.OutputStream; import java.time.Duration; +import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.function.Supplier; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.notNullValue; -import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -39,11 +44,18 @@ class LocalFileBufferTest { public static final String PREFIX = "local"; public static final String SUFFIX = ".log"; @Mock - private S3Client s3Client; + private S3AsyncClient s3Client; @Mock private Supplier bucketSupplier; @Mock private Supplier keySupplier; + + @Mock + private Consumer mockRunOnCompletion; + + @Mock + private Consumer mockRunOnFailure; + private LocalFileBuffer localFileBuffer; private File tempFile; @@ -100,40 +112,26 @@ void test_with_write_events_into_buffer_and_flush_toS3() throws IOException { when(keySupplier.get()).thenReturn(KEY); when(bucketSupplier.get()).thenReturn(BUCKET_NAME); - assertDoesNotThrow(() -> { - localFileBuffer.flushToS3(); - }); - - ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); - verify(s3Client).putObject(putObjectRequestArgumentCaptor.capture(), any(RequestBody.class)); - PutObjectRequest actualRequest = putObjectRequestArgumentCaptor.getValue(); - - assertThat(actualRequest, notNullValue()); - assertThat(actualRequest.bucket(), equalTo(BUCKET_NAME)); - assertThat(actualRequest.key(), equalTo(KEY)); - assertThat(actualRequest.expectedBucketOwner(), nullValue()); + final CompletableFuture expectedFuture = mock(CompletableFuture.class); + when(expectedFuture.whenComplete(any(BiConsumer.class))) + .thenReturn(expectedFuture); - assertFalse(tempFile.exists(), "The temp file has not been deleted."); - } + try (final MockedStatic bufferUtilitiesMockedStatic = mockStatic(BufferUtilities.class)) { + bufferUtilitiesMockedStatic.when(() -> + BufferUtilities.putObjectOrSendToDefaultBucket(eq(s3Client), any(AsyncRequestBody.class), + eq(mockRunOnCompletion), eq(mockRunOnFailure), eq(KEY), eq(BUCKET_NAME), eq(defaultBucket))) + .thenReturn(expectedFuture); - @Test - void test_uploadedToS3_success() { - when(keySupplier.get()).thenReturn(KEY); - when(bucketSupplier.get()).thenReturn(BUCKET_NAME); - - Assertions.assertNotNull(localFileBuffer); - assertDoesNotThrow(() -> { - localFileBuffer.flushToS3(); - }); + final Optional> result = localFileBuffer.flushToS3(mockRunOnCompletion, mockRunOnFailure); + assertThat(result.isPresent(), equalTo(true)); + assertThat(result.get(), equalTo(expectedFuture)); - ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); - verify(s3Client).putObject(putObjectRequestArgumentCaptor.capture(), any(RequestBody.class)); - PutObjectRequest actualRequest = putObjectRequestArgumentCaptor.getValue(); + final ArgumentCaptor biConsumer = ArgumentCaptor.forClass(BiConsumer.class); + verify(expectedFuture).whenComplete(biConsumer.capture()); - assertThat(actualRequest, notNullValue()); - assertThat(actualRequest.bucket(), equalTo(BUCKET_NAME)); - assertThat(actualRequest.key(), equalTo(KEY)); - assertThat(actualRequest.expectedBucketOwner(), nullValue()); + final BiConsumer actualBiConsumer = biConsumer.getValue(); + actualBiConsumer.accept(mock(PutObjectResponse.class), mock(RuntimeException.class)); + } assertFalse(tempFile.exists(), "The temp file has not been deleted."); } diff --git a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManagerTest.java b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManagerTest.java index 72afddb1ed..4764c8431f 100644 --- a/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManagerTest.java +++ b/data-prepper-plugins/s3-sink/src/test/java/org/opensearch/dataprepper/plugins/sink/s3/grouping/S3GroupManagerTest.java @@ -15,7 +15,7 @@ import org.opensearch.dataprepper.plugins.sink.s3.accumulator.Buffer; import org.opensearch.dataprepper.plugins.sink.s3.accumulator.BufferFactory; import org.opensearch.dataprepper.plugins.sink.s3.codec.CodecFactory; -import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import java.util.Collection; import java.util.UUID; @@ -48,7 +48,7 @@ public class S3GroupManagerTest { private CodecFactory codecFactory; @Mock - private S3Client s3Client; + private S3AsyncClient s3Client; private S3GroupManager createObjectUnderTest() { return new S3GroupManager(s3SinkConfig, s3GroupIdentifierFactory, bufferFactory, codecFactory, s3Client);