Skip to content

Commit

Permalink
Added decryption pool for bounded decryption
Browse files Browse the repository at this point in the history
  • Loading branch information
vikasvb90 committed Sep 15, 2023
1 parent 91a5748 commit 46738d0
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.encryption;

import org.opensearch.test.OpenSearchIntegTestCase;

public abstract class CryptoModuleIntegTest extends OpenSearchIntegTestCase {

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,40 @@

package org.opensearch.encryption;

import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.MasterKeyProvider;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.encryption.frame.AwsCrypto;
import org.opensearch.encryption.frame.EncryptionMetadata;
import org.opensearch.encryption.frame.FrameCryptoHandler;
import org.opensearch.encryption.keyprovider.CryptoMasterKey;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.plugins.CryptoPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import com.amazonaws.encryptionsdk.CryptoAlgorithm;
import com.amazonaws.encryptionsdk.ParsedCiphertext;
Expand All @@ -28,13 +50,26 @@

public class CryptoModulePlugin extends Plugin implements CryptoPlugin<EncryptionMetadata, ParsedCiphertext> {

static final Setting<Boolean> BOUNDED_DECRYPTION_SETTING = Setting.boolSetting(
"crypto.bounded_decryption",
true,
Setting.Property.NodeScope
);

private final int dataKeyCacheSize = 500;
private final String algorithm = "ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256";
private static final String DECRYPTION_POOL = "decryption";

// - Cache TTL and Jitter is used to decide the Crypto Cache TTL.
// - Random number between: (TTL Jitter, TTL - Jitter)
private final long dataKeyCacheTTL = TimeValue.timeValueDays(2).getMillis();
private static final long dataKeyCacheJitter = TimeUnit.MINUTES.toMillis(30); // - 30 minutes
private ExecutorService decryptionExecutor;
private final boolean boundedDecryptionEnabled;

public CryptoModulePlugin(Settings settings) {
boundedDecryptionEnabled = BOUNDED_DECRYPTION_SETTING.get(settings);
}

public CryptoHandler<EncryptionMetadata, ParsedCiphertext> getOrCreateCryptoHandler(
MasterKeyProvider keyProvider,
Expand All @@ -50,6 +85,48 @@ public CryptoHandler<EncryptionMetadata, ParsedCiphertext> getOrCreateCryptoHand
return createCryptoHandler(algorithm, materialsManager, keyProvider, onClose);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
if (boundedDecryptionEnabled == false) {
return new ArrayList<>();
}
List<ExecutorBuilder<?>> executorBuilders = new ArrayList<>();
executorBuilders.add(new FixedExecutorBuilder(settings, DECRYPTION_POOL, capacity(settings), 10_000, DECRYPTION_POOL));
return executorBuilders;
}

private static int capacity(Settings settings) {
return boundedBy((allocatedProcessors(settings) + 7) / 8, 1, 2);
}

private static int boundedBy(int value, int min, int max) {
return Math.min(max, Math.max(min, value));
}

private static int allocatedProcessors(Settings settings) {
return OpenSearchExecutors.allocatedProcessors(settings);
}

@Override
public Collection<Object> createComponents(
final Client client,
final ClusterService clusterService,
final ThreadPool threadPool,
final ResourceWatcherService resourceWatcherService,
final ScriptService scriptService,
final NamedXContentRegistry xContentRegistry,
final Environment environment,
final NodeEnvironment nodeEnvironment,
final NamedWriteableRegistry namedWriteableRegistry,
final IndexNameExpressionResolver expressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
if (boundedDecryptionEnabled == true) {
this.decryptionExecutor = threadPool.executor(DECRYPTION_POOL);
}
return Collections.emptyList();
}

private String getDataKeyAlgorithm(String algorithm) {
if ("ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256".equals(algorithm)) {
return CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256.getDataKeyAlgo();
Expand Down Expand Up @@ -77,7 +154,8 @@ CryptoHandler<EncryptionMetadata, ParsedCiphertext> createCryptoHandler(
return new FrameCryptoHandler(
new AwsCrypto(materialsManager, CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256),
masterKeyProvider.getEncryptionContext(),
onClose
onClose,
decryptionExecutor
);
}
throw new IllegalArgumentException("Unsupported algorithm: " + algorithm);
Expand All @@ -97,4 +175,9 @@ CachingCryptoMaterialsManager createMaterialsManager(MasterKeyProvider masterKey
.withMaxAge(masterKeyCacheTTL, TimeUnit.MILLISECONDS)
.build();
}

@Override
public List<Setting<?>> getSettings() {
return List.of(BOUNDED_DECRYPTION_SETTING);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import com.amazonaws.encryptionsdk.CommitmentPolicy;
import com.amazonaws.encryptionsdk.CryptoAlgorithm;
Expand Down Expand Up @@ -117,22 +118,23 @@ public int getTrailingSignatureSize(CryptoAlgorithm cryptoAlgorithm) {
return EncryptionHandler.getAlgoTrailingLength(cryptoAlgorithm);
}

public CryptoInputStream<?> createDecryptingStream(final InputStream inputStream) {
public CryptoInputStream<?> createDecryptingStream(final InputStream inputStream, ExecutorService decryptionExecutor) {

final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager);
return new CryptoInputStream<>(inputStream, cryptoHandler, true);
return new CryptoInputStream<>(inputStream, cryptoHandler, true, decryptionExecutor);
}

public CryptoInputStream<?> createDecryptingStream(
final InputStream inputStream,
final long size,
final ParsedCiphertext parsedCiphertext,
final int frameStartNum,
boolean isLastPart
boolean isLastPart,
ExecutorService decryptionExecutor
) {

final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, parsedCiphertext, frameStartNum);
CryptoInputStream<?> cryptoInputStream = new CryptoInputStream<>(inputStream, cryptoHandler, isLastPart);
CryptoInputStream<?> cryptoInputStream = new CryptoInputStream<>(inputStream, cryptoHandler, isLastPart, decryptionExecutor);
cryptoInputStream.setMaxInputLength(size);
return cryptoInputStream;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.encryption.frame;

import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousByteChannel;
import java.nio.channels.CompletionHandler;

import com.amazonaws.encryptionsdk.internal.MessageCryptoHandler;

/**
* Performs async
*/
public class CryptoAsyncByteChannelWriter {

private final MessageCryptoHandler cryptoHandler;
private final AsynchronousByteChannel asyncByteChannel;

public CryptoAsyncByteChannelWriter(AsynchronousByteChannel asyncByteChannel, MessageCryptoHandler cryptoHandler) {
this.asyncByteChannel = asyncByteChannel;
this.cryptoHandler = cryptoHandler;
}

/**
* Create a ByteBuffer and publishes response to handler. If sufficient bytes were received for decryption then
* response with decrypted bytes is published otherwise 0 bytes are published in result on handler.
*/
private void write(final byte[] b, final int off, final int len, CompletionHandler<Integer, ByteBuffer> handler) {
try {
if (b == null) {
throw new IllegalArgumentException("b cannot be null");
}

if (len < 0 || off < 0) {
throw new IllegalArgumentException(String.format("Invalid values for offset: %d and length: %d", off, len));
}

final int outLen = cryptoHandler.estimatePartialOutputSize(len);
final byte[] outBytes = new byte[outLen];

int bytesWritten = cryptoHandler.processBytes(b, off, len, outBytes, 0).getBytesWritten();
ByteBuffer byteBuffer = ByteBuffer.wrap(outBytes, 0, bytesWritten);
if (bytesWritten > 0) {
asyncByteChannel.write(byteBuffer, byteBuffer, handler);
} else {
handler.completed(0, byteBuffer);
}
} catch (Exception ex) {
handler.failed(ex, null);
}
}

public void setMaxInputLength(long size) {
cryptoHandler.setMaxInputLength(size);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@

package org.opensearch.encryption.frame;

import org.opensearch.ExceptionsHelper;

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import com.amazonaws.encryptionsdk.AwsCrypto;
import com.amazonaws.encryptionsdk.MasterKey;
Expand Down Expand Up @@ -72,6 +78,7 @@ public class CryptoInputStream<K extends MasterKey<K>> extends InputStream {
private boolean hasFinalCalled_;
private boolean hasProcessBytesCalled_;
private final boolean isLastPart_;
private final ExecutorService cryptoExecutor;

/**
* Constructs a CryptoInputStream that wraps the provided InputStream object. It performs
Expand All @@ -88,6 +95,19 @@ public class CryptoInputStream<K extends MasterKey<K>> extends InputStream {
inputStream_ = Utils.assertNonNull(inputStream, "inputStream");
cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler");
isLastPart_ = isLastPart;
cryptoExecutor = null;
}

CryptoInputStream(
final InputStream inputStream,
final MessageCryptoHandler cryptoHandler,
boolean isLastPart,
ExecutorService cryptoExecutor
) {
inputStream_ = Utils.assertNonNull(inputStream, "inputStream");
cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler");
isLastPart_ = isLastPart;
this.cryptoExecutor = cryptoExecutor;
}

/**
Expand Down Expand Up @@ -166,7 +186,26 @@ public int read(final byte[] b, final int off, final int len) throws IllegalArgu
// Block until a byte is read or end of stream in the underlying
// stream is reached.
while (newBytesLen == 0) {
newBytesLen = fillOutBytes();
if (cryptoExecutor != null) {
Callable<Integer> cryptoCallable = this::fillOutBytes;
Future<Integer> cryptoFuture = cryptoExecutor.submit(cryptoCallable);
try {
newBytesLen = cryptoFuture.get();
} catch (ExecutionException e) {
Throwable t = ExceptionsHelper.unwrap(e, BadCiphertextException.class, IllegalArgumentException.class);
if (t instanceof BadCiphertextException) {
throw (BadCiphertextException) t;
} else if (t instanceof IllegalArgumentException) {
throw (IllegalArgumentException) t;
} else {
throw new RuntimeException(e);
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
} else {
newBytesLen = fillOutBytes();
}
}
if (newBytesLen < 0) {
return -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,29 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import com.amazonaws.encryptionsdk.ParsedCiphertext;

public class FrameCryptoHandler implements CryptoHandler<EncryptionMetadata, ParsedCiphertext> {
private final AwsCrypto awsCrypto;
private final Map<String, String> encryptionContext;
private final Runnable onClose;
private final ExecutorService decryptionExecutor;

// package private for tests
private final int FRAME_SIZE = 8 * 1024;

public FrameCryptoHandler(AwsCrypto awsCrypto, Map<String, String> encryptionContext, Runnable onClose) {
public FrameCryptoHandler(
AwsCrypto awsCrypto,
Map<String, String> encryptionContext,
Runnable onClose,
ExecutorService decryptionExecutor
) {
this.awsCrypto = awsCrypto;
this.encryptionContext = encryptionContext;
this.onClose = onClose;
this.decryptionExecutor = decryptionExecutor;
}

public int getFrameSize() {
Expand Down Expand Up @@ -148,7 +156,7 @@ public ParsedCiphertext loadEncryptionMetadata(EncryptedHeaderContentSupplier en
* @return Decrypting wrapper stream
*/
public InputStream createDecryptingStream(InputStream encryptedStream) {
return awsCrypto.createDecryptingStream(encryptedStream);
return awsCrypto.createDecryptingStream(encryptedStream, decryptionExecutor);
}

/**
Expand All @@ -173,7 +181,7 @@ private InputStream createBlockDecryptionStream(
}
int frameStartNumber = (int) (startPosOfRawContent / parsedCiphertext.getFrameLength()) + 1;
long encryptedSize = encryptedRange[1] - encryptedRange[0] + 1;
return awsCrypto.createDecryptingStream(inputStream, encryptedSize, parsedCiphertext, frameStartNumber, false);
return awsCrypto.createDecryptingStream(inputStream, encryptedSize, parsedCiphertext, frameStartNumber, false, decryptionExecutor);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.MasterKeyProvider;
import org.opensearch.common.settings.Settings;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
Expand All @@ -21,7 +22,7 @@

public class CryptoModulePluginTests extends OpenSearchTestCase {

private final CryptoModulePlugin cryptoModulePlugin = new CryptoModulePlugin();
private final CryptoModulePlugin cryptoModulePlugin = new CryptoModulePlugin(Settings.EMPTY);

public void testGetOrCreateCryptoHandler() {
MasterKeyProvider mockKeyProvider = mock(MasterKeyProvider.class);
Expand Down
Loading

0 comments on commit 46738d0

Please sign in to comment.