From 9d7990bfe1f5d3509de4d7066a6f8c8a73c22814 Mon Sep 17 00:00:00 2001 From: Matteo Gazzetta Date: Fri, 11 Sep 2020 12:26:40 +0200 Subject: [PATCH] Encryption client side (#14) --- docker/docker-compose.yaml | 6 +- perf-test.sh | 2 +- plot_results.py | 81 ++--- pom.xml | 59 +++- rules.xml | 15 + .../DecryptingConsumerInterceptor.java | 115 +++---- .../EncryptingProducerInterceptor.java | 108 ++++--- .../EncryptorAesGcm.java | 75 +++++ .../SelfExpiringHashMap.java | 304 ++++++++++++++++++ .../SelfExpiringMap.java | 57 ++++ .../VaultFactory.java | 30 +- .../TransitInterceptorTest.java | 2 +- .../util/KafkaHelper.java | 10 +- .../util/VaultContainer.java | 77 +---- 14 files changed, 703 insertions(+), 238 deletions(-) create mode 100644 rules.xml create mode 100644 src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptorAesGcm.java create mode 100644 src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringHashMap.java create mode 100644 src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringMap.java diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e441500..e3043e8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: zookeeper: - image: confluentinc/cp-zookeeper:5.5.0 + image: confluentinc/cp-zookeeper:5.5.1 ports: - 2181:2181 environment: @@ -10,7 +10,7 @@ services: ZOOKEEPER_TICK_TIME: 2000 kafka: - image: confluentinc/cp-kafka:5.5.0 + image: confluentinc/cp-kafka:5.5.1 ports: - 9092:9092 - 29092:29092 @@ -25,7 +25,7 @@ services: - zookeeper vault: - image: vault:1.4.2 + image: vault:1.5.3 restart: always volumes: - ./vault/data:/vault/file diff --git a/perf-test.sh b/perf-test.sh index 5088f95..e241b22 100755 --- a/perf-test.sh +++ b/perf-test.sh @@ -11,7 +11,7 @@ docker-compose -f docker/docker-compose.yaml up -d echo "Enable Vault Transit" docker exec -e VAULT_TOKEN="${VAULT_TOKEN}" docker_vault_1 vault secrets enable transit || true -SIZE_IN_BYTES=(10 100 1000 10000 100000) +SIZE_IN_BYTES=(10 100 500 1000 10000 100000) NUM_RECORDS=50000 TEST_RUN=$((1 + RANDOM % 10)) diff --git a/plot_results.py b/plot_results.py index 0df8e73..45f8fe6 100644 --- a/plot_results.py +++ b/plot_results.py @@ -1,43 +1,52 @@ -import numpy as np -import matplotlib.pyplot as plt -import os import glob +import matplotlib.pyplot as plt +import numpy as np + + def plot_kafka_output(directory, kind): - numMsg = "" - for TYPE in ["baseline", "interceptor"]: - X = np.empty(0, dtype=float) - Y = np.empty(0, dtype=float) - print(TYPE) - for filename in sorted(glob.iglob(f"{directory}/{kind}-{TYPE}*.txt")): - size = filename.split('-')[-1].split('.')[0] - numMsg = filename.split('-')[-2] - plt.title(f"{kind} perf {numMsg} msgs") - plt.xlabel("Message Size [byte]") - X = np.append(X, size) - print(filename) - with open(filename, 'r') as f: - lines = f.read().splitlines() - last_line = lines[-1] - throughput = "0" - if kind == "producer": - throughput = last_line.split(',')[1].split(' ')[1] - plt.ylabel("records/sec") - else: - throughput = last_line.split(',')[3] - plt.ylabel("MB/sec") - Y = np.append(Y, round(float(throughput), 2)) - print(X) - print(Y) - plt.scatter(X, Y, label=f"{TYPE}") - - plt.legend() - plt.savefig(f"{directory}/{kind}-{numMsg}.png") - plt.clf() + numMsg = "" + for TYPE in ["baseline", "interceptor"]: + X = np.empty(0, dtype=float) + Y = np.empty(0, dtype=float) + print(TYPE) + #grab last 4 characters of the file name: + def message_size(x): + print(x) + print(x.split("-")[-1].rsplit( ".", 1 )[ 0 ]) + return(int(x.split("-")[-1].rsplit( ".", 1 )[ 0 ])) + file_list = glob.iglob(f"{directory}/{kind}-{TYPE}*.txt") + for filename in sorted(file_list, key = message_size): + size = filename.split('-')[-1].split('.')[0] + numMsg = filename.split('-')[-2] + plt.title(f"{kind} perf {numMsg} msgs") + plt.xlabel("Message Size [byte]") + X = np.append(X, size) + print(filename) + with open(filename, 'r') as f: + lines = f.read().splitlines() + last_line = lines[-1] + throughput = "0" + if kind == "producer": + throughput = last_line.split(',')[1].split(' ')[1] + plt.ylabel("records/sec") + else: + throughput = last_line.split(',')[3] + plt.ylabel("MB/sec") + Y = np.append(Y, round(float(throughput), 2)) + print(X) + print(Y) + plt.scatter(X, Y, label=f"{TYPE}") + + plt.legend() + plt.savefig(f"{directory}/{kind}-{numMsg}.png") + plt.clf() + def main(): - plot_kafka_output("results", "producer") - plot_kafka_output("results", "consumer") + plot_kafka_output("results", "producer") + plot_kafka_output("results", "consumer") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pom.xml b/pom.xml index 7c2bd83..7b91d7a 100644 --- a/pom.xml +++ b/pom.xml @@ -15,9 +15,9 @@ ${project.basedir} - 2.5.0 + 2.5.1 - 5.5.2 + 5.6.2 1.14.2 3.8.1 @@ -31,10 +31,20 @@ ${kafka.version} provided - + + + org.springframework.vault + spring-vault-core + 2.2.2.RELEASE + + + org.springframework.security + spring-security-crypto + 5.3.4.RELEASE org.junit.jupiter @@ -66,6 +76,12 @@ 1.19.0 test + + org.json + json + 20171018 + test + @@ -100,6 +116,43 @@ + + org.apache.maven.plugins + maven-enforcer-plugin + 3.0.0-M3 + + + enforce-maven + + enforce + + + + + 3.1.0 + + + + + + + + org.codehaus.mojo + versions-maven-plugin + 2.8.1 + + file:///${project.basedir}/rules.xml + + + + compile + + display-dependency-updates + display-plugin-updates + + + + diff --git a/rules.xml b/rules.xml new file mode 100644 index 0000000..af9e5ef --- /dev/null +++ b/rules.xml @@ -0,0 +1,15 @@ + + + + + (?i).*Alpha(?:-?\d+)? + (?i).*a(?:-?\d+)? + (?i).*Beta(?:-?\d+)? + (?i).*-B(?:-?\d+)? + (?i).*RC(?:-?\d+)? + (?i).*CR(?:-?\d+)? + (?i).*M(?:-?\d+)? + + + + diff --git a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/DecryptingConsumerInterceptor.java b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/DecryptingConsumerInterceptor.java index 7e9736e..676011c 100644 --- a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/DecryptingConsumerInterceptor.java +++ b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/DecryptingConsumerInterceptor.java @@ -1,20 +1,20 @@ package it.bitrock.kafkavaulttransitinterceptor; -import com.bettercloud.vault.Vault; -import com.bettercloud.vault.VaultException; -import com.bettercloud.vault.json.JsonArray; -import com.bettercloud.vault.json.JsonObject; -import com.bettercloud.vault.json.JsonValue; -import com.bettercloud.vault.response.LogicalResponse; import org.apache.kafka.clients.consumer.ConsumerInterceptor; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.serialization.Deserializer; - +import org.springframework.vault.core.VaultTransitOperations; +import org.springframework.vault.support.RawTransitKey; +import org.springframework.vault.support.TransitKeyType; +import org.springframework.vault.support.VaultTransitKey; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; import java.util.*; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import static it.bitrock.kafkavaulttransitinterceptor.TransitConfiguration.*; @@ -23,10 +23,12 @@ public class DecryptingConsumerInterceptor implements ConsumerInterceptor { TransitConfiguration configuration; - Vault vault; + VaultTransitOperations transit; String mount; String key; Deserializer valueDeserializer; + SelfExpiringMap map = new SelfExpiringHashMap(); + public ConsumerRecords onConsume(ConsumerRecords records) { if (records.isEmpty()) return records; @@ -35,59 +37,56 @@ public ConsumerRecords onConsume(ConsumerRecords records) { for (TopicPartition partition : records.partitions()) { List> decryptedRecordsPartition = records.records(partition).stream() - .collect(groupingBy(record -> getEncryptionKey((ConsumerRecord) record))).values() + .collect(groupingBy(record -> getEncryptionKeyName(record))).values() .stream().flatMap(recordsPerKey -> processBulkDecrypt(recordsPerKey).stream()) .collect(Collectors.toList()); decryptedRecordsMap.put(partition, decryptedRecordsPartition); } - ; return new ConsumerRecords(decryptedRecordsMap); } private List> processBulkDecrypt(List> records) { - JsonArray batch = new JsonArray(); - String key = getEncryptionKey(records.get(0)); - for (Object text : records.stream().map(ConsumerRecord::value).toArray()) { - if (text instanceof byte[]) { - batch.add(new JsonObject().add("ciphertext", new String((byte[]) text))); - } else { - batch.add(new JsonObject().add("ciphertext", (String) text)); - } - } - LogicalResponse response = null; - try { - response = vault.logical().write(String.format("%s/decrypt/%s", mount, key), - Collections.singletonMap("batch_input", batch)); - if (response.getRestResponse().getStatus() == 200) { - List plainTexts = getBatchResults(response) - .stream().map(this::getPlaintextData) - .collect(Collectors.toList()); - AtomicInteger index = new AtomicInteger(0); - return records.stream() - .map(record -> - new ConsumerRecord(record.topic(), - record.partition(), - record.offset(), - record.timestamp(), - record.timestampType(), - record.checksum(), - record.serializedKeySize(), - record.serializedValueSize(), - record.key(), - valueDeserializer.deserialize(record.topic(), plainTexts.get(index.getAndIncrement())), - record.headers(), - record.leaderEpoch())) - .collect(Collectors.toList()); + String keyName = getEncryptionKeyName(records.get(0)); + return records.stream().map( + record -> doTheMagic(record, keyName) + ).filter(Objects::nonNull).collect(Collectors.toList()); + } + + private ConsumerRecord doTheMagic(ConsumerRecord record, String keyName) { + byte[] ciphertext = (byte[]) record.value(); + int encryptionVersion = getEncryptionKeyVersion(record); + String keyCacheKey = keyName.concat("-").concat(String.valueOf(encryptionVersion)); + byte[] decodedKey = map.get(keyCacheKey); + if (decodedKey == null) { + VaultTransitKey vaultTransitKey = transit.getKey(keyName); + int minDecryptionVersion = vaultTransitKey.getMinDecryptionVersion(); + if (minDecryptionVersion <= encryptionVersion) { + RawTransitKey vaultKey = transit.exportKey(keyName, TransitKeyType.ENCRYPTION_KEY); + // decode the base64 encoded string + decodedKey = Base64.getDecoder().decode(vaultKey.getKeys().get(String.valueOf(encryptionVersion))); + map.put(keyCacheKey, decodedKey, 5 * 60000); } else { - LOGGER.error(String.format("Decryption failed with status code: %d body: %s", response.getRestResponse().getStatus(), new String(response.getRestResponse().getBody()))); - throw new RuntimeException("Decryption failed"); + return null; } - } catch (VaultException e) { - LOGGER.error("Failed to decrypt bulk records Vault", e); - throw new RuntimeException("Failed to decrypt bulk records Vault"); } + + // rebuild keyName using SecretKeySpec + SecretKey originalKey = new SecretKeySpec(decodedKey, 0, decodedKey.length, "AES"); + + return new ConsumerRecord(record.topic(), + record.partition(), + record.offset(), + record.timestamp(), + record.timestampType(), + record.checksum(), + record.serializedKeySize(), + record.serializedValueSize(), + record.key(), + valueDeserializer.deserialize(record.topic(), EncryptorAesGcm.decryptWithPrefixIV(ciphertext, originalKey)), + record.headers(), + record.leaderEpoch()); } public void close() { @@ -105,20 +104,24 @@ public void configure(Map configs) { } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { LOGGER.error("Failed to create instance of interceptor.value.deserializer", e); } - vault = new VaultFactory(configuration).vault; + try { + transit = new VaultFactory(configuration).transit; + } catch (Exception ignored) { + LOGGER.error("Failed to create Vault Client"); + } mount = configuration.getStringOrDefault(TRANSIT_MOUNT_CONFIG, TRANSIT_MOUNT_DEFAULT); key = configuration.getStringOrDefault(TRANSIT_KEY_CONFIG, TRANSIT_KEY_DEFAULT); } - private String getEncryptionKey(ConsumerRecord record) { - return new String(record.headers().headers("x-vault-encryption-key").iterator().next().value()); + private String getEncryptionKeyName(ConsumerRecord record) { + return new String(record.headers().headers("x-vault-encryption-key-name").iterator().next().value()); } - private byte[] getPlaintextData(JsonValue it) { - return Base64.getDecoder().decode(it.asObject().get("plaintext").asString()); + private int getEncryptionKeyVersion(ConsumerRecord record) { + return fromByteArray(record.headers().headers("x-vault-encryption-key-version").iterator().next().value()); } - private List getBatchResults(LogicalResponse response) { - return response.getDataObject().get("batch_results").asArray().values(); + int fromByteArray(byte[] bytes) { + return ByteBuffer.wrap(bytes).getInt(); } } diff --git a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptingProducerInterceptor.java b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptingProducerInterceptor.java index 94786f9..6444b3c 100644 --- a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptingProducerInterceptor.java +++ b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptingProducerInterceptor.java @@ -1,17 +1,20 @@ package it.bitrock.kafkavaulttransitinterceptor; -import com.bettercloud.vault.Vault; -import com.bettercloud.vault.VaultException; -import com.bettercloud.vault.api.Logical; -import com.bettercloud.vault.response.LogicalResponse; import org.apache.kafka.clients.producer.ProducerInterceptor; import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.clients.producer.RecordMetadata; import org.apache.kafka.common.header.Headers; import org.apache.kafka.common.serialization.Serializer; +import org.springframework.vault.core.VaultTransitOperations; +import org.springframework.vault.support.RawTransitKey; +import org.springframework.vault.support.TransitKeyType; +import org.springframework.vault.support.VaultTransitKey; +import org.springframework.vault.support.VaultTransitKeyCreationRequest; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; import java.util.Base64; -import java.util.Collections; import java.util.Map; import static it.bitrock.kafkavaulttransitinterceptor.TransitConfiguration.*; @@ -19,63 +22,66 @@ public class EncryptingProducerInterceptor implements ProducerInterceptor { TransitConfiguration configuration; - Vault vault; + VaultTransitOperations transit; String mount; String defaultKey; Serializer valueSerializer; + SelfExpiringMap map = new SelfExpiringHashMap(); public ProducerRecord onSend(ProducerRecord record) { if (record.value() == null) return record; - LogicalResponse vaultResponse = null; - String encryptionKey = extractKeyOrElse(record.key(), defaultKey); - String encryptPath = String.format("%s/encrypt/%s", mount, encryptionKey); - try { - String base64value = getBase64value(record); - vaultResponse = vault.logical().write( - encryptPath, - Collections.singletonMap("plaintext", base64value)); - if (vaultResponse.getRestResponse().getStatus() == 200) { - String encryptedData = vaultResponse.getData().get("ciphertext"); - Headers headers = record.headers(); - headers.add("x-vault-encryption-key", encryptionKey.getBytes()); - if (record.value() instanceof byte[]) { - return new ProducerRecord( - record.topic(), - record.partition(), - record.timestamp(), - record.key(), - encryptedData.getBytes(), - headers - ); - } - else { - return new ProducerRecord( - record.topic(), - record.partition(), - record.timestamp(), - record.key(), - encryptedData, - headers - ); - } + String encryptionKeyName = extractKeyOrElse(record.key(), defaultKey); + int encryptionKeyVersion = 1; + // VaultBytesEncryptor encryptor = new VaultBytesEncryptor(this.transit, encryptionKeyName); - } else { - LOGGER.error(String.format("Encryption failed with status code: %d body: %s", vaultResponse.getRestResponse().getStatus(), new String(vaultResponse.getRestResponse().getBody()))); - throw new RuntimeException("Encryption failed"); + byte[] decodedKey = map.get(encryptionKeyName); + if (decodedKey == null) { + VaultTransitKey vaultTransitKey = transit.getKey(encryptionKeyName); + if (vaultTransitKey == null) { + transit.createKey(encryptionKeyName, VaultTransitKeyCreationRequest.builder().exportable(true).build()); + vaultTransitKey = transit.getKey(encryptionKeyName); + encryptionKeyVersion = vaultTransitKey.getLatestVersion(); } - } catch (VaultException e) { - LOGGER.error("Failed to encrypt records Vault", e); - throw new RuntimeException("Failed to encrypt records Vault"); + RawTransitKey vaultKey = transit.exportKey(encryptionKeyName, TransitKeyType.ENCRYPTION_KEY); + // decode the base64 encoded string + decodedKey = Base64.getDecoder().decode(vaultKey.getKeys().get(String.valueOf(encryptionKeyVersion))); + map.put(encryptionKeyName, decodedKey, 5 * 60000); + } + + // encrypt and decrypt need the same IV. + // AES-GCM needs IV 96-bit (12 bytes) + byte[] iv = EncryptorAesGcm.getRandomNonce(EncryptorAesGcm.IV_LENGTH_BYTE); + + // rebuild key using SecretKeySpec + SecretKey originalKey = new SecretKeySpec(decodedKey, 0, decodedKey.length, "AES"); + + //byte[] ciphertext = encryptor.encrypt(valueSerializer.serialize(record.topic(), record.value())); + byte[] ciphertext = null; + try { + ciphertext = EncryptorAesGcm.encryptWithPrefixIV(valueSerializer.serialize(record.topic(), record.value()), originalKey, iv); + } catch (Exception e) { + LOGGER.error("Failed to encrypt"); } + Headers headers = record.headers(); + headers.add("x-vault-encryption-key-name", encryptionKeyName.getBytes()); + headers.add("x-vault-encryption-key-version", intToByteArray(encryptionKeyVersion)); + return new ProducerRecord( + record.topic(), + record.partition(), + record.timestamp(), + record.key(), + ciphertext, + headers + ); } - private String getBase64value(ProducerRecord record) { - return Base64.getEncoder().encodeToString(valueSerializer.serialize(record.topic(), record.value())); + private byte[] intToByteArray(final int i) { + return ByteBuffer.allocate(4).putInt(i).array(); } private String extractKeyOrElse(K key, String defaultKey) { - if(key instanceof String) { - return (String) key; + if (key instanceof String) { + return (String) key; } else return defaultKey; } @@ -94,7 +100,11 @@ public void configure(Map configs) { } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { LOGGER.error("Failed to create instance of interceptor.value.serializer", e); } - vault = new VaultFactory(configuration).vault; + try { + transit = new VaultFactory(configuration).transit; + } catch (Exception ignored) { + LOGGER.error("Failed to create Vault Client"); + } mount = configuration.getStringOrDefault(TRANSIT_MOUNT_CONFIG, TRANSIT_MOUNT_DEFAULT); defaultKey = configuration.getStringOrDefault(TRANSIT_KEY_CONFIG, TRANSIT_KEY_DEFAULT); } diff --git a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptorAesGcm.java b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptorAesGcm.java new file mode 100644 index 0000000..0acde9d --- /dev/null +++ b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/EncryptorAesGcm.java @@ -0,0 +1,75 @@ +package it.bitrock.kafkavaulttransitinterceptor; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import java.nio.ByteBuffer; +import java.security.SecureRandom; + +public class EncryptorAesGcm { + + static final int IV_LENGTH_BYTE = 12; + private static final String ENCRYPT_ALGO = "AES/GCM/NoPadding"; + private static final int TAG_LENGTH_BIT = 128; + + public static byte[] getRandomNonce(int numBytes) { + byte[] nonce = new byte[numBytes]; + new SecureRandom().nextBytes(nonce); + return nonce; + } + + // AES-GCM needs GCMParameterSpec + public static byte[] encrypt(byte[] pText, SecretKey secret, byte[] iv) throws Exception { + + Cipher cipher = Cipher.getInstance(ENCRYPT_ALGO); + cipher.init(Cipher.ENCRYPT_MODE, secret, new GCMParameterSpec(TAG_LENGTH_BIT, iv)); + byte[] encryptedText = cipher.doFinal(pText); + return encryptedText; + + } + + // prefix IV length + IV bytes to cipher text + public static byte[] encryptWithPrefixIV(byte[] pText, SecretKey secret, byte[] iv) throws Exception { + + byte[] cipherText = encrypt(pText, secret, iv); + + byte[] cipherTextWithIv = ByteBuffer.allocate(iv.length + cipherText.length) + .put(iv) + .put(cipherText) + .array(); + return cipherTextWithIv; + + } + + + public static byte[] decrypt(byte[] cText, SecretKey secret, byte[] iv) throws Exception { + + Cipher cipher = Cipher.getInstance(ENCRYPT_ALGO); + cipher.init(Cipher.DECRYPT_MODE, secret, new GCMParameterSpec(TAG_LENGTH_BIT, iv)); + byte[] plainText = cipher.doFinal(cText); + return plainText; + + } + + public static byte[] decryptWithPrefixIV(byte[] cText, SecretKey secret) { + + ByteBuffer bb = ByteBuffer.wrap(cText); + + byte[] iv = new byte[IV_LENGTH_BYTE]; + bb.get(iv); + //bb.get(iv, 0, iv.length); + + byte[] cipherText = new byte[bb.remaining()]; + bb.get(cipherText); + + byte[] plainText = null; + try { + plainText = decrypt(cipherText, secret, iv); + } catch (Exception e) { + e.printStackTrace(); + } + return plainText; + + } + +} diff --git a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringHashMap.java b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringHashMap.java new file mode 100644 index 0000000..c63b546 --- /dev/null +++ b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringHashMap.java @@ -0,0 +1,304 @@ +package it.bitrock.kafkavaulttransitinterceptor; + +/* + * Copyright (c) 2019 Pierantonio Cangianiello + * + * MIT License + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +import java.util.Collection; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.WeakHashMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.DelayQueue; +import java.util.concurrent.Delayed; +import java.util.concurrent.TimeUnit; + +/** + * A HashMap which entries expires after the specified life time. + * The life-time can be defined on a per-key basis, or using a default one, that is passed to the + * constructor. + * + * @author Pierantonio Cangianiello + * @param the Key type + * @param the Value type + */ +public class SelfExpiringHashMap implements SelfExpiringMap { + + private final Map internalMap; + + private final Map> expiringKeys; + + /** + * Holds the map keys using the given life time for expiration. + */ + private final DelayQueue delayQueue = new DelayQueue(); + + /** + * The default max life time in milliseconds. + */ + private final long maxLifeTimeMillis; + + public SelfExpiringHashMap() { + internalMap = new ConcurrentHashMap(); + expiringKeys = new WeakHashMap>(); + this.maxLifeTimeMillis = Long.MAX_VALUE; + } + + public SelfExpiringHashMap(long defaultMaxLifeTimeMillis) { + internalMap = new ConcurrentHashMap(); + expiringKeys = new WeakHashMap>(); + this.maxLifeTimeMillis = defaultMaxLifeTimeMillis; + } + + public SelfExpiringHashMap(long defaultMaxLifeTimeMillis, int initialCapacity) { + internalMap = new ConcurrentHashMap(initialCapacity); + expiringKeys = new WeakHashMap>(initialCapacity); + this.maxLifeTimeMillis = defaultMaxLifeTimeMillis; + } + + public SelfExpiringHashMap(long defaultMaxLifeTimeMillis, int initialCapacity, float loadFactor) { + internalMap = new ConcurrentHashMap(initialCapacity, loadFactor); + expiringKeys = new WeakHashMap>(initialCapacity, loadFactor); + this.maxLifeTimeMillis = defaultMaxLifeTimeMillis; + } + + /** + * {@inheritDoc} + */ + @Override + public int size() { + cleanup(); + return internalMap.size(); + } + + /** + * {@inheritDoc} + */ + @Override + public boolean isEmpty() { + cleanup(); + return internalMap.isEmpty(); + } + + /** + * {@inheritDoc} + */ + @Override + public boolean containsKey(Object key) { + cleanup(); + return internalMap.containsKey((K) key); + } + + /** + * {@inheritDoc} + */ + @Override + public boolean containsValue(Object value) { + cleanup(); + return internalMap.containsValue((V) value); + } + + @Override + public V get(Object key) { + cleanup(); + renewKey((K) key); + return internalMap.get((K) key); + } + + /** + * {@inheritDoc} + */ + @Override + public V put(K key, V value) { + return this.put(key, value, maxLifeTimeMillis); + } + + /** + * {@inheritDoc} + */ + @Override + public V put(K key, V value, long lifeTimeMillis) { + cleanup(); + ExpiringKey delayedKey = new ExpiringKey(key, lifeTimeMillis); + ExpiringKey oldKey = expiringKeys.put((K) key, delayedKey); + if(oldKey != null) { + expireKey(oldKey); + expiringKeys.put((K) key, delayedKey); + } + delayQueue.offer(delayedKey); + return internalMap.put(key, value); + } + + /** + * {@inheritDoc} + */ + @Override + public V remove(Object key) { + V removedValue = internalMap.remove((K) key); + expireKey(expiringKeys.remove((K) key)); + return removedValue; + } + + /** + * Not supported. + */ + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + @Override + public boolean renewKey(K key) { + ExpiringKey delayedKey = expiringKeys.get((K) key); + if (delayedKey != null) { + delayedKey.renew(); + return true; + } + return false; + } + + private void expireKey(ExpiringKey delayedKey) { + if (delayedKey != null) { + delayedKey.expire(); + cleanup(); + } + } + + /** + * {@inheritDoc} + */ + @Override + public void clear() { + delayQueue.clear(); + expiringKeys.clear(); + internalMap.clear(); + } + + /** + * Not supported. + */ + @Override + public Set keySet() { + throw new UnsupportedOperationException(); + } + + /** + * Not supported. + */ + @Override + public Collection values() { + throw new UnsupportedOperationException(); + } + + /** + * Not supported. + */ + @Override + public Set> entrySet() { + throw new UnsupportedOperationException(); + } + + private void cleanup() { + ExpiringKey delayedKey = delayQueue.poll(); + while (delayedKey != null) { + internalMap.remove(delayedKey.getKey()); + expiringKeys.remove(delayedKey.getKey()); + delayedKey = delayQueue.poll(); + } + } + + private class ExpiringKey implements Delayed { + + private long startTime = System.currentTimeMillis(); + private final long maxLifeTimeMillis; + private final K key; + + public ExpiringKey(K key, long maxLifeTimeMillis) { + this.maxLifeTimeMillis = maxLifeTimeMillis; + this.key = key; + } + + public K getKey() { + return key; + } + + /** + * {@inheritDoc} + */ + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final ExpiringKey other = (ExpiringKey) obj; + if (this.key != other.key && (this.key == null || !this.key.equals(other.key))) { + return false; + } + return true; + } + + /** + * {@inheritDoc} + */ + @Override + public int hashCode() { + int hash = 7; + hash = 31 * hash + (this.key != null ? this.key.hashCode() : 0); + return hash; + } + + /** + * {@inheritDoc} + */ + @Override + public long getDelay(TimeUnit unit) { + return unit.convert(getDelayMillis(), TimeUnit.MILLISECONDS); + } + + private long getDelayMillis() { + return (startTime + maxLifeTimeMillis) - System.currentTimeMillis(); + } + + public void renew() { + startTime = System.currentTimeMillis(); + } + + public void expire() { + startTime = Long.MIN_VALUE; + } + + /** + * {@inheritDoc} + */ + @Override + public int compareTo(Delayed that) { + return Long.compare(this.getDelayMillis(), ((ExpiringKey) that).getDelayMillis()); + } + } +} diff --git a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringMap.java b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringMap.java new file mode 100644 index 0000000..e8c17bd --- /dev/null +++ b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/SelfExpiringMap.java @@ -0,0 +1,57 @@ +package it.bitrock.kafkavaulttransitinterceptor; + +/* + * Copyright (c) 2019 Pierantonio Cangianiello + * + * MIT License + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +import java.util.Map; + +/** + * + * @author Pierantonio Cangianiello + * @param the Key type + * @param the Value type + */ +public interface SelfExpiringMap extends Map { + + /** + * Renews the specified key, setting the life time to the initial value. + * + * @param key + * @return true if the key is found, false otherwise + */ + public boolean renewKey(K key); + + /** + * Associates the given key to the given value in this map, with the specified life + * times in milliseconds. + * + * @param key + * @param value + * @param lifeTimeMillis + * @return a previously associated object for the given key (if exists). + */ + public V put(K key, V value, long lifeTimeMillis); + +} + diff --git a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/VaultFactory.java b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/VaultFactory.java index 43ea764..2453443 100644 --- a/src/main/java/it/bitrock/kafkavaulttransitinterceptor/VaultFactory.java +++ b/src/main/java/it/bitrock/kafkavaulttransitinterceptor/VaultFactory.java @@ -1,27 +1,29 @@ package it.bitrock.kafkavaulttransitinterceptor; -import com.bettercloud.vault.SslConfig; -import com.bettercloud.vault.Vault; -import com.bettercloud.vault.VaultConfig; -import com.bettercloud.vault.VaultException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.vault.authentication.TokenAuthentication; +import org.springframework.vault.client.VaultEndpoint; +import org.springframework.vault.core.VaultTemplate; +import org.springframework.vault.core.VaultTransitOperations; + +import java.net.URI; +import java.net.URISyntaxException; + +import static it.bitrock.kafkavaulttransitinterceptor.TransitConfiguration.TRANSIT_MOUNT_CONFIG; +import static it.bitrock.kafkavaulttransitinterceptor.TransitConfiguration.TRANSIT_MOUNT_DEFAULT; class VaultFactory { static final Logger LOGGER = LoggerFactory.getLogger(VaultFactory.class); - final Vault vault; final TransitConfiguration configuration; + VaultTemplate template; + VaultTransitOperations transit; - - VaultFactory(TransitConfiguration configuration) { + VaultFactory(TransitConfiguration configuration) throws URISyntaxException { this.configuration = configuration; - try { - VaultConfig config = new VaultConfig().sslConfig(new SslConfig().build()).build(); - this.vault = new Vault(config, 1); - } catch (VaultException e) { - LOGGER.error("Failed to initialize Vault", e); - throw new RuntimeException("Failed to initialize Vault"); - } + this.template = new VaultTemplate(VaultEndpoint.from(new URI(System.getenv("VAULT_ADDR"))), new TokenAuthentication(System.getenv("VAULT_TOKEN"))); + this.transit = template.opsForTransit(configuration.getStringOrDefault(TRANSIT_MOUNT_CONFIG, TRANSIT_MOUNT_DEFAULT)); } } diff --git a/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/TransitInterceptorTest.java b/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/TransitInterceptorTest.java index 7188a44..c8871cf 100644 --- a/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/TransitInterceptorTest.java +++ b/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/TransitInterceptorTest.java @@ -25,7 +25,7 @@ public class TransitInterceptorTest { public static final VaultContainer container = new VaultContainer(); @ClassRule - public static KafkaContainer kafka = new KafkaContainer("5.4.2"); + public static KafkaContainer kafka = new KafkaContainer("5.5.1"); @ClassRule public static final EnvironmentVariables environmentVariables = new EnvironmentVariables(); diff --git a/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/KafkaHelper.java b/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/KafkaHelper.java index 8a62a6d..a70f0df 100644 --- a/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/KafkaHelper.java +++ b/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/KafkaHelper.java @@ -4,6 +4,8 @@ import org.apache.kafka.clients.producer.KafkaProducer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; import org.apache.kafka.common.serialization.LongDeserializer; import org.apache.kafka.common.serialization.StringDeserializer; @@ -18,7 +20,7 @@ public static KafkaProducer longKafkaProducerInterceptor(String bo Properties properties = new Properties(); properties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); properties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer"); - properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer"); + properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); properties.put("interceptor.value.serializer", "org.apache.kafka.common.serialization.LongSerializer"); properties.put(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, "it.bitrock.kafkavaulttransitinterceptor.EncryptingProducerInterceptor"); @@ -30,7 +32,7 @@ public static KafkaProducer stringKafkaProducerInterceptor(Stri Properties properties = new Properties(); properties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); properties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer"); - properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer"); + properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); properties.put("interceptor.value.serializer", "org.apache.kafka.common.serialization.StringSerializer"); properties.put(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, "it.bitrock.kafkavaulttransitinterceptor.EncryptingProducerInterceptor"); @@ -44,7 +46,7 @@ private static Properties defaultStringConsumerProperties(String bootstrapServer props.put(ConsumerConfig.GROUP_ID_CONFIG, String.format("%s", UUID.randomUUID().toString())); props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); - props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); return props; } @@ -69,7 +71,7 @@ private static Properties defaultLongConsumerProperties(String bootstrapServers) props.put(ConsumerConfig.GROUP_ID_CONFIG, String.format("%s", UUID.randomUUID().toString())); props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); - props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); return props; } diff --git a/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/VaultContainer.java b/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/VaultContainer.java index af7441d..1caaa09 100644 --- a/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/VaultContainer.java +++ b/src/test-integration/java/it/bitrock/kafkavaulttransitinterceptor/util/VaultContainer.java @@ -1,14 +1,10 @@ package it.bitrock.kafkavaulttransitinterceptor.util; -import com.bettercloud.vault.SslConfig; -import com.bettercloud.vault.Vault; -import com.bettercloud.vault.VaultConfig; -import com.bettercloud.vault.VaultException; -import com.bettercloud.vault.json.Json; -import com.bettercloud.vault.json.JsonObject; import com.github.dockerjava.api.model.Capability; import java.io.IOException; import java.net.HttpURLConnection; + +import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testcontainers.containers.BindMode; @@ -25,7 +21,7 @@ public class VaultContainer extends GenericContainer implements private static final Logger LOGGER = LoggerFactory.getLogger(VaultContainer.class); - public static final String DEFAULT_IMAGE_AND_TAG = "vault:1.4.2"; + public static final String DEFAULT_IMAGE_AND_TAG = "vault:1.5.3"; public static String rootToken; private String unsealKey; @@ -59,9 +55,9 @@ public void initAndUnsealVault() throws IOException, InterruptedException { // Initialize the Vault server final Container.ExecResult initResult = runCommand("vault", "operator", "init", "-key-shares=1", "-key-threshold=1", "-format=json"); final String stdout = initResult.getStdout().replaceAll("\\r?\\n", ""); - JsonObject initJson = Json.parse(stdout).asObject(); - this.unsealKey = initJson.get("unseal_keys_b64").asArray().get(0).asString(); - rootToken = initJson.get("root_token").asString(); + JSONObject initJson = new JSONObject(stdout); + this.unsealKey = initJson.getJSONArray("unseal_keys_b64").get(0).toString(); + rootToken = initJson.get("root_token").toString(); System.out.println("Root token: " + rootToken); @@ -76,67 +72,6 @@ public void setupBackendTransit() throws IOException, InterruptedException { runCommand("vault", "secrets", "enable", "transit"); } - - public Vault getVault(final VaultConfig config, final Integer maxRetries, final Integer retryMillis) { - Vault vault = new Vault(config); - if (maxRetries != null && retryMillis != null) { - vault = vault.withRetries(maxRetries, retryMillis); - } else if (maxRetries != null) { - vault = vault.withRetries(maxRetries, RETRY_MILLIS); - } else if (retryMillis != null) { - vault = vault.withRetries(MAX_RETRIES, retryMillis); - } - return vault; - } - - public Vault getVault() throws VaultException { - final VaultConfig config = - new VaultConfig() - .address(getAddress()) - .openTimeout(5) - .readTimeout(30) - .sslConfig(new SslConfig().build()) - .build(); - return getVault(config, MAX_RETRIES, RETRY_MILLIS); - } - - public VaultConfig getVaultConfig() throws VaultException { - return new VaultConfig() - .address(getAddress()) - .openTimeout(5) - .readTimeout(30) - .sslConfig(new SslConfig().build()) - .build(); - } - - public Vault getVault(final String token) throws VaultException { - final VaultConfig config = - new VaultConfig() - .address(getAddress()) - .token(token) - .openTimeout(5) - .readTimeout(30) - .sslConfig(new SslConfig().build()) - .build(); - return new Vault(config).withRetries(MAX_RETRIES, RETRY_MILLIS); - } - - public Vault getRootVaultWithCustomVaultConfig(VaultConfig vaultConfig) throws VaultException { - final VaultConfig config = - vaultConfig - .address(getAddress()) - .token(rootToken) - .openTimeout(5) - .readTimeout(30) - .sslConfig(new SslConfig().build()) - .build(); - return new Vault(config).withRetries(MAX_RETRIES, RETRY_MILLIS); - } - - public Vault getRootVault() throws VaultException { - return getVault(rootToken).withRetries(MAX_RETRIES, RETRY_MILLIS); - } - public String getAddress() { return String.format("http://%s:%d", getContainerIpAddress(), getMappedPort(8200)); }