Skip to content

Commit

Permalink
SNOW-1483230 Parameter support & disable blob encryption for new tabl…
Browse files Browse the repository at this point in the history
…e format (#801)
  • Loading branch information
sfc-gh-alhuang authored Aug 10, 2024
1 parent 6632c84 commit b4b84b8
Show file tree
Hide file tree
Showing 16 changed files with 507 additions and 162 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021 Snowflake Computing Inc. All rights reserved.
* Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.streaming;
Expand Down Expand Up @@ -58,7 +58,7 @@ public SnowflakeStreamingIngestClient build() {
SnowflakeURL accountURL = new SnowflakeURL(prop.getProperty(Constants.ACCOUNT_URL));

return new SnowflakeStreamingIngestClientInternal<>(
this.name, accountURL, prop, this.parameterOverrides, this.isTestMode);
this.name, accountURL, prop, this.parameterOverrides, false, this.isTestMode);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021 Snowflake Computing Inc. All rights reserved.
* Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.streaming.internal;
Expand Down Expand Up @@ -61,10 +61,14 @@ class BlobBuilder {
* @param blobData All the data for one blob. Assumes that all ChannelData in the inner List
* belongs to the same table. Will error if this is not the case
* @param bdecVersion version of blob
* @param encrypt If the output chunk is encrypted or not
* @return {@link Blob} data
*/
static <T> Blob constructBlobAndMetadata(
String filePath, List<List<ChannelData<T>>> blobData, Constants.BdecVersion bdecVersion)
String filePath,
List<List<ChannelData<T>>> blobData,
Constants.BdecVersion bdecVersion,
boolean encrypt)
throws IOException, NoSuchPaddingException, NoSuchAlgorithmException,
InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException,
BadPaddingException {
Expand All @@ -83,25 +87,34 @@ static <T> Blob constructBlobAndMetadata(
flusher.serialize(channelsDataPerTable, filePath);

if (!serializedChunk.channelsMetadataList.isEmpty()) {
ByteArrayOutputStream chunkData = serializedChunk.chunkData;
Pair<byte[], Integer> paddedChunk =
padChunk(chunkData, Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES);
byte[] paddedChunkData = paddedChunk.getFirst();
int paddedChunkLength = paddedChunk.getSecond();
final byte[] compressedChunkData;
final int chunkLength;
final int compressedChunkDataSize;

// Encrypt the compressed chunk data, the encryption key is derived using the key from
// server with the full blob path.
// We need to maintain IV as a block counter for the whole file, even interleaved,
// to align with decryption on the Snowflake query path.
// TODO: address alignment for the header SNOW-557866
long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES;
byte[] encryptedCompressedChunkData =
Cryptor.encrypt(
paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv);
if (encrypt) {
Pair<byte[], Integer> paddedChunk =
padChunk(serializedChunk.chunkData, Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES);
byte[] paddedChunkData = paddedChunk.getFirst();
chunkLength = paddedChunk.getSecond();

// Encrypt the compressed chunk data, the encryption key is derived using the key from
// server with the full blob path.
// We need to maintain IV as a block counter for the whole file, even interleaved,
// to align with decryption on the Snowflake query path.
// TODO: address alignment for the header SNOW-557866
long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES;
compressedChunkData =
Cryptor.encrypt(
paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv);
compressedChunkDataSize = compressedChunkData.length;
} else {
compressedChunkData = serializedChunk.chunkData.toByteArray();
chunkLength = compressedChunkData.length;
compressedChunkDataSize = chunkLength;
}

// Compute the md5 of the chunk data
String md5 = computeMD5(encryptedCompressedChunkData, paddedChunkLength);
int encryptedCompressedChunkDataSize = encryptedCompressedChunkData.length;
String md5 = computeMD5(compressedChunkData, chunkLength);

// Create chunk metadata
long startOffset = curDataSize;
Expand All @@ -111,9 +124,9 @@ static <T> Blob constructBlobAndMetadata(
// The start offset will be updated later in BlobBuilder#build to include the blob
// header
.setChunkStartOffset(startOffset)
// The paddedChunkLength is used because it is the actual data size used for
// The chunkLength is used because it is the actual data size used for
// decompression and md5 calculation on server side.
.setChunkLength(paddedChunkLength)
.setChunkLength(chunkLength)
.setUncompressedChunkLength((int) serializedChunk.chunkEstimatedUncompressedSize)
.setChannelList(serializedChunk.channelsMetadataList)
.setChunkMD5(md5)
Expand All @@ -127,21 +140,22 @@ static <T> Blob constructBlobAndMetadata(

// Add chunk metadata and data to the list
chunksMetadataList.add(chunkMetadata);
chunksDataList.add(encryptedCompressedChunkData);
curDataSize += encryptedCompressedChunkDataSize;
crc.update(encryptedCompressedChunkData, 0, encryptedCompressedChunkDataSize);
chunksDataList.add(compressedChunkData);
curDataSize += compressedChunkDataSize;
crc.update(compressedChunkData, 0, compressedChunkDataSize);

logger.logInfo(
"Finish building chunk in blob={}, table={}, rowCount={}, startOffset={},"
+ " estimatedUncompressedSize={}, paddedChunkLength={}, encryptedCompressedSize={},"
+ " bdecVersion={}",
+ " estimatedUncompressedSize={}, chunkLength={}, compressedSize={},"
+ " encryption={}, bdecVersion={}",
filePath,
firstChannelFlushContext.getFullyQualifiedTableName(),
serializedChunk.rowCount,
startOffset,
serializedChunk.chunkEstimatedUncompressedSize,
paddedChunkLength,
encryptedCompressedChunkDataSize,
chunkLength,
compressedChunkDataSize,
encrypt,
bdecVersion);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,12 @@ private CompletableFuture<Void> distributeFlush(
/** If tracing is enabled, print always else, check if it needs flush or is forceful. */
private void logFlushTask(boolean isForce, Set<String> tablesToFlush, long flushStartTime) {
boolean isNeedFlush =
this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1
this.owningClient.getParameterProvider().getMaxChunksInBlob() == 1
? tablesToFlush.stream().anyMatch(channelCache::getNeedFlush)
: this.isNeedFlush;
long currentTime = System.currentTimeMillis();
final String logInfo;
if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1) {
if (this.owningClient.getParameterProvider().getMaxChunksInBlob() == 1) {
logInfo =
String.format(
"Tables=[%s]",
Expand Down Expand Up @@ -272,7 +272,7 @@ CompletableFuture<Void> flush(boolean isForce) {
this.owningClient.getParameterProvider().getCachedMaxClientLagInMs();

final Set<String> tablesToFlush;
if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1) {
if (this.owningClient.getParameterProvider().getMaxChunksInBlob() == 1) {
tablesToFlush =
this.channelCache.keySet().stream()
.filter(
Expand Down Expand Up @@ -412,15 +412,13 @@ void distributeFlushTasks(Set<String> tablesToFlush) {
channelsDataPerTable.addAll(leftoverChannelsDataPerTable);
leftoverChannelsDataPerTable.clear();
} else if (blobData.size()
>= this.owningClient
.getParameterProvider()
.getMaxChunksInBlobAndRegistrationRequest()) {
>= this.owningClient.getParameterProvider().getMaxChunksInBlob()) {
// Create a new blob if the current one already contains max allowed number of chunks
logger.logInfo(
"Max allowed number of chunks in the current blob reached. chunkCount={}"
+ " maxChunkCount={} currentBlobPath={}",
blobData.size(),
this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest(),
this.owningClient.getParameterProvider().getMaxChunksInBlob(),
blobPath);
break;
} else {
Expand Down Expand Up @@ -599,7 +597,12 @@ BlobMetadata buildAndUpload(
Timer.Context buildContext = Utils.createTimerContext(this.owningClient.buildLatency);

// Construct the blob along with the metadata of the blob
BlobBuilder.Blob blob = BlobBuilder.constructBlobAndMetadata(blobPath, blobData, bdecVersion);
BlobBuilder.Blob blob =
BlobBuilder.constructBlobAndMetadata(
blobPath,
blobData,
bdecVersion,
this.owningClient.getInternalParameterProvider().getEnableChunkEncryption());

blob.blobStats.setBuildDurationMs(buildContext);

Expand Down Expand Up @@ -691,7 +694,7 @@ void shutdown() throws InterruptedException {
*/
void setNeedFlush(String fullyQualifiedTableName) {
this.isNeedFlush = true;
if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() == 1) {
if (this.owningClient.getParameterProvider().getMaxChunksInBlob() == 1) {
this.channelCache.setNeedFlush(fullyQualifiedTableName, true);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.ingest.streaming.internal;

/** A class to provide non-configurable constants depends on Iceberg or non-Iceberg mode */
class InternalParameterProvider {
private final boolean isIcebergMode;

InternalParameterProvider(boolean isIcebergMode) {
this.isIcebergMode = isIcebergMode;
}

boolean getEnableChunkEncryption() {
// When in Iceberg mode, chunk encryption is disabled. Otherwise, it is enabled. Since Iceberg
// mode does not need client-side encryption.
return !isIcebergMode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
// Snowflake role for the client to use
private String role;

// Provides constant values which is determined by the Iceberg or non-Iceberg mode
private final InternalParameterProvider internalParameterProvider;

// Http client to send HTTP requests to Snowflake
private final CloseableHttpClient httpClient;

Expand All @@ -111,6 +114,9 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
// Indicates whether the client has closed
private volatile boolean isClosed;

// Indicates wheter the client is streaming to Iceberg tables
private final boolean isIcebergMode;

// Indicates whether the client is under test mode
private final boolean isTestMode;

Expand Down Expand Up @@ -146,6 +152,7 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
* @param prop connection properties
* @param httpClient http client for sending request
* @param isTestMode whether we're under test mode
* @param isIcebergMode whether we're streaming to Iceberg tables
* @param requestBuilder http request builder
* @param parameterOverrides parameters we override in case we want to set different values
*/
Expand All @@ -154,13 +161,16 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
SnowflakeURL accountURL,
Properties prop,
CloseableHttpClient httpClient,
boolean isIcebergMode,
boolean isTestMode,
RequestBuilder requestBuilder,
Map<String, Object> parameterOverrides) {
this.parameterProvider = new ParameterProvider(parameterOverrides, prop);
this.parameterProvider = new ParameterProvider(parameterOverrides, prop, isIcebergMode);
this.internalParameterProvider = new InternalParameterProvider(isIcebergMode);

this.name = name;
String accountName = accountURL == null ? null : accountURL.getAccount();
this.isIcebergMode = isIcebergMode;
this.isTestMode = isTestMode;
this.httpClient = httpClient == null ? HttpUtil.getHttpClient(accountName) : httpClient;
this.channelCache = new ChannelCache<>();
Expand Down Expand Up @@ -250,23 +260,25 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
* @param accountURL Snowflake account url
* @param prop connection properties
* @param parameterOverrides map of parameters to override for this client
* @param isIcebergMode whether we're streaming to Iceberg tables
* @param isTestMode indicates whether it's under test mode
*/
public SnowflakeStreamingIngestClientInternal(
String name,
SnowflakeURL accountURL,
Properties prop,
Map<String, Object> parameterOverrides,
boolean isIcebergMode,
boolean isTestMode) {
this(name, accountURL, prop, null, isTestMode, null, parameterOverrides);
this(name, accountURL, prop, null, isIcebergMode, isTestMode, null, parameterOverrides);
}

/*** Constructor for TEST ONLY
*
* @param name the name of the client
*/
SnowflakeStreamingIngestClientInternal(String name) {
this(name, null, null, null, true, null, new HashMap<>());
SnowflakeStreamingIngestClientInternal(String name, boolean isIcebergMode) {
this(name, null, null, null, isIcebergMode, true, null, new HashMap<>());
}

// TESTING ONLY - inject the request builder
Expand Down Expand Up @@ -495,21 +507,20 @@ List<List<BlobMetadata>> partitionBlobListForRegistrationRequest(List<BlobMetada
List<List<BlobMetadata>> result = new ArrayList<>();
List<BlobMetadata> currentBatch = new ArrayList<>();
int chunksInCurrentBatch = 0;
int maxChunksInBlobAndRegistrationRequest =
parameterProvider.getMaxChunksInBlobAndRegistrationRequest();
int maxChunksInRegistrationRequest = parameterProvider.getMaxChunksInRegistrationRequest();

for (BlobMetadata blob : blobs) {
if (blob.getChunks().size() > maxChunksInBlobAndRegistrationRequest) {
if (blob.getChunks().size() > maxChunksInRegistrationRequest) {
throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format(
"Incorrectly generated blob detected - number of chunks in the blob is larger than"
+ " the max allowed number of chunks. Please report this bug to Snowflake."
+ " bdec=%s chunkCount=%d maxAllowedChunkCount=%d",
blob.getPath(), blob.getChunks().size(), maxChunksInBlobAndRegistrationRequest));
blob.getPath(), blob.getChunks().size(), maxChunksInRegistrationRequest));
}

if (chunksInCurrentBatch + blob.getChunks().size() > maxChunksInBlobAndRegistrationRequest) {
if (chunksInCurrentBatch + blob.getChunks().size() > maxChunksInRegistrationRequest) {
// Newly added BDEC file would exceed the max number of chunks in a single registration
// request. We put chunks collected so far into the result list and create a new batch with
// the current blob
Expand Down Expand Up @@ -875,6 +886,15 @@ ParameterProvider getParameterProvider() {
return parameterProvider;
}

/**
* Get InternalParameterProvider with internal parameters
*
* @return {@link InternalParameterProvider} used by the client
*/
InternalParameterProvider getInternalParameterProvider() {
return internalParameterProvider;
}

/**
* Set refresh token, this method is for refresh token renewal without requiring to restart
* client. This method only works when the authorization type is OAuth
Expand Down
Loading

0 comments on commit b4b84b8

Please sign in to comment.