Skip to content

Commit

Permalink
Merge branch 'master' into PRODSEC-3611
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jfan authored Oct 18, 2023
2 parents e1d9cb3 + 5aa45d3 commit 976eed2
Show file tree
Hide file tree
Showing 7 changed files with 439 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ void distributeFlushTasks() {
if (!leftoverChannelsDataPerTable.isEmpty()) {
channelsDataPerTable.addAll(leftoverChannelsDataPerTable);
leftoverChannelsDataPerTable.clear();
} else if (blobData.size()
>= this.owningClient
.getParameterProvider()
.getMaxChunksInBlobAndRegistrationRequest()) {
// Create a new blob if the current one already contains max allowed number of chunks
logger.logInfo(
"Max allowed number of chunks in the current blob reached. chunkCount={}"
+ " maxChunkCount={} currentBlobPath={}",
blobData.size(),
this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest(),
blobPath);
break;
} else {
ConcurrentHashMap<String, SnowflakeStreamingIngestChannelInternal<T>> table =
itr.next().getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,53 @@ ChannelsStatusResponse getChannelsStatus(
* @param blobs list of uploaded blobs
*/
void registerBlobs(List<BlobMetadata> blobs) {
this.registerBlobs(blobs, 0);
for (List<BlobMetadata> blobBatch : partitionBlobListForRegistrationRequest(blobs)) {
this.registerBlobs(blobBatch, 0);
}
}

/**
* Partition the collection of blobs into sub-lists, so that the total number of chunks in each
* sublist does not exceed the max allowed number of chunks in one registration request.
*/
List<List<BlobMetadata>> partitionBlobListForRegistrationRequest(List<BlobMetadata> blobs) {
List<List<BlobMetadata>> result = new ArrayList<>();
List<BlobMetadata> currentBatch = new ArrayList<>();
int chunksInCurrentBatch = 0;
int maxChunksInBlobAndRegistrationRequest =
parameterProvider.getMaxChunksInBlobAndRegistrationRequest();

for (BlobMetadata blob : blobs) {
if (blob.getChunks().size() > maxChunksInBlobAndRegistrationRequest) {
throw new SFException(
ErrorCode.INTERNAL_ERROR,
String.format(
"Incorrectly generated blob detected - number of chunks in the blob is larger than"
+ " the max allowed number of chunks. Please report this bug to Snowflake."
+ " bdec=%s chunkCount=%d maxAllowedChunkCount=%d",
blob.getPath(), blob.getChunks().size(), maxChunksInBlobAndRegistrationRequest));
}

if (chunksInCurrentBatch + blob.getChunks().size() > maxChunksInBlobAndRegistrationRequest) {
// Newly added BDEC file would exceed the max number of chunks in a single registration
// request. We put chunks collected so far into the result list and create a new batch with
// the current blob
result.add(currentBatch);
currentBatch = new ArrayList<>();
currentBatch.add(blob);
chunksInCurrentBatch = blob.getChunks().size();
} else {
// Newly added BDEC can be added to the current batch because it does not exceed the max
// number of chunks in a single registration request, yet.
currentBatch.add(blob);
chunksInCurrentBatch += blob.getChunks().size();
}
}

if (!currentBatch.isEmpty()) {
result.add(currentBatch);
}
return result;
}

/**
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/net/snowflake/ingest/utils/ParameterProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class ParameterProvider {
public static final String MAX_CHUNK_SIZE_IN_BYTES = "MAX_CHUNK_SIZE_IN_BYTES".toLowerCase();
public static final String MAX_ALLOWED_ROW_SIZE_IN_BYTES =
"MAX_ALLOWED_ROW_SIZE_IN_BYTES".toLowerCase();
public static final String MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST =
"MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST".toLowerCase();

public static final String MAX_CLIENT_LAG = "MAX_CLIENT_LAG".toLowerCase();

Expand Down Expand Up @@ -59,6 +61,7 @@ public class ParameterProvider {

static final long MAX_CLIENT_LAG_MS_MAX = TimeUnit.MINUTES.toMillis(10);
public static final long MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT = 64 * 1024 * 1024; // 64 MB
public static final int MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT = 100;

/* Parameter that enables using internal Parquet buffers for buffering of rows before serializing.
It reduces memory consumption compared to using Java Objects for buffering.*/
Expand Down Expand Up @@ -170,6 +173,11 @@ private void setParameterMap(Map<String, Object> parameterOverrides, Properties
this.updateValue(MAX_CLIENT_LAG, MAX_CLIENT_LAG_DEFAULT, parameterOverrides, props);
this.updateValue(
MAX_CLIENT_LAG_ENABLED, MAX_CLIENT_LAG_ENABLED_DEFAULT, parameterOverrides, props);
this.updateValue(
MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST,
MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT,
parameterOverrides,
props);
}

/** @return Longest interval in milliseconds between buffer flushes */
Expand Down Expand Up @@ -369,13 +377,26 @@ public long getMaxChunkSizeInBytes() {
return (val instanceof String) ? Long.parseLong(val.toString()) : (long) val;
}

/** @return The max allow row size (in bytes) */
public long getMaxAllowedRowSizeInBytes() {
Object val =
this.parameterMap.getOrDefault(
MAX_ALLOWED_ROW_SIZE_IN_BYTES, MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT);
return (val instanceof String) ? Long.parseLong(val.toString()) : (long) val;
}

/**
* @return The max number of chunks that can be put into a single BDEC or blob registration
* request.
*/
public int getMaxChunksInBlobAndRegistrationRequest() {
Object val =
this.parameterMap.getOrDefault(
MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST,
MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT);
return (val instanceof String) ? Integer.parseInt(val.toString()) : (int) val;
}

@Override
public String toString() {
return "ParameterProvider{" + "parameterMap=" + parameterMap + '}';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
Expand Down Expand Up @@ -273,7 +274,22 @@ TestContext<List<List<Object>>> create() {
}
}

TestContextFactory<?> testContextFactory;
TestContextFactory<List<List<Object>>> testContextFactory;

private SnowflakeStreamingIngestChannelInternal<List<List<Object>>> addChannel(
TestContext<List<List<Object>>> testContext, int tableId, long encryptionKeyId) {
return testContext
.channelBuilder("channel" + UUID.randomUUID())
.setDBName("db1")
.setSchemaName("PUBLIC")
.setTableName("table" + tableId)
.setOffsetToken("offset1")
.setChannelSequencer(0L)
.setRowSequencer(0L)
.setEncryptionKey("key")
.setEncryptionKeyId(encryptionKeyId)
.buildAndAdd();
}

private SnowflakeStreamingIngestChannelInternal<?> addChannel1(TestContext<?> testContext) {
return testContext
Expand Down Expand Up @@ -546,6 +562,107 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception {
Mockito.verify(flushService, Mockito.times(2)).buildAndUpload(Mockito.any(), Mockito.any());
}

@Test
public void testBlobSplitDueToNumberOfChunks() throws Exception {
for (int rowCount : Arrays.asList(0, 1, 30, 111, 159, 287, 1287, 1599, 4496)) {
runTestBlobSplitDueToNumberOfChunks(rowCount);
}
}

/**
* Insert rows in batches of 3 into each table and assert that the expected number of blobs is
* generated.
*
* @param numberOfRows How many rows to insert
*/
public void runTestBlobSplitDueToNumberOfChunks(int numberOfRows) throws Exception {
int channelsPerTable = 3;
int expectedBlobs =
(int)
Math.ceil(
(double) numberOfRows
/ channelsPerTable
/ ParameterProvider.MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT);

final TestContext<List<List<Object>>> testContext = testContextFactory.create();

for (int i = 0; i < numberOfRows; i++) {
SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel =
addChannel(testContext, i / channelsPerTable, 1);
channel.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel.insertRow(Collections.singletonMap("C1", i), "");
}

FlushService<List<List<Object>>> flushService = testContext.flushService;
flushService.flush(true).get();

ArgumentCaptor<List<List<ChannelData<List<List<Object>>>>>> blobDataCaptor =
ArgumentCaptor.forClass(List.class);
Mockito.verify(flushService, Mockito.times(expectedBlobs))
.buildAndUpload(Mockito.any(), blobDataCaptor.capture());

// 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns
List<List<List<ChannelData<List<List<Object>>>>>> allUploadedBlobs =
blobDataCaptor.getAllValues();

Assert.assertEquals(numberOfRows, getRows(allUploadedBlobs).size());
}

@Test
public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Exception {
final TestContext<List<List<Object>>> testContext = testContextFactory.create();

for (int i = 0; i < 99; i++) { // 19 simple chunks
SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel =
addChannel(testContext, i, 1);
channel.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel.insertRow(Collections.singletonMap("C1", i), "");
}

// 20th chunk would contain multiple channels, but there are some with different encryption key
// ID, so they spill to a new blob
SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel1 =
addChannel(testContext, 99, 1);
channel1.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel1.insertRow(Collections.singletonMap("C1", 0), "");

SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel2 =
addChannel(testContext, 99, 2);
channel2.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel2.insertRow(Collections.singletonMap("C1", 0), "");

SnowflakeStreamingIngestChannelInternal<List<List<Object>>> channel3 =
addChannel(testContext, 99, 2);
channel3.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1")));
channel3.insertRow(Collections.singletonMap("C1", 0), "");

FlushService<List<List<Object>>> flushService = testContext.flushService;
flushService.flush(true).get();

ArgumentCaptor<List<List<ChannelData<List<List<Object>>>>>> blobDataCaptor =
ArgumentCaptor.forClass(List.class);
Mockito.verify(flushService, Mockito.atLeast(2))
.buildAndUpload(Mockito.any(), blobDataCaptor.capture());

// 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns
List<List<List<ChannelData<List<List<Object>>>>>> allUploadedBlobs =
blobDataCaptor.getAllValues();

Assert.assertEquals(102, getRows(allUploadedBlobs).size());
}

private List<List<Object>> getRows(List<List<List<ChannelData<List<List<Object>>>>>> blobs) {
List<List<Object>> result = new ArrayList<>();
blobs.forEach(
chunks ->
chunks.forEach(
channels ->
channels.forEach(
chunkData ->
result.addAll(((ParquetChunkData) chunkData.getVectors()).rows))));
return result;
}

@Test
public void testBuildAndUpload() throws Exception {
long expectedBuildLatencyMs = 100;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package net.snowflake.ingest.streaming.internal;

import static net.snowflake.ingest.utils.Constants.ROLE;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Collections;
import java.util.Map;
import java.util.Properties;
import net.snowflake.ingest.TestUtils;
import net.snowflake.ingest.streaming.OpenChannelRequest;
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel;
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient;
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory;
import net.snowflake.ingest.utils.Constants;
import net.snowflake.ingest.utils.ParameterProvider;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/**
* Verified that ingestion work when we ingest into large number of tables from the same client and
* blobs and registration requests have to be cut, so they don't contain large number of chunks
*/
public class ManyTablesIT {

private static final int TABLES_COUNT = 20;
private static final int TOTAL_ROWS_COUNT = 200_000;
private String dbName;
private SnowflakeStreamingIngestClient client;
private Connection connection;
private SnowflakeStreamingIngestChannel[] channels;
private String[] offsetTokensPerChannel;

@Before
public void setUp() throws Exception {
Properties props = TestUtils.getProperties(Constants.BdecVersion.THREE, false);
props.put(ParameterProvider.MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST, 2);
if (props.getProperty(ROLE).equals("DEFAULT_ROLE")) {
props.setProperty(ROLE, "ACCOUNTADMIN");
}
client = SnowflakeStreamingIngestClientFactory.builder("client1").setProperties(props).build();
connection = TestUtils.getConnection(true);
dbName = String.format("sdk_it_many_tables_db_%d", System.nanoTime());

channels = new SnowflakeStreamingIngestChannel[TABLES_COUNT];
offsetTokensPerChannel = new String[TABLES_COUNT];
connection.createStatement().execute(String.format("create database %s;", dbName));

String[] tableNames = new String[TABLES_COUNT];
for (int i = 0; i < tableNames.length; i++) {
tableNames[i] = String.format("table_%d", i);
connection.createStatement().execute(String.format("create table table_%d(c int);", i));
channels[i] =
client.openChannel(
OpenChannelRequest.builder(String.format("channel-%d", i))
.setDBName(dbName)
.setSchemaName("public")
.setTableName(tableNames[i])
.setOnErrorOption(OpenChannelRequest.OnErrorOption.ABORT)
.build());
}
}

@After
public void tearDown() throws Exception {
connection.createStatement().execute(String.format("drop database %s;", dbName));
client.close();
connection.close();
}

@Test
public void testIngestionIntoManyTables() throws InterruptedException, SQLException {
for (int i = 0; i < TOTAL_ROWS_COUNT; i++) {
Map<String, Object> row = Collections.singletonMap("c", i);
String offset = String.valueOf(i);
int channelId = i % channels.length;
channels[channelId].insertRow(row, offset);
offsetTokensPerChannel[channelId] = offset;
}

for (int i = 0; i < channels.length; i++) {
TestUtils.waitForOffset(channels[i], offsetTokensPerChannel[i]);
}

int totalRowsCount = 0;
ResultSet rs =
connection
.createStatement()
.executeQuery(String.format("show tables in database %s;", dbName));
while (rs.next()) {
totalRowsCount += rs.getInt("rows");
}
Assert.assertEquals(TOTAL_ROWS_COUNT, totalRowsCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,13 @@ public void testMaxClientLagEnabledThresholdAbove() {
Assert.assertTrue(e.getMessage().startsWith("Lag falls outside"));
}
}

@Test
public void testMaxChunksInBlobAndRegistrationRequest() {
Properties prop = new Properties();
Map<String, Object> parameterMap = getStartingParameterMap();
parameterMap.put("max_chunks_in_blob_and_registration_request", 1);
ParameterProvider parameterProvider = new ParameterProvider(parameterMap, prop);
Assert.assertEquals(1, parameterProvider.getMaxChunksInBlobAndRegistrationRequest());
}
}
Loading

0 comments on commit 976eed2

Please sign in to comment.