Skip to content

Commit

Permalink
Implement processing segments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Dec 20, 2024
1 parent 2634bdd commit dd4ff7d
Show file tree
Hide file tree
Showing 21 changed files with 496 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,42 @@
*/
package net.snowflake.client.jdbc.cloud.storage;

import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.READ;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.MatDesc;
import net.snowflake.client.jdbc.cloud.storage.floe.AeadProvider;
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import net.snowflake.client.jdbc.MatDesc;
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;

class GcmEncryptionProvider {
import static java.nio.file.StandardOpenOption.CREATE;
import static java.nio.file.StandardOpenOption.READ;

@SnowflakeJdbcInternalApi
public class GcmEncryptionProvider implements AeadProvider {
private static final int TAG_LENGTH_IN_BITS = 128;
private static final int IV_LENGTH_IN_BYTES = 12;
private static final String AES = "AES";
private static final String FILE_CIPHER = "AES/GCM/NoPadding";
private static final String KEY_CIPHER = "AES/GCM/NoPadding";
private static final String JCE_CIPHER_NAME = "AES/GCM/NoPadding";
private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB
private static final ThreadLocal<SecureRandom> random =
ThreadLocal.withInitial(SecureRandom::new);
Expand Down Expand Up @@ -85,7 +89,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData);
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
Expand All @@ -99,7 +103,7 @@ private static CipherInputStream encryptContent(
NoSuchAlgorithmException {
SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes);
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand Down Expand Up @@ -172,7 +176,7 @@ private static CipherInputStream decryptContentFromStream(
NoSuchAlgorithmException {
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand All @@ -187,7 +191,7 @@ private static void decryptContentFromFile(
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes);
byte[] buffer = new byte[BUFFER_SIZE];
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand Down Expand Up @@ -215,11 +219,34 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
}
return keyCipher.doFinal(keyBytes);
}

// TODO refactor to reuse cipher (consider thread safety vs performance)
@Override
public byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) throws GeneralSecurityException {
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, iv);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.ENCRYPT_MODE, key, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
}
return keyCipher.doFinal(plaintext);
}

@Override
public byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) throws GeneralSecurityException {
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, iv);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.DECRYPT_MODE, key, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
}
return keyCipher.doFinal(ciphertext);
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import net.snowflake.client.jdbc.cloud.storage.GcmEncryptionProvider;

public enum Aead {
AES_GCM_128((byte) 0),
AES_GCM_256((byte) 1);
// TODO confirm id
AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new GcmEncryptionProvider()),
AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new GcmEncryptionProvider());

private byte id;
private String jceName;
private int keyLength;
private int ivLength;
private int authTagLength;
private AeadProvider aeadProvider;

Aead(byte id) {
Aead(byte id, String jceName, int keyLength, int ivLength, int authTagLength, AeadProvider aeadProvider) {
this.jceName = jceName;
this.keyLength = keyLength;
this.id = id;
this.ivLength = ivLength;
this.authTagLength = authTagLength;
this.aeadProvider = aeadProvider;
}

byte getId() {
return id;
}

String getJceName() {
return jceName;
}

int getKeyLength() {
return keyLength;
}

int getIvLength() {
return ivLength;
}

int getAuthTagLength() {
return authTagLength;
}

AeadProvider getAeadProvider() {
return aeadProvider;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;

class AeadAad {
private final byte[] bytes;

private AeadAad(long segmentCounter, byte terminalityByte) {
ByteBuffer buf = ByteBuffer.allocate(9);
buf.putLong(segmentCounter);
buf.put(terminalityByte);
this.bytes = buf.array();
}

static AeadAad nonTerminal(long segmentCounter) {
return new AeadAad(segmentCounter, (byte) 0);
}

byte[] getBytes() {
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;

class AeadIv {
private final byte[] bytes;

AeadIv(byte[] bytes) {
this.bytes = bytes;
}

public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) {
return new AeadIv(floeRandom.ofLength(ivLength));
}

public static AeadIv from(ByteBuffer buffer, int ivLength) {
byte[] bytes = new byte[ivLength];
buffer.get(bytes);
return new AeadIv(bytes);
}

byte[] getBytes() {
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;

class AeadKey {
private final SecretKey key;

AeadKey(SecretKey key) {
this.key = key;
}

SecretKey getKey() {
return key;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;
import java.security.GeneralSecurityException;

public interface AeadProvider {
byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) throws GeneralSecurityException;
byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) throws GeneralSecurityException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

abstract class BaseSegmentProcessor {
protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1;
protected static final int headerTagLength = 32;

protected final FloeParameterSpec parameterSpec;
protected final FloeKey floeKey;
protected final FloeAad floeAad;

protected final KeyDerivator floeKdf;

private AeadKey currentAeadKey;

BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) {
this.parameterSpec = parameterSpec;
this.floeKey = floeKey;
this.floeAad = floeAad;
this.floeKdf = new KeyDerivator(parameterSpec);
}

protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) {
currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter);
}
return currentAeadKey;
}

private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
byte[] keyBytes = floeKdf.hkdfExpand(floeKey, floeIv, floeAad, new DekTagFloePurpose(segmentCounter), parameterSpec.getAead().getKeyLength());
SecretKey key = new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD
return new AeadKey(key);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

public interface FloeDecryptor {}
public interface FloeDecryptor extends SegmentProcessor {

}
Loading

0 comments on commit dd4ff7d

Please sign in to comment.