Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1492090 Snowpipe streaming file master key id rotation #786

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
24bd907
SNOW-1492090: Add encryption key fields to RegisterBLobResponse and C…
sfc-gh-bmikaili Jun 27, 2024
77e4da5
Fix test
sfc-gh-bmikaili Aug 15, 2024
200034b
false comment
sfc-gh-bmikaili Aug 15, 2024
eee09b8
Trigger Build
sfc-gh-bmikaili Aug 15, 2024
979baa7
add back the old logic
sfc-gh-bmikaili Sep 3, 2024
fc864f1
format
sfc-gh-bmikaili Sep 3, 2024
34968af
test both code paths
sfc-gh-bmikaili Sep 4, 2024
f93b50a
Merge remote-tracking branch 'origin' into bmikaili-SNOW-1492090-snow…
sfc-gh-bmikaili Sep 4, 2024
a438a43
Fix tests
sfc-gh-bmikaili Sep 5, 2024
5ceda9c
Merge remote-tracking branch 'origin' into bmikaili-SNOW-1492090-snow…
sfc-gh-bmikaili Sep 12, 2024
59a5f80
ignore dep poml
sfc-gh-bmikaili Sep 18, 2024
a6b4663
merge
sfc-gh-bmikaili Sep 18, 2024
9562df3
Merge remote-tracking branch 'origin' into bmikaili-SNOW-1492090-snow…
sfc-gh-bmikaili Sep 19, 2024
61a0340
fix tests
sfc-gh-bmikaili Sep 26, 2024
edcf313
format
sfc-gh-bmikaili Sep 26, 2024
3ff93ca
Merge remote-tracking branch 'origin' into bmikaili-SNOW-1492090-snow…
sfc-gh-bmikaili Oct 1, 2024
db35d5c
Merge remote-tracking branch 'origin/master' into bmikaili-SNOW-14920…
sfc-gh-lsembera Nov 5, 2024
686ffc8
Fix compilation issues caused by last merge
sfc-gh-lsembera Nov 5, 2024
b918ee4
Fix BlobBuilderTest
sfc-gh-lsembera Nov 5, 2024
5863839
Merge remote-tracking branch 'origin/master' into bmikaili-SNOW-14920…
sfc-gh-lsembera Nov 5, 2024
a58b397
Fix star import
sfc-gh-lsembera Nov 5, 2024
796fce1
Temporarily enable long running tests
sfc-gh-lsembera Nov 5, 2024
4807227
Revert "Temporarily enable long running tests"
sfc-gh-lsembera Nov 5, 2024
c252660
Merge remote-tracking branch 'origin/master' into bmikaili-SNOW-14920…
sfc-gh-lsembera Nov 8, 2024
5dbb2c7
Utils.getFullyQualifiedTableName
sfc-gh-lsembera Nov 8, 2024
4cc8810
Don't populate encryptionKeysPerTable on channel open
sfc-gh-lsembera Nov 8, 2024
27504b8
Remove irrelevant test
sfc-gh-lsembera Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ src/main/resources/log4j.properties
src/test/resources/log4j.properties
testOutput/
.cache/
/dependency-reduced-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.zip.CRC32;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
Expand Down Expand Up @@ -70,7 +71,8 @@ static <T> Blob constructBlobAndMetadata(
String filePath,
List<List<ChannelData<T>>> blobData,
Constants.BdecVersion bdecVersion,
InternalParameterProvider internalParameterProvider)
InternalParameterProvider internalParameterProvider,
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable)
throws IOException, NoSuchPaddingException, NoSuchAlgorithmException,
InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException,
BadPaddingException {
Expand All @@ -84,6 +86,19 @@ static <T> Blob constructBlobAndMetadata(
ChannelFlushContext firstChannelFlushContext =
channelsDataPerTable.get(0).getChannelContext();

final EncryptionKey encryptionKey =
encryptionKeysPerTable.getOrDefault(
new FullyQualifiedTableName(
firstChannelFlushContext.getDbName(),
firstChannelFlushContext.getSchemaName(),
firstChannelFlushContext.getTableName()),
new EncryptionKey(
firstChannelFlushContext.getDbName(),
firstChannelFlushContext.getSchemaName(),
firstChannelFlushContext.getTableName(),
firstChannelFlushContext.getEncryptionKey(),
firstChannelFlushContext.getEncryptionKeyId()));

Flusher<T> flusher = channelsDataPerTable.get(0).createFlusher();
Flusher.SerializationResult serializedChunk =
flusher.serialize(channelsDataPerTable, filePath, curDataSize);
Expand All @@ -105,9 +120,10 @@ static <T> Blob constructBlobAndMetadata(
// to align with decryption on the Snowflake query path.
// TODO: address alignment for the header SNOW-557866
long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES;

compressedChunkData =
Cryptor.encrypt(
paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv);
Cryptor.encrypt(paddedChunkData, encryptionKey.getEncryptionKey(), filePath, iv);

compressedChunkDataSize = compressedChunkData.length;
} else {
compressedChunkData = serializedChunk.chunkData.toByteArray();
Expand All @@ -132,7 +148,7 @@ static <T> Blob constructBlobAndMetadata(
.setUncompressedChunkLength((int) serializedChunk.chunkEstimatedUncompressedSize)
.setChannelList(serializedChunk.channelsMetadataList)
.setChunkMD5(md5)
.setEncryptionKeyId(firstChannelFlushContext.getEncryptionKeyId())
.setEncryptionKeyId(encryptionKey.getEncryptionKeyId())
.setEpInfo(
AbstractRowBuffer.buildEpInfoFromStats(
serializedChunk.rowCount,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package net.snowflake.ingest.streaming.internal;

import com.fasterxml.jackson.annotation.JsonProperty;
import net.snowflake.ingest.utils.Utils;

/** Represents an encryption key for a table */
public class EncryptionKey {
// Database name
private final String databaseName;

// Schema name
private final String schemaName;

// Table Name
private final String tableName;

String blobTableMasterKey;

long encryptionKeyId;

public EncryptionKey(
@JsonProperty("database") String databaseName,
@JsonProperty("schema") String schemaName,
@JsonProperty("table") String tableName,
@JsonProperty("encryption_key") String blobTableMasterKey,
@JsonProperty("encryption_key_id") long encryptionKeyId) {
this.databaseName = databaseName;
this.schemaName = schemaName;
this.tableName = tableName;
this.blobTableMasterKey = blobTableMasterKey;
this.encryptionKeyId = encryptionKeyId;
}

public EncryptionKey(EncryptionKey encryptionKey) {
this.databaseName = encryptionKey.databaseName;
this.schemaName = encryptionKey.schemaName;
this.tableName = encryptionKey.tableName;
this.blobTableMasterKey = encryptionKey.blobTableMasterKey;
this.encryptionKeyId = encryptionKey.encryptionKeyId;
}

public String getFullyQualifiedTableName() {
return Utils.getFullyQualifiedTableName(databaseName, schemaName, tableName);
}

@JsonProperty("database")
public String getDatabaseName() {
return databaseName;
}

@JsonProperty("schema")
public String getSchemaName() {
return schemaName;
}

@JsonProperty("table")
public String getTableName() {
return tableName;
}

@JsonProperty("encryption_key")
public String getEncryptionKey() {
return blobTableMasterKey;
}

@JsonProperty("encryption_key_id")
public long getEncryptionKeyId() {
return encryptionKeyId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,19 @@ && shouldStopProcessing(
blobPath.fileRegistrationPath, this.owningClient.flushLatency.time());
}

// Copy encryptionKeysPerTable from owning client
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable =
new ConcurrentHashMap<>();
this.owningClient
.getEncryptionKeysPerTable()
.forEach((k, v) -> encryptionKeysPerTable.put(k, new EncryptionKey(v)));

Supplier<BlobMetadata> supplier =
() -> {
try {
BlobMetadata blobMetadata =
buildAndUpload(blobPath, blobData, fullyQualifiedTableName);
buildAndUpload(
blobPath, blobData, fullyQualifiedTableName, encryptionKeysPerTable);
blobMetadata.getBlobStats().setFlushStartMs(flushStartMs);
return blobMetadata;
} catch (Throwable e) {
Expand Down Expand Up @@ -562,8 +570,6 @@ && shouldStopProcessing(
*
* <p>When the chunk size is larger than a certain threshold
*
* <p>When the encryption key ids are not the same
*
* <p>When the schemas are not the same
*/
private boolean shouldStopProcessing(
Expand Down Expand Up @@ -591,7 +597,10 @@ private boolean shouldStopProcessing(
* @return BlobMetadata for FlushService.upload
*/
BlobMetadata buildAndUpload(
BlobPath blobPath, List<List<ChannelData<T>>> blobData, String fullyQualifiedTableName)
BlobPath blobPath,
List<List<ChannelData<T>>> blobData,
String fullyQualifiedTableName,
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable)
throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException,
NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException,
InvalidKeyException {
Expand All @@ -603,7 +612,8 @@ BlobMetadata buildAndUpload(
blobPath.fileRegistrationPath,
blobData,
bdecVersion,
this.owningClient.getInternalParameterProvider());
this.owningClient.getInternalParameterProvider(),
encryptionKeysPerTable);

blob.blobStats.setBuildDurationMs(buildContext);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package net.snowflake.ingest.streaming.internal;

import java.util.Objects;
import net.snowflake.ingest.utils.Utils;

/**
* FullyQualifiedTableName is a class that represents a fully qualified table name. It is used to
* store the fully qualified table name in the Snowflake format.
*/
public class FullyQualifiedTableName {
public FullyQualifiedTableName(String databaseName, String schemaName, String tableName) {
this.databaseName = databaseName;
this.schemaName = schemaName;
this.tableName = tableName;
}

// Database name
private final String databaseName;

// Schema name
private final String schemaName;

// Table Name
private final String tableName;

public String getTableName() {
return tableName;
}

public String getSchemaName() {
return schemaName;
}

public String getDatabaseName() {
return databaseName;
}

public String getFullyQualifiedName() {
return Utils.getFullyQualifiedTableName(databaseName, schemaName, tableName);
}

private int hashCode;

@Override
public int hashCode() {
int result = hashCode;
if (result == 0) {
result = 31 + ((databaseName == null) ? 0 : databaseName.hashCode());
result = 31 * result + ((schemaName == null) ? 0 : schemaName.hashCode());
result = 31 * result + ((tableName == null) ? 0 : tableName.hashCode());
hashCode = result;
}

return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;

if (!(obj instanceof FullyQualifiedTableName)) return false;

FullyQualifiedTableName other = (FullyQualifiedTableName) obj;

if (!Objects.equals(databaseName, other.databaseName)) return false;
if (!Objects.equals(schemaName, other.schemaName)) return false;
return Objects.equals(tableName, other.tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class RegisterBlobResponse extends StreamingIngestResponse {
private Long statusCode;
private String message;
private List<BlobRegisterStatus> blobsStatus;
private List<EncryptionKey> encryptionKeys;

@JsonProperty("status_code")
void setStatusCode(Long statusCode) {
Expand Down Expand Up @@ -39,4 +40,13 @@ void setBlobsStatus(List<BlobRegisterStatus> blobsStatus) {
List<BlobRegisterStatus> getBlobsStatus() {
return this.blobsStatus;
}

@JsonProperty("encryption_keys")
void setEncryptionKeys(List<EncryptionKey> encryptionKeys) {
this.encryptionKeys = encryptionKeys;
}

List<EncryptionKey> getEncryptionKeys() {
return this.encryptionKeys;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -110,14 +111,17 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
private final FlushService<T> flushService;

// Reference to storage manager
private final IStorageManager storageManager;
private IStorageManager storageManager;

// Indicates whether the client has closed
private volatile boolean isClosed;

// Indicates whether the client is under test mode
private final boolean isTestMode;

// Stores encryptionkey per table: FullyQualifiedTableName -> EncryptionKey
private final Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable;

// Performance testing related metrics
MetricRegistry metrics;
Histogram blobSizeHistogram; // Histogram for blob size after compression
Expand Down Expand Up @@ -172,6 +176,7 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
this.channelCache = new ChannelCache<>();
this.isClosed = false;
this.requestBuilder = requestBuilder;
this.encryptionKeysPerTable = new ConcurrentHashMap<>();

if (!isTestMode) {
// Setup request builder for communication with the server side
Expand Down Expand Up @@ -600,6 +605,18 @@ void registerBlobs(List<BlobMetadata> blobs, final int executionCount) {
this.name,
executionCount);

// Update encryption keys for the table given the response
if (response.getEncryptionKeys() == null) {
this.encryptionKeysPerTable.clear();
} else {
for (EncryptionKey key : response.getEncryptionKeys()) {
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
key.getDatabaseName(), key.getSchemaName(), key.getTableName()),
key);
}
}

// We will retry any blob chunks that were rejected because internal Snowflake queues are full
Set<ChunkRegisterStatus> queueFullChunks = new HashSet<>();
response
Expand Down Expand Up @@ -1063,4 +1080,13 @@ private void cleanUpResources() {
HttpUtil.shutdownHttpConnectionManagerDaemonThread();
}
}

public Map<FullyQualifiedTableName, EncryptionKey> getEncryptionKeysPerTable() {
return encryptionKeysPerTable;
}

// TESTING ONLY - inject the storage manager
public void setStorageManager(IStorageManager storageManager) {
this.storageManager = storageManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.Pair;
Expand Down Expand Up @@ -47,20 +49,27 @@ public static Object[] enableIcebergStreaming() {

@Test
public void testSerializationErrors() throws Exception {
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable = new ConcurrentHashMap<>();
encryptionKeysPerTable.put(
new FullyQualifiedTableName("DB", "SCHEMA", "TABLE"),
new EncryptionKey("DB", "SCHEMA", "TABLE", "KEY", 1234L));

// Construction succeeds if both data and metadata contain 1 row
BlobBuilder.constructBlobAndMetadata(
"a.bdec",
Collections.singletonList(createChannelDataPerTable(1)),
Constants.BdecVersion.THREE,
new InternalParameterProvider(enableIcebergStreaming));
new InternalParameterProvider(enableIcebergStreaming),
encryptionKeysPerTable);

// Construction fails if metadata contains 0 rows and data 1 row
try {
BlobBuilder.constructBlobAndMetadata(
"a.bdec",
Collections.singletonList(createChannelDataPerTable(0)),
Constants.BdecVersion.THREE,
new InternalParameterProvider(enableIcebergStreaming));
new InternalParameterProvider(enableIcebergStreaming),
encryptionKeysPerTable);
} catch (SFException e) {
Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode());
Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1"));
Expand All @@ -84,7 +93,8 @@ public void testMetadataAndExtendedMetadataSize() throws Exception {
"a.parquet",
Collections.singletonList(createChannelDataPerTable(1)),
Constants.BdecVersion.THREE,
new InternalParameterProvider(enableIcebergStreaming));
new InternalParameterProvider(enableIcebergStreaming),
new ConcurrentHashMap<>());

InputFile blobInputFile = new InMemoryInputFile(blob.blobBytes);
ParquetFileReader reader = ParquetFileReader.open(blobInputFile);
Expand Down
Loading
Loading