Skip to content

Commit

Permalink
Add async blob read support for encrypted containers
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <[email protected]>
  • Loading branch information
kotwanikunal committed Sep 20, 2023
1 parent b3d4698 commit 421d3c5
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 66 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681))
- Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694))
- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131))

### Dependencies
- Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575))
Expand Down Expand Up @@ -108,4 +109,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Security

[Unreleased 3.0]: https://github.com/opensearch-project/OpenSearch/compare/2.x...HEAD
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.11...2.x
[Unreleased 2.x]: https://github.com/opensearch-project/OpenSearch/compare/2.11...2.x
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -47,24 +48,16 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp

@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
DecryptingReadContextListener<T, U> decryptingReadContextListener = new DecryptingReadContextListener<>(
listener,
cryptoHandler,
getEncryptedHeaderContentSupplier(blobName)
);
blobContainer.readBlobAsync(blobName, decryptingReadContextListener);
}
try {
ActionListener<ReadContext> decryptingCompletionListener = ActionListener.map(
listener,
readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, getEncryptedHeaderContentSupplier(blobName))
);

private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
return (start, end) -> {
byte[] buffer;
int length = (int) (end - start + 1);
try (InputStream inputStream = blobContainer.readBlob(blobName, start, length)) {
buffer = new byte[length];
inputStream.readNBytes(buffer, (int) start, buffer.length);
}
return buffer;
};
blobContainer.readBlobAsync(blobName, decryptingCompletionListener);
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
Expand Down Expand Up @@ -124,47 +117,17 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {

}

static class DecryptingReadContextListener<T, U> implements ActionListener<ReadContext> {

private final ActionListener<ReadContext> completionListener;
private final CryptoHandler<T, U> cryptoHandler;
private final EncryptedHeaderContentSupplier encryptedHeaderContentSupplier;

public DecryptingReadContextListener(
ActionListener<ReadContext> completionListener,
CryptoHandler<T, U> cryptoHandler,
EncryptedHeaderContentSupplier headerContentSupplier
) {
this.completionListener = completionListener;
this.cryptoHandler = cryptoHandler;
this.encryptedHeaderContentSupplier = headerContentSupplier;
}

@Override
public void onResponse(ReadContext readContext) {
try {
DecryptedReadContext<T, U> decryptedReadContext = new DecryptedReadContext<>(
readContext,
cryptoHandler,
encryptedHeaderContentSupplier
);
completionListener.onResponse(decryptedReadContext);
} catch (Exception e) {
onFailure(e);
}
}

@Override
public void onFailure(Exception e) {
completionListener.onFailure(e);
}
}

/**
* DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around
* the encrypted object
* @param <T> Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
* @param <U> Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
*/
static class DecryptedReadContext<T, U> extends ReadContext {

private final U cryptoContext;
private final CryptoHandler<T, U> cryptoHandler;
private final long fileSize;
private final U cryptoContext;
private Long blobSize;

public DecryptedReadContext(
ReadContext readContext,
Expand All @@ -175,22 +138,30 @@ public DecryptedReadContext(
try {
this.cryptoHandler = cryptoHandler;
this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(headerContentSupplier);
this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize());
} catch (IOException e) {
throw new RuntimeException(e);
throw new UncheckedIOException(e);
}
}

@Override
public long getBlobSize() {
return fileSize;
// initializes the value lazily
if (blobSize == null) {
this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize());
}
return this.blobSize;
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList());
}

/**
* Transforms an encrypted {@link InputStreamContainer} to a decrypted instance
* @param inputStreamContainer encrypted input stream container instance
* @return decrypted input stream container instance
*/
private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) {
long startOfStream = inputStreamContainer.getOffset();
long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1;
Expand All @@ -202,11 +173,9 @@ private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer in

long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0];
long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1;
return new InputStreamContainer(
decryptedStreamProvider.getDecryptedStreamProvider().apply(inputStreamContainer.getInputStream()),
adjustedPos,
adjustedLength
);
final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider()
.apply(inputStreamContainer.getInputStream());
return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException {
return cryptoHandler.createDecryptingStream(inputStream);
}

private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
return (start, end) -> {
byte[] buffer;
int length = (int) (end - start + 1);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.blobstore;

import org.opensearch.common.Randomness;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.function.UnaryOperator;

import org.mockito.Mockito;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase {

// Tests the happy path scenario for decrypting a read context
@SuppressWarnings("unchecked")
public void testReadBlobAsync() throws Exception {
String testBlobName = "testBlobName";
int size = 100;

// Mock objects needed for the test
AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
Object cryptoContext = mock(Object.class);
when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext);
when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size);
long[] adjustedRanges = { 0, size - 1 };
DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity());
when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider);

// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
readContextActionListener.onResponse(readContext);
return null;
}).when(blobContainer).readBlobAsync(eq(testBlobName), any());

AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);

// Assert results
ReadContext response = completionListener.getResponse();
assertEquals(0, completionListener.getFailureCount());
assertEquals(1, completionListener.getResponseCount());
assertNull(completionListener.getException());

assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext);
assertEquals(1, response.getNumberOfParts());
assertEquals(size, response.getBlobSize());

InputStreamContainer responseContainer = response.getPartStreams().get(0);
assertEquals(0, responseContainer.getOffset());
assertEquals(size, responseContainer.getContentLength());
assertEquals(100, responseContainer.getInputStream().available());
}

// Tests the exception scenario for decrypting a read context
@SuppressWarnings("unchecked")
public void testReadBlobAsyncException() throws Exception {
String testBlobName = "testBlobName";
int size = 100;

// Mock objects needed for the test
AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException());

// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
readContextActionListener.onResponse(readContext);
return null;
}).when(blobContainer).readBlobAsync(eq(testBlobName), any());

AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);

// Assert results
assertEquals(1, completionListener.getFailureCount());
assertEquals(0, completionListener.getResponseCount());
assertNull(completionListener.getResponse());
assertTrue(completionListener.getException() instanceof UncheckedIOException);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class ListenerTestUtils {
* CountingCompletionListener acts as a verification instance for wrapping listener based calls.
* Keeps track of the last response, failure and count of response and failure invocations.
*/
static class CountingCompletionListener<T> implements ActionListener<T> {
public static class CountingCompletionListener<T> implements ActionListener<T> {
private int responseCount;
private int failureCount;
private T response;
Expand Down

0 comments on commit 421d3c5

Please sign in to comment.