Skip to content

Commit

Permalink
test both code paths
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bmikaili committed Sep 4, 2024
1 parent fc864f1 commit 34968af
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ static <T> Blob constructBlobAndMetadata(
// TODO: address alignment for the header SNOW-557866
long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES;

if (encryptionKey != null)
compressedChunkData =
Cryptor.encrypt(paddedChunkData, encryptionKey.getEncryptionKey(), filePath, iv);
else
compressedChunkData =
Cryptor.encrypt(
paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv);
if (encryptionKey == null)
encryptionKey =
new EncryptionKey(
firstChannelFlushContext.getDbName(),
firstChannelFlushContext.getSchemaName(),
firstChannelFlushContext.getTableName(),
firstChannelFlushContext.getEncryptionKey(),
firstChannelFlushContext.getEncryptionKeyId());

compressedChunkData =
Cryptor.encrypt(paddedChunkData, encryptionKey.getEncryptionKey(), filePath, iv);

compressedChunkDataSize = compressedChunkData.length;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -456,6 +462,8 @@ && shouldStopProcessing(
totalBufferSizeInBytes,
totalBufferSizePerTableInBytes,
channelData.getBufferSize(),
channelData.getChannelContext().getEncryptionKeyId(),
channelsDataPerTable.get(idx - 1).getChannelContext().getEncryptionKeyId(),
channelData.getColumnEps().keySet(),
channelsDataPerTable.get(idx - 1).getColumnEps().keySet());
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,17 +382,15 @@ public SnowflakeStreamingIngestChannelInternal<?> openChannel(OpenChannelRequest
this.channelCache.addChannel(channel);

// Add encryption key to the client map for the table
if (response.getEncryptionKey() != null) {
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
request.getDBName(), request.getSchemaName(), request.getTableName()),
new EncryptionKey(
response.getDBName(),
response.getSchemaName(),
response.getTableName(),
response.getEncryptionKey(),
response.getEncryptionKeyId()));
}
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
request.getDBName(), request.getSchemaName(), request.getTableName()),
new EncryptionKey(
response.getDBName(),
response.getSchemaName(),
response.getTableName(),
response.getEncryptionKey(),
response.getEncryptionKeyId()));

return channel;
} catch (IOException | IngestResponseException e) {
Expand Down Expand Up @@ -592,11 +590,15 @@ void registerBlobs(List<BlobMetadata> blobs, final int executionCount) {
executionCount);

// Update encryption keys for the table given the response
for (EncryptionKey key : response.getEncryptionKeys()) {
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
key.getDatabaseName(), key.getSchemaName(), key.getTableName()),
key);
if (response.getEncryptionKeys() == null) {
this.encryptionKeysPerTable.clear();
} else {
for (EncryptionKey key : response.getEncryptionKeys()) {
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
key.getDatabaseName(), key.getSchemaName(), key.getTableName()),
key);
}
}

// We will retry any blob chunks that were rejected because internal Snowflake queues are full
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,23 @@ private abstract static class TestContext<T> implements AutoCloseable {
.thenAnswer((Answer<ParameterProvider>) (i) -> parameterProvider);

encryptionKeysPerTable = new ConcurrentHashMap<>();
encryptionKeysPerTable.put(
new FullyQualifiedTableName("db1", "schema1", "table1"),
new EncryptionKey("db1", "schema1", "table1", "key1", 1234L));
encryptionKeysPerTable.put(
new FullyQualifiedTableName("db2", "schema1", "table2"),
new EncryptionKey("db2", "schema1", "table2", "key1", 1234L));

for (int i = 0; i <= 9999; i++) {
if (isIcebergMode) {
encryptionKeysPerTable.put(
new FullyQualifiedTableName("db1", "PUBLIC", String.format("table%d", i)),
new EncryptionKey("db1", "PUBLIC", String.format("table%d", i), "key1", 1234L));
new FullyQualifiedTableName("db1", "schema1", "table1"),
new EncryptionKey("db1", "schema1", "table1", "key1", 1234L));
encryptionKeysPerTable.put(
new FullyQualifiedTableName("db2", "schema1", "table2"),
new EncryptionKey("db2", "schema1", "table2", "key1", 1234L));

for (int i = 0; i <= 9999; i++) {
encryptionKeysPerTable.put(
new FullyQualifiedTableName("db1", "PUBLIC", String.format("table%d", i)),
new EncryptionKey("db1", "PUBLIC", String.format("table%d", i), "key1", 1234L));
}

Mockito.when(client.getEncryptionKeysPerTable()).thenReturn(encryptionKeysPerTable);
}

Mockito.when(client.getEncryptionKeysPerTable()).thenReturn(encryptionKeysPerTable);
channelCache = new ChannelCache<>();
Mockito.when(client.getChannelCache()).thenReturn(channelCache);
registerService = Mockito.spy(new RegisterService(client, client.isTestMode()));
Expand Down Expand Up @@ -629,13 +632,13 @@ public void testBlobCreation() throws Exception {

channel1.insertRows(rows1, "offset1");
channel2.insertRows(rows1, "offset2");
channel4.insertRows(rows1, "offset3");
channel4.insertRows(rows1, "offset4");

FlushService<?> flushService = testContext.flushService;

// Force = true flushes
flushService.flush(true).get();
Mockito.verify(flushService, Mockito.atLeast(1))
Mockito.verify(flushService, Mockito.atLeast(2))
.buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any());
}

Expand Down Expand Up @@ -772,6 +775,49 @@ public void runTestBlobSplitDueToNumberOfChunks(int numberOfRows) throws Excepti
Assert.assertEquals(numberOfRows, getRows(allUploadedBlobs).size());
}

@Test
public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Exception {
final TestContext<List<List<Object>>> testContext = testContextFactory.create();

for (int i = 0; i < 99; i++) { // 19 simple chunks
SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel =
addChannel(testContext, i, 1);
channel.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel.insertRow(Collections.singletonMap("C1", i), "");
}

// 20th chunk would contain multiple channels, but there are some with different encryption key
// ID, so they spill to a new blob
SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel1 =
addChannel(testContext, 99, 1);
channel1.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel1.insertRow(Collections.singletonMap("C1", 0), "");

SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel2 =
addChannel(testContext, 99, 2);
channel2.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel2.insertRow(Collections.singletonMap("C1", 0), "");

SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel3 =
addChannel(testContext, 99, 2);
channel3.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel3.insertRow(Collections.singletonMap("C1", 0), "");

FlushService<List<List<Object>>> flushService = testContext.flushService;
flushService.flush(true).get();

ArgumentCaptor<List<List<ChannelData<List<List<Object>>>>>> blobDataCaptor =
ArgumentCaptor.forClass(List.class);
Mockito.verify(flushService, Mockito.atLeast(2))
.buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any(), Mockito.any());

// 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns
List<List<List<ChannelData<List<List<Object>>>>>> allUploadedBlobs =
blobDataCaptor.getAllValues();

Assert.assertEquals(102, getRows(allUploadedBlobs).size());
}

private List<List<Object>> getRows(List<List<List<ChannelData<List<List<Object>>>>>> blobs) {
List<List<Object>> result = new ArrayList<>();
blobs.forEach(
Expand Down

0 comments on commit 34968af

Please sign in to comment.