From 5aa45d3c9e8b702a2f63e570e307c00bb238e7b7 Mon Sep 17 00:00:00 2001 From: Lukas Sembera Date: Wed, 18 Oct 2023 13:22:58 +0200 Subject: [PATCH] SNOW-902709 Limit the max allowed number of chunks in blob (#580) --- .../streaming/internal/FlushService.java | 12 ++ ...nowflakeStreamingIngestClientInternal.java | 48 ++++- .../ingest/utils/ParameterProvider.java | 21 ++ .../streaming/internal/FlushServiceTest.java | 119 +++++++++++- .../streaming/internal/ManyTablesIT.java | 98 ++++++++++ .../internal/ParameterProviderTest.java | 9 + .../SnowflakeStreamingIngestClientTest.java | 180 +++++++++++++----- 7 files changed, 439 insertions(+), 48 deletions(-) create mode 100644 src/test/java/net/snowflake/ingest/streaming/internal/ManyTablesIT.java diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index ff879b8d6..cb1ef7810 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -350,6 +350,18 @@ void distributeFlushTasks() { if (!leftoverChannelsDataPerTable.isEmpty()) { channelsDataPerTable.addAll(leftoverChannelsDataPerTable); leftoverChannelsDataPerTable.clear(); + } else if (blobData.size() + >= this.owningClient + .getParameterProvider() + .getMaxChunksInBlobAndRegistrationRequest()) { + // Create a new blob if the current one already contains max allowed number of chunks + logger.logInfo( + "Max allowed number of chunks in the current blob reached. chunkCount={}" + + " maxChunkCount={} currentBlobPath={}", + blobData.size(), + this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest(), + blobPath); + break; } else { ConcurrentHashMap> table = itr.next().getValue(); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index 5916e472d..27ff9407e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -452,7 +452,53 @@ ChannelsStatusResponse getChannelsStatus( * @param blobs list of uploaded blobs */ void registerBlobs(List blobs) { - this.registerBlobs(blobs, 0); + for (List blobBatch : partitionBlobListForRegistrationRequest(blobs)) { + this.registerBlobs(blobBatch, 0); + } + } + + /** + * Partition the collection of blobs into sub-lists, so that the total number of chunks in each + * sublist does not exceed the max allowed number of chunks in one registration request. + */ + List> partitionBlobListForRegistrationRequest(List blobs) { + List> result = new ArrayList<>(); + List currentBatch = new ArrayList<>(); + int chunksInCurrentBatch = 0; + int maxChunksInBlobAndRegistrationRequest = + parameterProvider.getMaxChunksInBlobAndRegistrationRequest(); + + for (BlobMetadata blob : blobs) { + if (blob.getChunks().size() > maxChunksInBlobAndRegistrationRequest) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "Incorrectly generated blob detected - number of chunks in the blob is larger than" + + " the max allowed number of chunks. Please report this bug to Snowflake." + + " bdec=%s chunkCount=%d maxAllowedChunkCount=%d", + blob.getPath(), blob.getChunks().size(), maxChunksInBlobAndRegistrationRequest)); + } + + if (chunksInCurrentBatch + blob.getChunks().size() > maxChunksInBlobAndRegistrationRequest) { + // Newly added BDEC file would exceed the max number of chunks in a single registration + // request. We put chunks collected so far into the result list and create a new batch with + // the current blob + result.add(currentBatch); + currentBatch = new ArrayList<>(); + currentBatch.add(blob); + chunksInCurrentBatch = blob.getChunks().size(); + } else { + // Newly added BDEC can be added to the current batch because it does not exceed the max + // number of chunks in a single registration request, yet. + currentBatch.add(blob); + chunksInCurrentBatch += blob.getChunks().size(); + } + } + + if (!currentBatch.isEmpty()) { + result.add(currentBatch); + } + return result; } /** diff --git a/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java b/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java index 5c6f81f66..3a2221697 100644 --- a/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java +++ b/src/main/java/net/snowflake/ingest/utils/ParameterProvider.java @@ -31,6 +31,8 @@ public class ParameterProvider { public static final String MAX_CHUNK_SIZE_IN_BYTES = "MAX_CHUNK_SIZE_IN_BYTES".toLowerCase(); public static final String MAX_ALLOWED_ROW_SIZE_IN_BYTES = "MAX_ALLOWED_ROW_SIZE_IN_BYTES".toLowerCase(); + public static final String MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST = + "MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST".toLowerCase(); public static final String MAX_CLIENT_LAG = "MAX_CLIENT_LAG".toLowerCase(); @@ -59,6 +61,7 @@ public class ParameterProvider { static final long MAX_CLIENT_LAG_MS_MAX = TimeUnit.MINUTES.toMillis(10); public static final long MAX_ALLOWED_ROW_SIZE_IN_BYTES_DEFAULT = 64 * 1024 * 1024; // 64 MB + public static final int MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT = 100; /* Parameter that enables using internal Parquet buffers for buffering of rows before serializing. It reduces memory consumption compared to using Java Objects for buffering.*/ @@ -170,6 +173,11 @@ private void setParameterMap(Map parameterOverrides, Properties this.updateValue(MAX_CLIENT_LAG, MAX_CLIENT_LAG_DEFAULT, parameterOverrides, props); this.updateValue( MAX_CLIENT_LAG_ENABLED, MAX_CLIENT_LAG_ENABLED_DEFAULT, parameterOverrides, props); + this.updateValue( + MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST, + MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT, + parameterOverrides, + props); } /** @return Longest interval in milliseconds between buffer flushes */ @@ -369,6 +377,7 @@ public long getMaxChunkSizeInBytes() { return (val instanceof String) ? Long.parseLong(val.toString()) : (long) val; } + /** @return The max allow row size (in bytes) */ public long getMaxAllowedRowSizeInBytes() { Object val = this.parameterMap.getOrDefault( @@ -376,6 +385,18 @@ public long getMaxAllowedRowSizeInBytes() { return (val instanceof String) ? Long.parseLong(val.toString()) : (long) val; } + /** + * @return The max number of chunks that can be put into a single BDEC or blob registration + * request. + */ + public int getMaxChunksInBlobAndRegistrationRequest() { + Object val = + this.parameterMap.getOrDefault( + MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST, + MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT); + return (val instanceof String) ? Integer.parseInt(val.toString()) : (int) val; + } + @Override public String toString() { return "ParameterProvider{" + "parameterMap=" + parameterMap + '}'; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index 5a77a51db..17f929206 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -35,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.TimeZone; +import java.util.UUID; import java.util.concurrent.TimeUnit; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; @@ -273,7 +274,22 @@ TestContext>> create() { } } - TestContextFactory testContextFactory; + TestContextFactory>> testContextFactory; + + private SnowflakeStreamingIngestChannelInternal>> addChannel( + TestContext>> testContext, int tableId, long encryptionKeyId) { + return testContext + .channelBuilder("channel" + UUID.randomUUID()) + .setDBName("db1") + .setSchemaName("PUBLIC") + .setTableName("table" + tableId) + .setOffsetToken("offset1") + .setChannelSequencer(0L) + .setRowSequencer(0L) + .setEncryptionKey("key") + .setEncryptionKeyId(encryptionKeyId) + .buildAndAdd(); + } private SnowflakeStreamingIngestChannelInternal addChannel1(TestContext testContext) { return testContext @@ -546,6 +562,107 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception { Mockito.verify(flushService, Mockito.times(2)).buildAndUpload(Mockito.any(), Mockito.any()); } + @Test + public void testBlobSplitDueToNumberOfChunks() throws Exception { + for (int rowCount : Arrays.asList(0, 1, 30, 111, 159, 287, 1287, 1599, 4496)) { + runTestBlobSplitDueToNumberOfChunks(rowCount); + } + } + + /** + * Insert rows in batches of 3 into each table and assert that the expected number of blobs is + * generated. + * + * @param numberOfRows How many rows to insert + */ + public void runTestBlobSplitDueToNumberOfChunks(int numberOfRows) throws Exception { + int channelsPerTable = 3; + int expectedBlobs = + (int) + Math.ceil( + (double) numberOfRows + / channelsPerTable + / ParameterProvider.MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST_DEFAULT); + + final TestContext>> testContext = testContextFactory.create(); + + for (int i = 0; i < numberOfRows; i++) { + SnowflakeStreamingIngestChannelInternal>> channel = + addChannel(testContext, i / channelsPerTable, 1); + channel.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1"))); + channel.insertRow(Collections.singletonMap("C1", i), ""); + } + + FlushService>> flushService = testContext.flushService; + flushService.flush(true).get(); + + ArgumentCaptor>>>>> blobDataCaptor = + ArgumentCaptor.forClass(List.class); + Mockito.verify(flushService, Mockito.times(expectedBlobs)) + .buildAndUpload(Mockito.any(), blobDataCaptor.capture()); + + // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns + List>>>>> allUploadedBlobs = + blobDataCaptor.getAllValues(); + + Assert.assertEquals(numberOfRows, getRows(allUploadedBlobs).size()); + } + + @Test + public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Exception { + final TestContext>> testContext = testContextFactory.create(); + + for (int i = 0; i < 99; i++) { // 19 simple chunks + SnowflakeStreamingIngestChannelInternal>> channel = + addChannel(testContext, i, 1); + channel.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1"))); + channel.insertRow(Collections.singletonMap("C1", i), ""); + } + + // 20th chunk would contain multiple channels, but there are some with different encryption key + // ID, so they spill to a new blob + SnowflakeStreamingIngestChannelInternal>> channel1 = + addChannel(testContext, 99, 1); + channel1.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1"))); + channel1.insertRow(Collections.singletonMap("C1", 0), ""); + + SnowflakeStreamingIngestChannelInternal>> channel2 = + addChannel(testContext, 99, 2); + channel2.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1"))); + channel2.insertRow(Collections.singletonMap("C1", 0), ""); + + SnowflakeStreamingIngestChannelInternal>> channel3 = + addChannel(testContext, 99, 2); + channel3.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1"))); + channel3.insertRow(Collections.singletonMap("C1", 0), ""); + + FlushService>> flushService = testContext.flushService; + flushService.flush(true).get(); + + ArgumentCaptor>>>>> blobDataCaptor = + ArgumentCaptor.forClass(List.class); + Mockito.verify(flushService, Mockito.atLeast(2)) + .buildAndUpload(Mockito.any(), blobDataCaptor.capture()); + + // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns + List>>>>> allUploadedBlobs = + blobDataCaptor.getAllValues(); + + Assert.assertEquals(102, getRows(allUploadedBlobs).size()); + } + + private List> getRows(List>>>>> blobs) { + List> result = new ArrayList<>(); + blobs.forEach( + chunks -> + chunks.forEach( + channels -> + channels.forEach( + chunkData -> + result.addAll(((ParquetChunkData) chunkData.getVectors()).rows)))); + return result; + } + @Test public void testBuildAndUpload() throws Exception { long expectedBuildLatencyMs = 100; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ManyTablesIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/ManyTablesIT.java new file mode 100644 index 000000000..d32adfe3f --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ManyTablesIT.java @@ -0,0 +1,98 @@ +package net.snowflake.ingest.streaming.internal; + +import static net.snowflake.ingest.utils.Constants.ROLE; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import net.snowflake.ingest.TestUtils; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestChannel; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient; +import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; +import net.snowflake.ingest.utils.Constants; +import net.snowflake.ingest.utils.ParameterProvider; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Verified that ingestion work when we ingest into large number of tables from the same client and + * blobs and registration requests have to be cut, so they don't contain large number of chunks + */ +public class ManyTablesIT { + + private static final int TABLES_COUNT = 20; + private static final int TOTAL_ROWS_COUNT = 200_000; + private String dbName; + private SnowflakeStreamingIngestClient client; + private Connection connection; + private SnowflakeStreamingIngestChannel[] channels; + private String[] offsetTokensPerChannel; + + @Before + public void setUp() throws Exception { + Properties props = TestUtils.getProperties(Constants.BdecVersion.THREE, false); + props.put(ParameterProvider.MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST, 2); + if (props.getProperty(ROLE).equals("DEFAULT_ROLE")) { + props.setProperty(ROLE, "ACCOUNTADMIN"); + } + client = SnowflakeStreamingIngestClientFactory.builder("client1").setProperties(props).build(); + connection = TestUtils.getConnection(true); + dbName = String.format("sdk_it_many_tables_db_%d", System.nanoTime()); + + channels = new SnowflakeStreamingIngestChannel[TABLES_COUNT]; + offsetTokensPerChannel = new String[TABLES_COUNT]; + connection.createStatement().execute(String.format("create database %s;", dbName)); + + String[] tableNames = new String[TABLES_COUNT]; + for (int i = 0; i < tableNames.length; i++) { + tableNames[i] = String.format("table_%d", i); + connection.createStatement().execute(String.format("create table table_%d(c int);", i)); + channels[i] = + client.openChannel( + OpenChannelRequest.builder(String.format("channel-%d", i)) + .setDBName(dbName) + .setSchemaName("public") + .setTableName(tableNames[i]) + .setOnErrorOption(OpenChannelRequest.OnErrorOption.ABORT) + .build()); + } + } + + @After + public void tearDown() throws Exception { + connection.createStatement().execute(String.format("drop database %s;", dbName)); + client.close(); + connection.close(); + } + + @Test + public void testIngestionIntoManyTables() throws InterruptedException, SQLException { + for (int i = 0; i < TOTAL_ROWS_COUNT; i++) { + Map row = Collections.singletonMap("c", i); + String offset = String.valueOf(i); + int channelId = i % channels.length; + channels[channelId].insertRow(row, offset); + offsetTokensPerChannel[channelId] = offset; + } + + for (int i = 0; i < channels.length; i++) { + TestUtils.waitForOffset(channels[i], offsetTokensPerChannel[i]); + } + + int totalRowsCount = 0; + ResultSet rs = + connection + .createStatement() + .executeQuery(String.format("show tables in database %s;", dbName)); + while (rs.next()) { + totalRowsCount += rs.getInt("rows"); + } + Assert.assertEquals(TOTAL_ROWS_COUNT, totalRowsCount); + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParameterProviderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParameterProviderTest.java index def5f7ecf..7838763de 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParameterProviderTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParameterProviderTest.java @@ -272,4 +272,13 @@ public void testMaxClientLagEnabledThresholdAbove() { Assert.assertTrue(e.getMessage().startsWith("Lag falls outside")); } } + + @Test + public void testMaxChunksInBlobAndRegistrationRequest() { + Properties prop = new Properties(); + Map parameterMap = getStartingParameterMap(); + parameterMap.put("max_chunks_in_blob_and_registration_request", 1); + ParameterProvider parameterProvider = new ParameterProvider(parameterMap, prop); + Assert.assertEquals(1, parameterProvider.getMaxChunksInBlobAndRegistrationRequest()); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index a99054c9f..11fc0b93b 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -11,6 +11,8 @@ import static net.snowflake.ingest.utils.Constants.ROLE; import static net.snowflake.ingest.utils.Constants.USER; import static net.snowflake.ingest.utils.ParameterProvider.ENABLE_SNOWPIPE_STREAMING_METRICS; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; @@ -319,11 +321,11 @@ public void testGetChannelsStatusWithRequest() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = Mockito.spy( @@ -378,11 +380,11 @@ public void testGetChannelsStatusWithRequestError() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(500); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(statusLine.getStatusCode()).thenReturn(500); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = Mockito.spy( @@ -645,12 +647,12 @@ public void testRegisterBlobErrorResponse() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(500); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); + when(statusLine.getStatusCode()).thenReturn(500); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); String response = "testRegisterBlobErrorResponse"; - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); @@ -693,11 +695,11 @@ public void testRegisterBlobSnowflakeInternalErrorResponse() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); @@ -749,11 +751,11 @@ public void testRegisterBlobSuccessResponse() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); @@ -827,16 +829,16 @@ public void testRegisterBlobsRetries() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()) + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()) .thenReturn( IOUtils.toInputStream(responseString), IOUtils.toInputStream(retryResponseString), IOUtils.toInputStream(retryResponseString), IOUtils.toInputStream(retryResponseString)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = Mockito.spy( @@ -862,6 +864,92 @@ public void testRegisterBlobsRetries() throws Exception { Assert.assertFalse(channel2.isValid()); } + @Test + public void testRegisterBlobChunkLimit() throws Exception { + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + RequestBuilder requestBuilder = + Mockito.spy( + new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); + + SnowflakeStreamingIngestClientInternal client = + Mockito.spy( + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null)); + + assertEquals(0, client.partitionBlobListForRegistrationRequest(new ArrayList<>()).size()); + assertEquals( + 1, client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(1)).size()); + assertEquals( + 1, client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(99)).size()); + assertEquals( + 1, client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(100)).size()); + + assertEquals( + 1, client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(3, 95, 2)).size()); + assertEquals( + 2, + client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(3, 95, 2, 1)).size()); + assertEquals( + 3, + client + .partitionBlobListForRegistrationRequest(createTestBlobMetadata(3, 95, 2, 1, 100)) + .size()); + assertEquals( + 2, client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(99, 2)).size()); + assertEquals( + 2, + client + .partitionBlobListForRegistrationRequest(createTestBlobMetadata(55, 44, 2, 98)) + .size()); + assertEquals( + 3, + client + .partitionBlobListForRegistrationRequest(createTestBlobMetadata(55, 44, 2, 99)) + .size()); + assertEquals( + 3, + client + .partitionBlobListForRegistrationRequest(createTestBlobMetadata(55, 44, 2, 99, 1)) + .size()); + } + + /** + * Generate blob metadata with specified number of chunks per blob + * + * @param numbersOfChunks Array of chunk numbers per blob + * @return List of blob metadata + */ + private List createTestBlobMetadata(int... numbersOfChunks) { + List result = new ArrayList<>(); + for (int n : numbersOfChunks) { + List chunkMetadata = new ArrayList<>(); + for (int i = 0; i < n; i++) { + ChunkMetadata chunk = + ChunkMetadata.builder() + .setOwningTableFromChannelContext(channel1.getChannelContext()) + .setChunkStartOffset(0L) + .setChunkLength(1) + .setEncryptionKeyId(0L) + .setChunkMD5("") + .setEpInfo(new EpInfo()) + .setChannelList(new ArrayList<>()) + .setFirstInsertTimeInMs(0L) + .setLastInsertTimeInMs(0L) + .build(); + chunkMetadata.add(chunk); + } + + result.add(new BlobMetadata("", "", chunkMetadata, new BlobStats())); + } + return result; + } + @Test public void testRegisterBlobsRetriesSucceeds() throws Exception { Pair, Set> testData = getRetryBlobMetadata(); @@ -945,13 +1033,13 @@ public void testRegisterBlobsRetriesSucceeds() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()) + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()) .thenReturn( IOUtils.toInputStream(responseString), IOUtils.toInputStream(retryResponseString)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = Mockito.spy( @@ -1022,11 +1110,11 @@ public void testRegisterBlobResponseWithInvalidChannel() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); @@ -1146,7 +1234,7 @@ public void testCloseWithError() throws Exception { CompletableFuture future = new CompletableFuture<>(); future.completeExceptionally(new Exception("Simulating Error")); - Mockito.when(client.flush(true)).thenReturn(future); + when(client.flush(true)).thenReturn(future); Assert.assertFalse(client.isClosed()); try { @@ -1249,11 +1337,11 @@ public void testGetLatestCommittedOffsetTokens() throws Exception { CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); StatusLine statusLine = Mockito.mock(StatusLine.class); HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); + when(statusLine.getStatusCode()).thenReturn(200); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + when(httpResponse.getEntity()).thenReturn(httpEntity); + when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); + when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); RequestBuilder requestBuilder = Mockito.spy(