diff --git a/pom.xml b/pom.xml index 4e6907047..4c6665eec 100644 --- a/pom.xml +++ b/pom.xml @@ -358,6 +358,18 @@ 3.7.7 test + + org.openjdk.jmh + jmh-core + 1.34 + test + + + org.openjdk.jmh + jmh-generator-annprocess + 1.34 + test + @@ -470,6 +482,13 @@ org.apache.parquet parquet-common + + + + javax.annotation + javax.annotation-api + + org.apache.parquet @@ -527,6 +546,16 @@ mockito-core test + + org.openjdk.jmh + jmh-core + test + + + org.openjdk.jmh + jmh-generator-annprocess + test + org.powermock powermock-api-mockito2 @@ -723,8 +752,8 @@ true + to workaround https://issues.apache.org/jira/browse/MNG-7982. Now the dependency analyzer complains that + the dependency is unused, so we ignore it here--> org.apache.commons:commons-compress org.apache.commons:commons-configuration2 @@ -818,10 +847,8 @@ 2.0.1 failFast - + Apache License 2.0 BSD 2-Clause License @@ -1133,10 +1160,8 @@ - + org.codehaus.mojo exec-maven-plugin diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java index 6d5dce17f..71a9d501e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java @@ -16,7 +16,6 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; -import java.util.stream.Collectors; import net.snowflake.ingest.connection.TelemetryService; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OffsetTokenVerificationFunction; @@ -400,10 +399,10 @@ public float getSize() { Set verifyInputColumns( Map row, InsertValidationResponse.InsertError error, int rowIndex) { // Map of unquoted column name -> original column name - Map inputColNamesMap = - row.keySet().stream() - .collect(Collectors.toMap(LiteralQuoteUtils::unquoteColumnName, value -> value)); - + Set originalKeys = row.keySet(); + Map inputColNamesMap = new HashMap<>(); + originalKeys.forEach( + key -> inputColNamesMap.put(LiteralQuoteUtils.unquoteColumnName(key), key)); // Check for extra columns in the row List extraCols = new ArrayList<>(); for (String columnName : inputColNamesMap.keySet()) { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java index 814423c28..162e56145 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java @@ -86,6 +86,18 @@ class DataValidationUtil { objectMapper.registerModule(module); } + // Caching the powers of 10 that are used for checking the range of numbers because computing them + // on-demand is expensive. + private static final BigDecimal[] POWER_10 = makePower10Table(); + + private static BigDecimal[] makePower10Table() { + BigDecimal[] power10 = new BigDecimal[Power10.sb16Size]; + for (int i = 0; i < Power10.sb16Size; i++) { + power10[i] = new BigDecimal(Power10.sb16Table[i]); + } + return power10; + } + /** * Validates and parses input as JSON. All types in the object tree must be valid variant types, * see {@link DataValidationUtil#isAllowedSemiStructuredType}. @@ -823,7 +835,11 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR static void checkValueInRange( BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) { - if (bigDecimalValue.abs().compareTo(BigDecimal.TEN.pow(precision - scale)) >= 0) { + BigDecimal comparand = + (precision >= scale) && (precision - scale) < POWER_10.length + ? POWER_10[precision - scale] + : BigDecimal.TEN.pow(precision - scale); + if (bigDecimalValue.abs().compareTo(comparand) >= 0) { throw new SFException( ErrorCode.INVALID_FORMAT_ROW, String.format( diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index f08196477..76e43ff4d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -122,6 +122,7 @@ List>> getData() { // blob encoding version private final Constants.BdecVersion bdecVersion; + private volatile int numProcessors = Runtime.getRuntime().availableProcessors(); /** * Constructor for TESTING that takes (usually mocked) StreamingIngestStage @@ -360,6 +361,9 @@ void distributeFlushTasks() { List, CompletableFuture>> blobs = new ArrayList<>(); List> leftoverChannelsDataPerTable = new ArrayList<>(); + // The API states that the number of available processors reported can change and therefore, we + // should poll it occasionally. + numProcessors = Runtime.getRuntime().availableProcessors(); while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) { List>> blobData = new ArrayList<>(); float totalBufferSizeInBytes = 0F; @@ -704,8 +708,7 @@ String getClientPrefix() { */ boolean throttleDueToQueuedFlushTasks() { ThreadPoolExecutor buildAndUpload = (ThreadPoolExecutor) this.buildUploadWorkers; - boolean throttleOnQueuedTasks = - buildAndUpload.getQueue().size() > Runtime.getRuntime().availableProcessors(); + boolean throttleOnQueuedTasks = buildAndUpload.getQueue().size() > numProcessors; if (throttleOnQueuedTasks) { logger.logWarn( "Throttled due too many queue flush tasks (probably because of slow uploading speed)," diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java index f426e898d..777ae4fdc 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java @@ -9,9 +9,6 @@ public interface MemoryInfoProvider { /** @return Max memory the JVM can allocate */ long getMaxMemory(); - /** @return Total allocated JVM memory so far */ - long getTotalMemory(); - /** @return Free JVM memory */ long getFreeMemory(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java index 3a957f225..d248ddfd9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java @@ -4,20 +4,51 @@ package net.snowflake.ingest.streaming.internal; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + /** Reads memory information from JVM runtime */ public class MemoryInfoProviderFromRuntime implements MemoryInfoProvider { - @Override - public long getMaxMemory() { - return Runtime.getRuntime().maxMemory(); + private final long maxMemory; + private volatile long totalFreeMemory; + private final ScheduledExecutorService executorService; + private static final long FREE_MEMORY_UPDATE_INTERVAL_MS = 100; + private static final MemoryInfoProviderFromRuntime INSTANCE = + new MemoryInfoProviderFromRuntime(FREE_MEMORY_UPDATE_INTERVAL_MS); + + private MemoryInfoProviderFromRuntime(long freeMemoryUpdateIntervalMs) { + maxMemory = Runtime.getRuntime().maxMemory(); + totalFreeMemory = + Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory()); + executorService = + new ScheduledThreadPoolExecutor( + 1, + r -> { + Thread th = new Thread(r, "MemoryInfoProviderFromRuntime"); + th.setDaemon(true); + return th; + }); + executorService.scheduleAtFixedRate( + this::updateFreeMemory, 0, freeMemoryUpdateIntervalMs, TimeUnit.MILLISECONDS); + } + + private void updateFreeMemory() { + totalFreeMemory = + Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory()); + } + + public static MemoryInfoProviderFromRuntime getInstance() { + return INSTANCE; } @Override - public long getTotalMemory() { - return Runtime.getRuntime().totalMemory(); + public long getMaxMemory() { + return maxMemory; } @Override public long getFreeMemory() { - return Runtime.getRuntime().freeMemory(); + return totalFreeMemory; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 39ec66dbb..3ad3db5f4 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Logging; @@ -124,6 +125,12 @@ private SerializationResult serializeFromParquetWriteBuffers( if (mergedChannelWriter != null) { mergedChannelWriter.close(); + this.verifyRowCounts( + "serializeFromParquetWriteBuffers", + mergedChannelWriter, + rowCount, + channelsDataPerTable, + -1); } return new SerializationResult( channelsMetadataList, @@ -216,6 +223,9 @@ private SerializationResult serializeFromJavaObjects( rows.forEach(parquetWriter::writeRow); parquetWriter.close(); + this.verifyRowCounts( + "serializeFromJavaObjects", parquetWriter, rowCount, channelsDataPerTable, rows.size()); + return new SerializationResult( channelsMetadataList, columnEpStatsMapCombined, @@ -224,4 +234,71 @@ private SerializationResult serializeFromJavaObjects( mergedData, chunkMinMaxInsertTimeInMs); } + + /** + * Validates that rows count in metadata matches the row count in Parquet footer and the row count + * written by the parquet writer + * + * @param serializationType Serialization type, used for logging purposes only + * @param writer Parquet writer writing the data + * @param channelsDataPerTable Channel data + * @param totalMetadataRowCount Row count calculated during metadata collection + * @param javaSerializationTotalRowCount Total row count when java object serialization is used. + * Used only for logging purposes if there is a mismatch. + */ + private void verifyRowCounts( + String serializationType, + BdecParquetWriter writer, + long totalMetadataRowCount, + List> channelsDataPerTable, + long javaSerializationTotalRowCount) { + long parquetTotalRowsWritten = writer.getRowsWritten(); + + List parquetFooterRowsPerBlock = writer.getRowCountsFromFooter(); + long parquetTotalRowsInFooter = 0; + for (long perBlockCount : parquetFooterRowsPerBlock) { + parquetTotalRowsInFooter += perBlockCount; + } + + if (parquetTotalRowsInFooter != totalMetadataRowCount + || parquetTotalRowsWritten != totalMetadataRowCount) { + + final String perChannelRowCountsInMetadata = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getRowCount())) + .collect(Collectors.joining(",")); + + final String channelNames = + channelsDataPerTable.stream() + .map(x -> String.valueOf(x.getChannelContext().getName())) + .collect(Collectors.joining(",")); + + final String perBlockRowCountsInFooter = + parquetFooterRowsPerBlock.stream().map(String::valueOf).collect(Collectors.joining(",")); + + final long channelsCountInMetadata = channelsDataPerTable.size(); + + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "[%s]The number of rows in Parquet does not match the number of rows in metadata. " + + "parquetTotalRowsInFooter=%d " + + "totalMetadataRowCount=%d " + + "parquetTotalRowsWritten=%d " + + "perChannelRowCountsInMetadata=%s " + + "perBlockRowCountsInFooter=%s " + + "channelsCountInMetadata=%d " + + "countOfSerializedJavaObjects=%d " + + "channelNames=%s", + serializationType, + parquetTotalRowsInFooter, + totalMetadataRowCount, + parquetTotalRowsWritten, + perChannelRowCountsInMetadata, + perBlockRowCountsInFooter, + channelsCountInMetadata, + javaSerializationTotalRowCount, + channelNames)); + } + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java index 58e81d116..8ebc23ca1 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java @@ -45,6 +45,10 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn // Reference to the row buffer private final RowBuffer rowBuffer; + private final long insertThrottleIntervalInMs; + private final int insertThrottleThresholdInBytes; + private final int insertThrottleThresholdInPercentage; + private final long maxMemoryLimitInBytes; // Indicates whether the channel is closed private volatile boolean isClosed; @@ -61,6 +65,9 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn // The latest cause of channel invalidation private String invalidationCause; + private final MemoryInfoProvider memoryInfoProvider; + private volatile long freeMemoryInBytes = 0; + /** * Constructor for TESTING ONLY which allows us to set the test mode * @@ -121,6 +128,17 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn OffsetTokenVerificationFunction offsetTokenVerificationFunction) { this.isClosed = false; this.owningClient = client; + + this.insertThrottleIntervalInMs = + this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs(); + this.insertThrottleThresholdInBytes = + this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes(); + this.insertThrottleThresholdInPercentage = + this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage(); + this.maxMemoryLimitInBytes = + this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes(); + + this.memoryInfoProvider = MemoryInfoProviderFromRuntime.getInstance(); this.channelFlushContext = new ChannelFlushContext( name, dbName, schemaName, tableName, channelSequencer, encryptionKey, encryptionKeyId); @@ -373,7 +391,7 @@ public InsertValidationResponse insertRows( Iterable> rows, @Nullable String startOffsetToken, @Nullable String endOffsetToken) { - throttleInsertIfNeeded(new MemoryInfoProviderFromRuntime()); + throttleInsertIfNeeded(memoryInfoProvider); checkValidation(); if (isClosed()) { @@ -448,8 +466,6 @@ public Map getTableSchema() { /** Check whether we need to throttle the insertRows API */ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) { int retry = 0; - long insertThrottleIntervalInMs = - this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs(); while ((hasLowRuntimeMemory(memoryInfoProvider) || (this.owningClient.getFlushService() != null && this.owningClient.getFlushService().throttleDueToQueuedFlushTasks())) @@ -473,22 +489,14 @@ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) { /** Check whether we have a low runtime memory condition */ private boolean hasLowRuntimeMemory(MemoryInfoProvider memoryInfoProvider) { - int insertThrottleThresholdInBytes = - this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes(); - int insertThrottleThresholdInPercentage = - this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage(); - long maxMemoryLimitInBytes = - this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes(); long maxMemory = maxMemoryLimitInBytes == MAX_MEMORY_LIMIT_IN_BYTES_DEFAULT ? memoryInfoProvider.getMaxMemory() : maxMemoryLimitInBytes; - long freeMemory = - memoryInfoProvider.getFreeMemory() - + (memoryInfoProvider.getMaxMemory() - memoryInfoProvider.getTotalMemory()); + freeMemoryInBytes = memoryInfoProvider.getFreeMemory(); boolean hasLowRuntimeMemory = - freeMemory < insertThrottleThresholdInBytes - && freeMemory * 100 / maxMemory < insertThrottleThresholdInPercentage; + freeMemoryInBytes < insertThrottleThresholdInBytes + && freeMemoryInBytes * 100 / maxMemory < insertThrottleThresholdInPercentage; if (hasLowRuntimeMemory) { logger.logWarn( "Throttled due to memory pressure, client={}, channel={}.", diff --git a/src/main/java/net/snowflake/ingest/utils/HttpUtil.java b/src/main/java/net/snowflake/ingest/utils/HttpUtil.java index 1ff65a095..1be382797 100644 --- a/src/main/java/net/snowflake/ingest/utils/HttpUtil.java +++ b/src/main/java/net/snowflake/ingest/utils/HttpUtil.java @@ -294,7 +294,8 @@ static HttpRequestRetryHandler getHttpRequestRetryHandler() { if (exception instanceof NoHttpResponseException || exception instanceof javax.net.ssl.SSLException || exception instanceof java.net.SocketException - || exception instanceof java.net.UnknownHostException) { + || exception instanceof java.net.UnknownHostException + || exception instanceof java.net.SocketTimeoutException) { LOGGER.info( "Retrying request which caused {} with " + "URI:{}, retryCount:{} and maxRetryCount:{}", exception.getClass().getName(), diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 8b71cfd0e..58e7df4f3 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -6,6 +6,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; import net.snowflake.ingest.utils.Constants; @@ -17,6 +18,7 @@ import org.apache.parquet.column.values.factory.DefaultV1ValuesWriterFactory; import org.apache.parquet.crypto.FileEncryptionProperties; import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.hadoop.metadata.BlockMetaData; import org.apache.parquet.io.DelegatingPositionOutputStream; import org.apache.parquet.io.OutputFile; import org.apache.parquet.io.ParquetEncodingException; @@ -35,6 +37,7 @@ public class BdecParquetWriter implements AutoCloseable { private final InternalParquetRecordWriter> writer; private final CodecFactory codecFactory; + private long rowsWritten = 0; /** * Creates a BDEC specific parquet writer. @@ -100,14 +103,28 @@ public BdecParquetWriter( encodingProps); } + /** @return List of row counts per block stored in the parquet footer */ + public List getRowCountsFromFooter() { + final List blockRowCounts = new ArrayList<>(); + for (BlockMetaData metadata : writer.getFooter().getBlocks()) { + blockRowCounts.add(metadata.getRowCount()); + } + return blockRowCounts; + } + public void writeRow(List row) { try { writer.write(row); + rowsWritten++; } catch (InterruptedException | IOException e) { throw new SFException(ErrorCode.INTERNAL_ERROR, "parquet row write failed", e); } } + public long getRowsWritten() { + return rowsWritten; + } + @Override public void close() throws IOException { try { diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java new file mode 100644 index 000000000..e220aec79 --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -0,0 +1,133 @@ +package net.snowflake.ingest.streaming.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import net.snowflake.ingest.utils.Constants; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.hadoop.BdecParquetWriter; +import org.apache.parquet.schema.MessageType; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +public class BlobBuilderTest { + + @Test + public void testSerializationErrors() throws Exception { + // Construction succeeds if both data and metadata contain 1 row + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, false)), + Constants.BdecVersion.THREE); + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(1, true)), + Constants.BdecVersion.THREE); + + // Construction fails if metadata contains 0 rows and data 1 row + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, false)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=false"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromJavaObjects")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalMetadataRowCount=0")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsWritten=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("perBlockRowCountsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("channelsCountInMetadata=1")); + Assert.assertTrue(e.getMessage().contains("countOfSerializedJavaObjects=1")); + } + + try { + BlobBuilder.constructBlobAndMetadata( + "a.bdec", + Collections.singletonList(createChannelDataPerTable(0, true)), + Constants.BdecVersion.THREE); + Assert.fail("Should not pass enableParquetInternalBuffering=true"); + } catch (SFException e) { + Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); + Assert.assertTrue(e.getMessage().contains("serializeFromParquetWriteBuffers")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("totalMetadataRowCount=0")); + Assert.assertTrue(e.getMessage().contains("parquetTotalRowsWritten=1")); + Assert.assertTrue(e.getMessage().contains("perChannelRowCountsInMetadata=0")); + Assert.assertTrue(e.getMessage().contains("perBlockRowCountsInFooter=1")); + Assert.assertTrue(e.getMessage().contains("channelsCountInMetadata=1")); + Assert.assertTrue(e.getMessage().contains("countOfSerializedJavaObjects=-1")); + } + } + + /** + * Creates a channel data configurable number of rows in metadata and 1 physical row (using both + * with and without internal buffering optimization) + */ + private List> createChannelDataPerTable( + int metadataRowCount, boolean enableParquetInternalBuffering) throws IOException { + String columnName = "C1"; + ChannelData channelData = Mockito.spy(new ChannelData<>()); + MessageType schema = createSchema(columnName); + Mockito.doReturn( + new ParquetFlusher( + schema, + enableParquetInternalBuffering, + 100L, + Constants.BdecParquetCompression.GZIP)) + .when(channelData) + .createFlusher(); + + channelData.setRowSequencer(1L); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + BdecParquetWriter bdecParquetWriter = + new BdecParquetWriter( + stream, + schema, + new HashMap<>(), + "CHANNEL", + 1000, + Constants.BdecParquetCompression.GZIP); + bdecParquetWriter.writeRow(Collections.singletonList("1")); + channelData.setVectors( + new ParquetChunkData( + Collections.singletonList(Collections.singletonList("A")), + bdecParquetWriter, + stream, + new HashMap<>())); + channelData.setColumnEps(new HashMap<>()); + channelData.setRowCount(metadataRowCount); + channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); + + channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData.setChannelContext( + new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); + return Collections.singletonList(channelData); + } + + private static MessageType createSchema(String columnName) { + ParquetTypeGenerator.ParquetTypeInfo c1 = + ParquetTypeGenerator.generateColumnParquetTypeInfo(createTestTextColumn(columnName), 1); + return new MessageType("bdec", Collections.singletonList(c1.getParquetType())); + } + + private static ColumnMetadata createTestTextColumn(String name) { + ColumnMetadata colChar = new ColumnMetadata(); + colChar.setOrdinal(1); + colChar.setName(name); + colChar.setPhysicalType("LOB"); + colChar.setNullable(true); + colChar.setLogicalType("TEXT"); + colChar.setByteLength(14); + colChar.setLength(11); + colChar.setScale(0); + return colChar; + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java new file mode 100644 index 000000000..5b28e9c45 --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java @@ -0,0 +1,122 @@ +package net.snowflake.ingest.streaming.internal; + +import static java.time.ZoneOffset.UTC; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import net.snowflake.ingest.streaming.InsertValidationResponse; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.utils.Utils; +import org.junit.Assert; +import org.junit.Test; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.TimeValue; + +@State(Scope.Thread) +public class InsertRowsBenchmarkTest { + + private SnowflakeStreamingIngestChannelInternal channel; + private SnowflakeStreamingIngestClientInternal client; + + @Param({"100000"}) + private int numRows; + + @Setup(Level.Trial) + public void setUpBeforeAll() { + client = new SnowflakeStreamingIngestClientInternal("client_PARQUET"); + channel = + new SnowflakeStreamingIngestChannelInternal<>( + "channel", + "db", + "schema", + "table", + "0", + 0L, + 0L, + client, + "key", + 1234L, + OpenChannelRequest.OnErrorOption.CONTINUE, + UTC); + // Setup column fields and vectors + ColumnMetadata col = new ColumnMetadata(); + col.setOrdinal(1); + col.setName("COL"); + col.setPhysicalType("SB16"); + col.setNullable(false); + col.setLogicalType("FIXED"); + col.setPrecision(38); + col.setScale(0); + + channel.setupSchema(Collections.singletonList(col)); + assert Utils.getProvider() != null; + } + + @TearDown(Level.Trial) + public void tearDownAfterAll() throws Exception { + channel.close(); + client.close(); + } + + @Benchmark + public void testInsertRow() { + Map row = new HashMap<>(); + row.put("col", 1); + + for (int i = 0; i < numRows; i++) { + InsertValidationResponse response = channel.insertRow(row, String.valueOf(i)); + Assert.assertFalse(response.hasErrors()); + } + } + + @Test + public void insertRow() throws Exception { + setUpBeforeAll(); + Map row = new HashMap<>(); + row.put("col", 1); + + for (int i = 0; i < 1000000; i++) { + InsertValidationResponse response = channel.insertRow(row, String.valueOf(i)); + Assert.assertFalse(response.hasErrors()); + } + tearDownAfterAll(); + } + + @Test + public void launchBenchmark() throws RunnerException { + Options opt = + new OptionsBuilder() + // Specify which benchmarks to run. + // You can be more specific if you'd like to run only one benchmark per test. + .include(this.getClass().getName() + ".*") + // Set the following options as needed + .mode(Mode.AverageTime) + .timeUnit(TimeUnit.MICROSECONDS) + .warmupTime(TimeValue.seconds(1)) + .warmupIterations(2) + .measurementTime(TimeValue.seconds(1)) + .measurementIterations(10) + .threads(2) + .forks(1) + .shouldFailOnError(true) + .shouldDoGC(true) + // .jvmArgs("-XX:+UnlockDiagnosticVMOptions", "-XX:+PrintInlining") + // .addProfiler(WinPerfAsmProfiler.class) + .build(); + + new Runner(opt).run(); + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index 4ddc61ece..b4fa769a1 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -61,11 +61,6 @@ public long getMaxMemory() { return maxMemory; } - @Override - public long getTotalMemory() { - return maxMemory; - } - @Override public long getFreeMemory() { return freeMemory; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index 5b24bcc7f..553efbd31 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -80,9 +80,27 @@ public class SnowflakeStreamingIngestClientTest { SnowflakeStreamingIngestChannelInternal channel4; @Before - public void setup() { + public void setup() throws Exception { objectMapper.setVisibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.ANY); objectMapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY); + Properties prop = new Properties(); + prop.put(USER, TestUtils.getUser()); + prop.put(ACCOUNT_URL, TestUtils.getHost()); + prop.put(PRIVATE_KEY, TestUtils.getPrivateKey()); + prop.put(ROLE, TestUtils.getRole()); + + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + RequestBuilder requestBuilder = + new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); channel1 = new SnowflakeStreamingIngestChannelInternal<>( "channel1", @@ -92,7 +110,7 @@ public void setup() { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -108,7 +126,7 @@ public void setup() { "0", 2L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -124,7 +142,7 @@ public void setup() { "0", 3L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -140,7 +158,7 @@ public void setup() { "0", 3L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -357,7 +375,7 @@ public void testGetChannelsStatusWithRequest() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -461,7 +479,7 @@ public void testGetChannelsStatusWithRequestError() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -494,6 +512,16 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { RequestBuilder requestBuilder = new RequestBuilder(url, prop.get(USER).toString(), keyPair, null, null); + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -503,7 +531,7 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -1426,7 +1454,7 @@ public void testGetLatestCommittedOffsetTokens() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java index ead26acd6..b7f9e6829 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/StreamingIngestBigFilesIT.java @@ -130,7 +130,7 @@ private void ingestRandomRowsToTable( boolean isNullable) throws ExecutionException, InterruptedException { - List> rows = new ArrayList<>(); + final List> rows = Collections.synchronizedList(new ArrayList<>()); for (int i = 0; i < batchSize; i++) { Random r = new Random(); rows.add(TestUtils.getRandomRow(r, isNullable)); @@ -138,7 +138,8 @@ private void ingestRandomRowsToTable( ExecutorService testThreadPool = Executors.newFixedThreadPool(numChannels); CompletableFuture[] futures = new CompletableFuture[numChannels]; - List channelList = new ArrayList<>(); + List channelList = + Collections.synchronizedList(new ArrayList<>()); for (int i = 0; i < numChannels; i++) { final String channelName = "CHANNEL" + i; int finalI = i;