diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java index 1e2e25e64..636eae206 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java @@ -3,9 +3,19 @@ */ package net.snowflake.client.jdbc.cloud.storage; -import static java.nio.file.StandardOpenOption.CREATE; -import static java.nio.file.StandardOpenOption.READ; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.MatDesc; +import net.snowflake.client.jdbc.cloud.storage.floe.AeadProvider; +import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial; +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import javax.crypto.spec.SecretKeySpec; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; @@ -13,28 +23,22 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; import java.nio.file.Files; +import java.security.GeneralSecurityException; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.util.Base64; -import javax.crypto.BadPaddingException; -import javax.crypto.Cipher; -import javax.crypto.CipherInputStream; -import javax.crypto.IllegalBlockSizeException; -import javax.crypto.NoSuchPaddingException; -import javax.crypto.SecretKey; -import javax.crypto.spec.GCMParameterSpec; -import javax.crypto.spec.SecretKeySpec; -import net.snowflake.client.jdbc.MatDesc; -import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial; -class GcmEncryptionProvider { +import static java.nio.file.StandardOpenOption.CREATE; +import static java.nio.file.StandardOpenOption.READ; + +@SnowflakeJdbcInternalApi +public class GcmEncryptionProvider implements AeadProvider { private static final int TAG_LENGTH_IN_BITS = 128; private static final int IV_LENGTH_IN_BYTES = 12; private static final String AES = "AES"; - private static final String FILE_CIPHER = "AES/GCM/NoPadding"; - private static final String KEY_CIPHER = "AES/GCM/NoPadding"; + private static final String JCE_CIPHER_NAME = "AES/GCM/NoPadding"; private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB private static final ThreadLocal random = ThreadLocal.withInitial(SecureRandom::new); @@ -85,7 +89,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData); - Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); + Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec); if (aad != null) { keyCipher.updateAAD(aad); @@ -99,7 +103,7 @@ private static CipherInputStream encryptContent( NoSuchAlgorithmException { SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes); - Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); + Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -172,7 +176,7 @@ private static CipherInputStream decryptContentFromStream( NoSuchAlgorithmException { GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES); - Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); + Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -187,7 +191,7 @@ private static void decryptContentFromFile( SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes); byte[] buffer = new byte[BUFFER_SIZE]; - Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); + Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -215,11 +219,34 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); - Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); + Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec); if (aad != null) { keyCipher.updateAAD(aad); } return keyCipher.doFinal(keyBytes); } + + // TODO refactor to reuse cipher (consider thread safety vs performance) + @Override + public byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) throws GeneralSecurityException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, iv); + Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); + keyCipher.init(Cipher.ENCRYPT_MODE, key, gcmParameterSpec); + if (aad != null) { + keyCipher.updateAAD(aad); + } + return keyCipher.doFinal(plaintext); + } + + @Override + public byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) throws GeneralSecurityException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, iv); + Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); + keyCipher.init(Cipher.DECRYPT_MODE, key, gcmParameterSpec); + if (aad != null) { + keyCipher.updateAAD(aad); + } + return keyCipher.doFinal(ciphertext); + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java index 861343163..1a77b04a6 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -1,16 +1,49 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import net.snowflake.client.jdbc.cloud.storage.GcmEncryptionProvider; + public enum Aead { - AES_GCM_128((byte) 0), - AES_GCM_256((byte) 1); + // TODO confirm id + AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new GcmEncryptionProvider()), + AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new GcmEncryptionProvider()); private byte id; + private String jceName; + private int keyLength; + private int ivLength; + private int authTagLength; + private AeadProvider aeadProvider; - Aead(byte id) { + Aead(byte id, String jceName, int keyLength, int ivLength, int authTagLength, AeadProvider aeadProvider) { + this.jceName = jceName; + this.keyLength = keyLength; this.id = id; + this.ivLength = ivLength; + this.authTagLength = authTagLength; + this.aeadProvider = aeadProvider; } byte getId() { return id; } + + String getJceName() { + return jceName; + } + + int getKeyLength() { + return keyLength; + } + + int getIvLength() { + return ivLength; + } + + int getAuthTagLength() { + return authTagLength; + } + + AeadProvider getAeadProvider() { + return aeadProvider; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java new file mode 100644 index 000000000..6586c28bd --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java @@ -0,0 +1,22 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class AeadAad { + private final byte[] bytes; + + private AeadAad(long segmentCounter, byte terminalityByte) { + ByteBuffer buf = ByteBuffer.allocate(9); + buf.putLong(segmentCounter); + buf.put(terminalityByte); + this.bytes = buf.array(); + } + + static AeadAad nonTerminal(long segmentCounter) { + return new AeadAad(segmentCounter, (byte) 0); + } + + byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java new file mode 100644 index 000000000..c2a559b47 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java @@ -0,0 +1,25 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class AeadIv { + private final byte[] bytes; + + AeadIv(byte[] bytes) { + this.bytes = bytes; + } + + public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) { + return new AeadIv(floeRandom.ofLength(ivLength)); + } + + public static AeadIv from(ByteBuffer buffer, int ivLength) { + byte[] bytes = new byte[ivLength]; + buffer.get(bytes); + return new AeadIv(bytes); + } + + byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java new file mode 100644 index 000000000..bfbd01976 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +class AeadKey { + private final SecretKey key; + + AeadKey(SecretKey key) { + this.key = key; + } + + SecretKey getKey() { + return key; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java new file mode 100644 index 000000000..f3e36f512 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java @@ -0,0 +1,9 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; +import java.security.GeneralSecurityException; + +public interface AeadProvider { + byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) throws GeneralSecurityException; + byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) throws GeneralSecurityException; +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java new file mode 100644 index 000000000..4b340eba9 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java @@ -0,0 +1,37 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +abstract class BaseSegmentProcessor { + protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1; + protected static final int headerTagLength = 32; + + protected final FloeParameterSpec parameterSpec; + protected final FloeKey floeKey; + protected final FloeAad floeAad; + + protected final KeyDerivator floeKdf; + + private AeadKey currentAeadKey; + + BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { + this.parameterSpec = parameterSpec; + this.floeKey = floeKey; + this.floeAad = floeAad; + this.floeKdf = new KeyDerivator(parameterSpec); + } + + protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { + if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) { + currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter); + } + return currentAeadKey; + } + + private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { + byte[] keyBytes = floeKdf.hkdfExpand(floeKey, floeIv, floeAad, new DekTagFloePurpose(segmentCounter), parameterSpec.getAead().getKeyLength()); + SecretKey key = new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD + return new AeadKey(key); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java deleted file mode 100644 index 7328d480c..000000000 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java +++ /dev/null @@ -1,18 +0,0 @@ -package net.snowflake.client.jdbc.cloud.storage.floe; - -abstract class FloeBase { - protected static final int headerTagLength = 32; - - protected final FloeParameterSpec parameterSpec; - protected final FloeKey floeKey; - protected final FloeAad floeAad; - - protected final FloeKdf floeKdf; - - FloeBase(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { - this.parameterSpec = parameterSpec; - this.floeKey = floeKey; - this.floeAad = floeAad; - this.floeKdf = new FloeKdf(parameterSpec); - } -} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java index 87e2463fd..69cebb708 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java @@ -1,3 +1,5 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public interface FloeDecryptor {} +public interface FloeDecryptor extends SegmentProcessor { + +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java index 7139d9e73..5e94d1298 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -1,19 +1,19 @@ package net.snowflake.client.jdbc.cloud.storage.floe; import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; import java.util.Arrays; -public class FloeDecryptorImpl extends FloeBase implements FloeDecryptor { +public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { + private final FloeIv floeIv; + private long segmentCounter; + FloeDecryptorImpl( FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) { super(parameterSpec, floeKey, floeAad); - validate(floeHeaderAsBytes); - } - - public void validate(byte[] floeHeaderAsBytes) { - byte[] encodedParams = parameterSpec.paramEncode(); + byte[] encodedParams = this.parameterSpec.paramEncode(); if (floeHeaderAsBytes.length - != encodedParams.length + parameterSpec.getFloeIvLength().getLength() + headerTagLength) { + != encodedParams.length + this.parameterSpec.getFloeIvLength().getLength() + headerTagLength) { throw new IllegalArgumentException("invalid header length"); } ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes); @@ -24,17 +24,49 @@ public void validate(byte[] floeHeaderAsBytes) { throw new IllegalArgumentException("invalid parameters header"); } - byte[] floeIvBytes = new byte[parameterSpec.getFloeIvLength().getLength()]; + byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()]; floeHeader.get(floeIvBytes, 0, floeIvBytes.length); - FloeIv floeIv = new FloeIv(floeIvBytes); + this.floeIv = new FloeIv(floeIvBytes); byte[] headerTagFromHeader = new byte[headerTagLength]; floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length); byte[] headerTag = - floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength); + floeKdf.hkdfExpand(this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength); if (!Arrays.equals(headerTag, headerTagFromHeader)) { throw new IllegalArgumentException("invalid header tag"); } } + + @Override + public byte[] processSegment(byte[] input) { + try { + verifySegmentLength(input); + ByteBuffer inputBuf = ByteBuffer.wrap(input); + verifySegmentSizeMarker(inputBuf); + // TODO handle mask + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++); + AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); + byte[] ciphertext = new byte[inputBuf.remaining()]; + inputBuf.get(ciphertext); + return aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private void verifySegmentLength(byte[] input) { + if (input.length != parameterSpec.getEncryptedSegmentLength()) { + throw new IllegalArgumentException(String.format("segment length mismatch, expected %d, got %d", parameterSpec.getEncryptedSegmentLength(), input.length)); + } + } + + private void verifySegmentSizeMarker(ByteBuffer inputBuf) { + int segmentSizeMarker = inputBuf.getInt(); + if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) { + throw new IllegalStateException(String.format("segment length marker mismatch, expected: %d, got :%d", NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker)); + } + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java index b629869f8..f1ab85496 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java @@ -1,5 +1,5 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public interface FloeEncryptor { +public interface FloeEncryptor extends SegmentProcessor { byte[] getHeader(); } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java index ed993962f..371321469 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -1,9 +1,14 @@ package net.snowflake.client.jdbc.cloud.storage.floe; import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; + +class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { -class FloeEncryptorImpl extends FloeBase implements FloeEncryptor { private final FloeIv floeIv; + private AeadKey currentAeadKey; + + private long segmentCounter; private final byte[] header; @@ -18,7 +23,7 @@ private byte[] buildHeader() { byte[] parametersEncoded = parameterSpec.paramEncode(); byte[] floeIvBytes = floeIv.getBytes(); byte[] headerTag = - floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength); + floeKdf.hkdfExpand(floeKey, floeIv, floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength); ByteBuffer result = ByteBuffer.allocate(parametersEncoded.length + floeIvBytes.length + headerTag.length); @@ -35,4 +40,36 @@ private byte[] buildHeader() { public byte[] getHeader() { return header; } + + @Override + public byte[] processSegment(byte[] input) { + verifySegmentLength(input); + // TODO assert State.Counter != 2^40-1 # Prevent overflow + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = AeadIv.generateRandom(parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++); + AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); + // it works as long as AEAD returns auth tag as a part of the ciphertext + // TODO reuse cipher? + byte[] ciphertextWithAuthTag = aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + return segmentToBytes(aeadIv, ciphertextWithAuthTag); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private byte[] segmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { + ByteBuffer output = ByteBuffer.allocate(parameterSpec.getEncryptedSegmentLength()); + output.putInt(NON_TERMINAL_SEGMENT_SIZE_MARKER); + output.put(aeadIv.getBytes()); + output.put(ciphertextWithAuthTag); + return output.array(); + } + + private void verifySegmentLength(byte[] input) { + if (input.length != parameterSpec.getPlainTextSegmentLength()) { + throw new IllegalArgumentException(String.format("segment length mismatch, expected %d, got %d", parameterSpec.getPlainTextSegmentLength(), input.length)); + } + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java index 53b5db779..4f956f628 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java @@ -9,10 +9,11 @@ public class FloeParameterSpec { private final int encryptedSegmentLength; private final FloeIvLength floeIvLength; private final FloeRandom floeRandom; + private final int keyRotationModulo; public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int floeIvLength) { this( - aead, hash, encryptedSegmentLength, new FloeIvLength(floeIvLength), new SecureFloeRandom()); + aead, hash, encryptedSegmentLength, new FloeIvLength(floeIvLength), new SecureFloeRandom(), 1 << 20); } FloeParameterSpec( @@ -20,12 +21,14 @@ public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int f Hash hash, int encryptedSegmentLength, FloeIvLength floeIvLength, - FloeRandom floeRandom) { + FloeRandom floeRandom, + int keyRotationModulo) { this.aead = aead; this.hash = hash; this.encryptedSegmentLength = encryptedSegmentLength; this.floeIvLength = floeIvLength; this.floeRandom = floeRandom; + this.keyRotationModulo = keyRotationModulo; } byte[] paramEncode() { @@ -37,6 +40,10 @@ byte[] paramEncode() { return result.array(); } + public Aead getAead() { + return aead; + } + public Hash getHash() { return hash; } @@ -48,4 +55,17 @@ public FloeIvLength getFloeIvLength() { FloeRandom getFloeRandom() { return floeRandom; } + + int getEncryptedSegmentLength() { + return encryptedSegmentLength; + } + + int getPlainTextSegmentLength() { + // sizeof(int) == 4, file size is a part of the segment ciphertext + return encryptedSegmentLength - aead.getIvLength() - aead.getAuthTagLength() - 4; + } + + int getKeyRotationModulo() { + return keyRotationModulo; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java index ad4627035..b0f0c3a4c 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java @@ -1,17 +1,41 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -public enum FloePurpose { - HEADER_TAG("HEADER_TAG:".getBytes(StandardCharsets.UTF_8)); +interface FloePurpose { + byte[] generate(); +} + +class HeaderTagFloePurpose implements FloePurpose { + private static final byte[] bytes = "HEADER_TAG:".getBytes(StandardCharsets.UTF_8); + + static final HeaderTagFloePurpose INSTANCE = new HeaderTagFloePurpose(); + + private HeaderTagFloePurpose() { + + } + + @Override + public byte[] generate() { + return bytes; + } +} + +class DekTagFloePurpose implements FloePurpose { + private static final byte[] prefix = "DEK:".getBytes(StandardCharsets.UTF_8); private final byte[] bytes; - FloePurpose(byte[] bytes) { - this.bytes = bytes; + DekTagFloePurpose(long segmentCount) { + ByteBuffer buffer = ByteBuffer.allocate(prefix.length + 8 /*size of long*/); + buffer.put(prefix); + buffer.putLong(segmentCount); + this.bytes = buffer.array(); } - public byte[] getBytes() { + @Override + public byte[] generate() { return bytes; } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java similarity index 74% rename from src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java rename to src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java index 0d39e0a52..14c977903 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java @@ -1,35 +1,37 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import javax.crypto.Mac; import java.nio.ByteBuffer; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.Arrays; -import javax.crypto.Mac; -class FloeKdf { +class KeyDerivator { private final FloeParameterSpec parameterSpec; - FloeKdf(FloeParameterSpec parameterSpec) { + KeyDerivator(FloeParameterSpec parameterSpec) { this.parameterSpec = parameterSpec; } + // TODO should we return SecretKey instead? byte[] hkdfExpand( FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, FloePurpose purpose, int length) { byte[] encodedParams = parameterSpec.paramEncode(); + byte[] purposeBytes = purpose.generate(); ByteBuffer info = ByteBuffer.allocate( encodedParams.length + floeIv.getBytes().length - + purpose.getBytes().length + + purposeBytes.length + floeAad.getBytes().length); info.put(encodedParams); info.put(floeIv.getBytes()); - info.put(purpose.getBytes()); + info.put(purposeBytes); info.put(floeAad.getBytes()); - return jceHkdfExpand(parameterSpec.getHash(), floeKey, info.array(), length); + return hkdfExpandInternal(parameterSpec.getHash(), floeKey, info.array(), length); } - private byte[] jceHkdfExpand(Hash hash, FloeKey prk, byte[] info, int len) { + private byte[] hkdfExpandInternal(Hash hash, FloeKey prk, byte[] info, int len) { try { Mac mac = Mac.getInstance(hash.getJceName()); mac.init(prk.getKey()); diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java new file mode 100644 index 000000000..45e4f2872 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java @@ -0,0 +1,5 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +interface SegmentProcessor { + byte[] processSegment(byte[] input); +} diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index 3104ce7e9..c75b81818 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -324,6 +324,10 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. + properties.put("useProxy", "true"); + properties.put("proxyHost", "localhost"); + properties.put("proxyPort", "8080"); + if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java deleted file mode 100644 index 1a790d7ca..000000000 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java +++ /dev/null @@ -1,17 +0,0 @@ -package net.snowflake.client.jdbc.cloud.storage.floe; - -public class FixedFloeRandom implements FloeRandom { - private final byte[] bytes; - - public FixedFloeRandom(byte[] bytes) { - this.bytes = bytes; - } - - @Override - public byte[] ofLength(int length) { - if (bytes.length != length) { - throw new IllegalArgumentException("allowed only " + bytes.length + " bytes"); - } - return bytes; - } -} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java index b50f5d280..f353fe15d 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -1,12 +1,19 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; -import java.nio.charset.StandardCharsets; +import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -import org.junit.jupiter.api.Test; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; class FloeEncryptorImplTest { + byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); + SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); + @Test void shouldCreateCorrectHeader() { FloeParameterSpec parameterSpec = @@ -15,7 +22,9 @@ void shouldCreateCorrectHeader() { Hash.SHA384, 12345678, new FloeIvLength(4), - new FixedFloeRandom(new byte[] {11, 22, 33, 44})); + new IncrementingFloeRandom(), + 4); + parameterSpec.getFloeRandom().ofLength(4); // just to trigger incrementation FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); FloeAad floeAad = new FloeAad("test aad".getBytes(StandardCharsets.UTF_8)); FloeEncryptorImpl floeEncryptor = new FloeEncryptorImpl(parameterSpec, floeKey, floeAad); @@ -37,9 +46,43 @@ void shouldCreateCorrectHeader() { assertEquals(0, header[8]); assertEquals(4, header[9]); // FLOE IV - assertEquals(11, header[10]); - assertEquals(22, header[11]); - assertEquals(33, header[12]); - assertEquals(44, header[13]); + assertEquals(0, header[10]); + assertEquals(0, header[11]); + assertEquals(0, header[12]); + assertEquals(1, header[13]); + } + + @Test + void testEncryptionMatchesReference() { + List referenceCiphertextSegments = Arrays.asList( + "ffffffff0000000100000000000000000100007f5713b9827bb806318311fcde197146a144c6b485", // pragma: allowlist secret + "ffffffff000000020000000000000000f926dfc0a0bac6263d1634ad9a72f86900872033a271a037", // pragma: allowlist secret + "ffffffff00000003000000000000000080df8fdee872febe574c2b8df0bb34b3fb25bfc5802703a2", // pragma: allowlist secret + "ffffffff000000040000000000000000f4d81083e57451dbfa538827942245019b8bc3354ecc31e0", // pragma: allowlist secret + "ffffffff000000050000000000000000d91b774b5b460bd665910114e155f1cbc55a9a262a54f65e", // pragma: allowlist secret + "ffffffff000000060000000000000000ec723f3807eb71ea42ff03f5420daf34e1a8f4fb58931db1", // pragma: allowlist secret + "ffffffff00000007000000000000000072960c06ec19ce94c27c9fc72d79164f187f37e86325d849", // pragma: allowlist secret + "ffffffff000000080000000000000000c00a40fb140d797da818ab57399cb986bddddd174b8d3d6a", // pragma: allowlist secret + "ffffffff000000090000000000000000065e959cd1ffa521896fb54949a57ad1c1f8291a531c6d60", // pragma: allowlist secret + "ffffffff0000000a0000000000000000dfde3da3f67a081fb31229ac11e43a629ed120fbf9942513" // pragma: allowlist secret + ); + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(), 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + byte[] testData = new byte[8]; + for (int i = 0; i < referenceCiphertextSegments.size(); i++) { + byte[] ciphertextBytes = encryptor.processSegment(testData); + String ciphertextHex = toHex(ciphertextBytes); + assertEquals(referenceCiphertextSegments.get(i), ciphertextHex); + } + } + + private String toHex(byte[] input) { + StringBuilder result = new StringBuilder(); + for (byte b : input) { + result.append(String.format("%02x", b)); + } + return result.toString(); } } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java index 8395f9f8d..3f7a21839 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -1,66 +1,113 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; -import java.nio.charset.StandardCharsets; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -import org.junit.jupiter.api.Test; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; class FloeTest { - @Test - void validateHeaderMatchesForEncryptionAndDecryption() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); - Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - assertDoesNotThrow(() -> floe.createDecryptor(secretKey, aad, header)); + byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); + SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); + + @Nested + class HeaderTests { + @Test + void validateHeaderMatchesForEncryptionAndDecryption() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + assertDoesNotThrow(() -> floe.createDecryptor(secretKey, aad, header)); + } + + @Test + void validateHeaderDoesNotMatchInParams() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[0] = 12; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid parameters header"); + } + + @Test + void validateHeaderDoesNotMatchInIV() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[11]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + } + + @Test + void validateHeaderDoesNotMatchInHeaderTag() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 4096, 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[header.length - 3]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + } } @Test - void validateHeaderDoesNotMatchInParams() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + void testSegmentEncryptedAndDecrypted() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(), 4); Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); - header[0] = 12; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid parameters header"); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); } @Test - void validateHeaderDoesNotMatchInIV() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + void testSegmentEncryptedAndDecryptedWithRandomData() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(), 4); Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); - header[11]++; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid header tag"); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + new SecureRandom().nextBytes(testData); + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); } @Test - void validateHeaderDoesNotMatchInHeaderTag() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(), 4); Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); - header[header.length - 3]++; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid header tag"); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + for (int i = 0; i < 10; i++) { + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } } } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java new file mode 100644 index 000000000..0d954a7fe --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java @@ -0,0 +1,14 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +public class IncrementingFloeRandom implements FloeRandom { + private int seed; + + @Override + public byte[] ofLength(int length) { + ByteBuffer buffer = ByteBuffer.allocate(length); + buffer.putInt(seed++); + return buffer.array(); + } +}