From 3af626380a5f42baaba21d8ba2c05d2e2b1d39d4 Mon Sep 17 00:00:00 2001 From: Vikas Bansal <43470111+vikasvb90@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:55:00 +0530 Subject: [PATCH] Added decryption pool for bounded decryption Signed-off-by: Vikas Bansal <43470111+vikasvb90@users.noreply.github.com> --- .../encryption/CryptoModulePlugin.java | 85 ++++++++++++++++++- .../encryption/frame/AwsCrypto.java | 10 ++- .../encryption/frame/CryptoInputStream.java | 41 ++++++++- .../encryption/frame/FrameCryptoHandler.java | 14 ++- .../encryption/CryptoModulePluginTests.java | 3 +- .../encryption/frame/CryptoTests.java | 11 ++- 6 files changed, 153 insertions(+), 11 deletions(-) diff --git a/modules/crypto/src/main/java/org/opensearch/encryption/CryptoModulePlugin.java b/modules/crypto/src/main/java/org/opensearch/encryption/CryptoModulePlugin.java index 9c4a760eca618..ed29da9d31d8a 100644 --- a/modules/crypto/src/main/java/org/opensearch/encryption/CryptoModulePlugin.java +++ b/modules/crypto/src/main/java/org/opensearch/encryption/CryptoModulePlugin.java @@ -8,18 +8,40 @@ package org.opensearch.encryption; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.crypto.CryptoHandler; import org.opensearch.common.crypto.MasterKeyProvider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.encryption.frame.AwsCrypto; import org.opensearch.encryption.frame.EncryptionMetadata; import org.opensearch.encryption.frame.FrameCryptoHandler; import org.opensearch.encryption.keyprovider.CryptoMasterKey; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; import org.opensearch.plugins.CryptoPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ExecutorBuilder; +import org.opensearch.threadpool.FixedExecutorBuilder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.watcher.ResourceWatcherService; import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.ParsedCiphertext; @@ -28,13 +50,26 @@ public class CryptoModulePlugin extends Plugin implements CryptoPlugin { + static final Setting BOUNDED_DECRYPTION_SETTING = Setting.boolSetting( + "crypto.bounded_decryption", + true, + Setting.Property.NodeScope + ); + private final int dataKeyCacheSize = 500; private final String algorithm = "ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256"; + private static final String DECRYPTION_POOL = "decryption"; // - Cache TTL and Jitter is used to decide the Crypto Cache TTL. // - Random number between: (TTL Jitter, TTL - Jitter) private final long dataKeyCacheTTL = TimeValue.timeValueDays(2).getMillis(); private static final long dataKeyCacheJitter = TimeUnit.MINUTES.toMillis(30); // - 30 minutes + private ExecutorService decryptionExecutor; + private final boolean boundedDecryptionEnabled; + + public CryptoModulePlugin(Settings settings) { + boundedDecryptionEnabled = BOUNDED_DECRYPTION_SETTING.get(settings); + } public CryptoHandler getOrCreateCryptoHandler( MasterKeyProvider keyProvider, @@ -50,6 +85,48 @@ public CryptoHandler getOrCreateCryptoHand return createCryptoHandler(algorithm, materialsManager, keyProvider, onClose); } + @Override + public List> getExecutorBuilders(Settings settings) { + if (boundedDecryptionEnabled == false) { + return new ArrayList<>(); + } + List> executorBuilders = new ArrayList<>(); + executorBuilders.add(new FixedExecutorBuilder(settings, DECRYPTION_POOL, capacity(settings), 10_000, DECRYPTION_POOL)); + return executorBuilders; + } + + private static int capacity(Settings settings) { + return boundedBy((allocatedProcessors(settings) + 7) / 8, 1, 2); + } + + private static int boundedBy(int value, int min, int max) { + return Math.min(max, Math.max(min, value)); + } + + private static int allocatedProcessors(Settings settings) { + return OpenSearchExecutors.allocatedProcessors(settings); + } + + @Override + public Collection createComponents( + final Client client, + final ClusterService clusterService, + final ThreadPool threadPool, + final ResourceWatcherService resourceWatcherService, + final ScriptService scriptService, + final NamedXContentRegistry xContentRegistry, + final Environment environment, + final NodeEnvironment nodeEnvironment, + final NamedWriteableRegistry namedWriteableRegistry, + final IndexNameExpressionResolver expressionResolver, + final Supplier repositoriesServiceSupplier + ) { + if (boundedDecryptionEnabled == true) { + this.decryptionExecutor = threadPool.executor(DECRYPTION_POOL); + } + return Collections.emptyList(); + } + private String getDataKeyAlgorithm(String algorithm) { if ("ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256".equals(algorithm)) { return CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256.getDataKeyAlgo(); @@ -77,7 +154,8 @@ CryptoHandler createCryptoHandler( return new FrameCryptoHandler( new AwsCrypto(materialsManager, CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256), masterKeyProvider.getEncryptionContext(), - onClose + onClose, + decryptionExecutor ); } throw new IllegalArgumentException("Unsupported algorithm: " + algorithm); @@ -97,4 +175,9 @@ CachingCryptoMaterialsManager createMaterialsManager(MasterKeyProvider masterKey .withMaxAge(masterKeyCacheTTL, TimeUnit.MILLISECONDS) .build(); } + + @Override + public List> getSettings() { + return List.of(BOUNDED_DECRYPTION_SETTING); + } } diff --git a/modules/crypto/src/main/java/org/opensearch/encryption/frame/AwsCrypto.java b/modules/crypto/src/main/java/org/opensearch/encryption/frame/AwsCrypto.java index 241b82db5273c..4d3cc2bdd7572 100644 --- a/modules/crypto/src/main/java/org/opensearch/encryption/frame/AwsCrypto.java +++ b/modules/crypto/src/main/java/org/opensearch/encryption/frame/AwsCrypto.java @@ -12,6 +12,7 @@ import java.io.InputStream; import java.util.Map; +import java.util.concurrent.ExecutorService; import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoAlgorithm; @@ -117,10 +118,10 @@ public int getTrailingSignatureSize(CryptoAlgorithm cryptoAlgorithm) { return EncryptionHandler.getAlgoTrailingLength(cryptoAlgorithm); } - public CryptoInputStream createDecryptingStream(final InputStream inputStream) { + public CryptoInputStream createDecryptingStream(final InputStream inputStream, ExecutorService decryptionExecutor) { final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager); - return new CryptoInputStream<>(inputStream, cryptoHandler, true); + return new CryptoInputStream<>(inputStream, cryptoHandler, true, decryptionExecutor); } public CryptoInputStream createDecryptingStream( @@ -128,11 +129,12 @@ public CryptoInputStream createDecryptingStream( final long size, final ParsedCiphertext parsedCiphertext, final int frameStartNum, - boolean isLastPart + boolean isLastPart, + ExecutorService decryptionExecutor ) { final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, parsedCiphertext, frameStartNum); - CryptoInputStream cryptoInputStream = new CryptoInputStream<>(inputStream, cryptoHandler, isLastPart); + CryptoInputStream cryptoInputStream = new CryptoInputStream<>(inputStream, cryptoHandler, isLastPart, decryptionExecutor); cryptoInputStream.setMaxInputLength(size); return cryptoInputStream; } diff --git a/modules/crypto/src/main/java/org/opensearch/encryption/frame/CryptoInputStream.java b/modules/crypto/src/main/java/org/opensearch/encryption/frame/CryptoInputStream.java index e8d51fb2440d5..d71d501f972ac 100644 --- a/modules/crypto/src/main/java/org/opensearch/encryption/frame/CryptoInputStream.java +++ b/modules/crypto/src/main/java/org/opensearch/encryption/frame/CryptoInputStream.java @@ -21,8 +21,14 @@ package org.opensearch.encryption.frame; +import org.opensearch.ExceptionsHelper; + import java.io.IOException; import java.io.InputStream; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.MasterKey; @@ -72,6 +78,7 @@ public class CryptoInputStream> extends InputStream { private boolean hasFinalCalled_; private boolean hasProcessBytesCalled_; private final boolean isLastPart_; + private final ExecutorService cryptoExecutor; /** * Constructs a CryptoInputStream that wraps the provided InputStream object. It performs @@ -88,6 +95,19 @@ public class CryptoInputStream> extends InputStream { inputStream_ = Utils.assertNonNull(inputStream, "inputStream"); cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler"); isLastPart_ = isLastPart; + cryptoExecutor = null; + } + + CryptoInputStream( + final InputStream inputStream, + final MessageCryptoHandler cryptoHandler, + boolean isLastPart, + ExecutorService cryptoExecutor + ) { + inputStream_ = Utils.assertNonNull(inputStream, "inputStream"); + cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler"); + isLastPart_ = isLastPart; + this.cryptoExecutor = cryptoExecutor; } /** @@ -166,7 +186,26 @@ public int read(final byte[] b, final int off, final int len) throws IllegalArgu // Block until a byte is read or end of stream in the underlying // stream is reached. while (newBytesLen == 0) { - newBytesLen = fillOutBytes(); + if (cryptoExecutor != null) { + Callable cryptoCallable = this::fillOutBytes; + Future cryptoFuture = cryptoExecutor.submit(cryptoCallable); + try { + newBytesLen = cryptoFuture.get(); + } catch (ExecutionException e) { + Throwable t = ExceptionsHelper.unwrap(e, BadCiphertextException.class, IllegalArgumentException.class); + if (t instanceof BadCiphertextException) { + throw (BadCiphertextException) t; + } else if (t instanceof IllegalArgumentException) { + throw (IllegalArgumentException) t; + } else { + throw new RuntimeException(e); + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } else { + newBytesLen = fillOutBytes(); + } } if (newBytesLen < 0) { return -1; diff --git a/modules/crypto/src/main/java/org/opensearch/encryption/frame/FrameCryptoHandler.java b/modules/crypto/src/main/java/org/opensearch/encryption/frame/FrameCryptoHandler.java index dee821e5cdf2d..3149bf5294ba3 100644 --- a/modules/crypto/src/main/java/org/opensearch/encryption/frame/FrameCryptoHandler.java +++ b/modules/crypto/src/main/java/org/opensearch/encryption/frame/FrameCryptoHandler.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.Map; +import java.util.concurrent.ExecutorService; import com.amazonaws.encryptionsdk.ParsedCiphertext; @@ -24,14 +25,21 @@ public class FrameCryptoHandler implements CryptoHandler encryptionContext; private final Runnable onClose; + private final ExecutorService decryptionExecutor; // package private for tests private final int FRAME_SIZE = 8 * 1024; - public FrameCryptoHandler(AwsCrypto awsCrypto, Map encryptionContext, Runnable onClose) { + public FrameCryptoHandler( + AwsCrypto awsCrypto, + Map encryptionContext, + Runnable onClose, + ExecutorService decryptionExecutor + ) { this.awsCrypto = awsCrypto; this.encryptionContext = encryptionContext; this.onClose = onClose; + this.decryptionExecutor = decryptionExecutor; } public int getFrameSize() { @@ -148,7 +156,7 @@ public ParsedCiphertext loadEncryptionMetadata(EncryptedHeaderContentSupplier en * @return Decrypting wrapper stream */ public InputStream createDecryptingStream(InputStream encryptedStream) { - return awsCrypto.createDecryptingStream(encryptedStream); + return awsCrypto.createDecryptingStream(encryptedStream, decryptionExecutor); } /** @@ -173,7 +181,7 @@ private InputStream createBlockDecryptionStream( } int frameStartNumber = (int) (startPosOfRawContent / parsedCiphertext.getFrameLength()) + 1; long encryptedSize = encryptedRange[1] - encryptedRange[0] + 1; - return awsCrypto.createDecryptingStream(inputStream, encryptedSize, parsedCiphertext, frameStartNumber, false); + return awsCrypto.createDecryptingStream(inputStream, encryptedSize, parsedCiphertext, frameStartNumber, false, decryptionExecutor); } /** diff --git a/modules/crypto/src/test/java/org/opensearch/encryption/CryptoModulePluginTests.java b/modules/crypto/src/test/java/org/opensearch/encryption/CryptoModulePluginTests.java index dd83acd8feb3a..d77ce257ec359 100644 --- a/modules/crypto/src/test/java/org/opensearch/encryption/CryptoModulePluginTests.java +++ b/modules/crypto/src/test/java/org/opensearch/encryption/CryptoModulePluginTests.java @@ -10,6 +10,7 @@ import org.opensearch.common.crypto.CryptoHandler; import org.opensearch.common.crypto.MasterKeyProvider; +import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; import java.util.Collections; @@ -21,7 +22,7 @@ public class CryptoModulePluginTests extends OpenSearchTestCase { - private final CryptoModulePlugin cryptoModulePlugin = new CryptoModulePlugin(); + private final CryptoModulePlugin cryptoModulePlugin = new CryptoModulePlugin(Settings.EMPTY); public void testGetOrCreateCryptoHandler() { MasterKeyProvider mockKeyProvider = mock(MasterKeyProvider.class); diff --git a/modules/crypto/src/test/java/org/opensearch/encryption/frame/CryptoTests.java b/modules/crypto/src/test/java/org/opensearch/encryption/frame/CryptoTests.java index c0277f29f527c..985dda465094a 100644 --- a/modules/crypto/src/test/java/org/opensearch/encryption/frame/CryptoTests.java +++ b/modules/crypto/src/test/java/org/opensearch/encryption/frame/CryptoTests.java @@ -13,6 +13,7 @@ import org.opensearch.common.io.InputStreamContainer; import org.opensearch.encryption.MockKeyProvider; import org.opensearch.test.OpenSearchTestCase; +import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; @@ -26,6 +27,8 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.HashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.zip.CRC32; @@ -44,12 +47,13 @@ public class CryptoTests extends OpenSearchTestCase { private static FrameCryptoHandler frameCryptoHandler; private static FrameCryptoHandler frameCryptoHandlerTrailingAlgo; + private static final ExecutorService executorService = Executors.newFixedThreadPool(2); static class CustomFrameCryptoHandlerTest extends FrameCryptoHandler { private final int frameSize; CustomFrameCryptoHandlerTest(AwsCrypto awsCrypto, HashMap config, int frameSize) { - super(awsCrypto, config, () -> {}); + super(awsCrypto, config, () -> {}, executorService); this.frameSize = frameSize; } @@ -59,6 +63,11 @@ public int getFrameSize() { } } + @AfterClass + public static void close() { + executorService.shutdown(); + } + @Before public void setupResources() { frameCryptoHandler = new CustomFrameCryptoHandlerTest(