From f8cad10fd0fc21c6dc0293072a5f32a4ddfa64b4 Mon Sep 17 00:00:00 2001 From: Purujit Saha Date: Wed, 26 Jun 2024 09:37:27 -0700 Subject: [PATCH] Various performance improvements in the `insertRows` path (#782) * Add Microbenchmark for Insert rows * Various performance improvements in the insertRows path * Fix tests and format * Make flush threads daemon to allow jvm to exit if these threads are active * Remove commented out line * Make memory info provider a singleton because we dont want multiple instances of it * Address review comments * Lower benchmark row count to make gh actions happy * Review comments * Mark these deps as test-only and check memory every 100ms --- pom.xml | 22 ++++ .../streaming/internal/AbstractRowBuffer.java | 9 +- .../internal/DataValidationUtil.java | 18 ++- .../streaming/internal/FlushService.java | 7 +- .../internal/MemoryInfoProvider.java | 3 - .../MemoryInfoProviderFromRuntime.java | 43 +++++- ...owflakeStreamingIngestChannelInternal.java | 36 ++++-- .../internal/InsertRowsBenchmarkTest.java | 122 ++++++++++++++++++ .../SnowflakeStreamingIngestChannelTest.java | 5 - .../SnowflakeStreamingIngestClientTest.java | 46 +++++-- 10 files changed, 266 insertions(+), 45 deletions(-) create mode 100644 src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java diff --git a/pom.xml b/pom.xml index 04062e6a8..7723321bb 100644 --- a/pom.xml +++ b/pom.xml @@ -364,6 +364,18 @@ 3.7.7 test + + org.openjdk.jmh + jmh-core + 1.34 + test + + + org.openjdk.jmh + jmh-generator-annprocess + 1.34 + test + @@ -537,6 +549,16 @@ mockito-core test + + org.openjdk.jmh + jmh-core + test + + + org.openjdk.jmh + jmh-generator-annprocess + test + org.powermock powermock-api-mockito2 diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java index 6d5dce17f..71a9d501e 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/AbstractRowBuffer.java @@ -16,7 +16,6 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; -import java.util.stream.Collectors; import net.snowflake.ingest.connection.TelemetryService; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OffsetTokenVerificationFunction; @@ -400,10 +399,10 @@ public float getSize() { Set verifyInputColumns( Map row, InsertValidationResponse.InsertError error, int rowIndex) { // Map of unquoted column name -> original column name - Map inputColNamesMap = - row.keySet().stream() - .collect(Collectors.toMap(LiteralQuoteUtils::unquoteColumnName, value -> value)); - + Set originalKeys = row.keySet(); + Map inputColNamesMap = new HashMap<>(); + originalKeys.forEach( + key -> inputColNamesMap.put(LiteralQuoteUtils.unquoteColumnName(key), key)); // Check for extra columns in the row List extraCols = new ArrayList<>(); for (String columnName : inputColNamesMap.keySet()) { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java index 814423c28..162e56145 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java @@ -86,6 +86,18 @@ class DataValidationUtil { objectMapper.registerModule(module); } + // Caching the powers of 10 that are used for checking the range of numbers because computing them + // on-demand is expensive. + private static final BigDecimal[] POWER_10 = makePower10Table(); + + private static BigDecimal[] makePower10Table() { + BigDecimal[] power10 = new BigDecimal[Power10.sb16Size]; + for (int i = 0; i < Power10.sb16Size; i++) { + power10[i] = new BigDecimal(Power10.sb16Table[i]); + } + return power10; + } + /** * Validates and parses input as JSON. All types in the object tree must be valid variant types, * see {@link DataValidationUtil#isAllowedSemiStructuredType}. @@ -823,7 +835,11 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR static void checkValueInRange( BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) { - if (bigDecimalValue.abs().compareTo(BigDecimal.TEN.pow(precision - scale)) >= 0) { + BigDecimal comparand = + (precision >= scale) && (precision - scale) < POWER_10.length + ? POWER_10[precision - scale] + : BigDecimal.TEN.pow(precision - scale); + if (bigDecimalValue.abs().compareTo(comparand) >= 0) { throw new SFException( ErrorCode.INVALID_FORMAT_ROW, String.format( diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index f08196477..76e43ff4d 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -122,6 +122,7 @@ List>> getData() { // blob encoding version private final Constants.BdecVersion bdecVersion; + private volatile int numProcessors = Runtime.getRuntime().availableProcessors(); /** * Constructor for TESTING that takes (usually mocked) StreamingIngestStage @@ -360,6 +361,9 @@ void distributeFlushTasks() { List, CompletableFuture>> blobs = new ArrayList<>(); List> leftoverChannelsDataPerTable = new ArrayList<>(); + // The API states that the number of available processors reported can change and therefore, we + // should poll it occasionally. + numProcessors = Runtime.getRuntime().availableProcessors(); while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) { List>> blobData = new ArrayList<>(); float totalBufferSizeInBytes = 0F; @@ -704,8 +708,7 @@ String getClientPrefix() { */ boolean throttleDueToQueuedFlushTasks() { ThreadPoolExecutor buildAndUpload = (ThreadPoolExecutor) this.buildUploadWorkers; - boolean throttleOnQueuedTasks = - buildAndUpload.getQueue().size() > Runtime.getRuntime().availableProcessors(); + boolean throttleOnQueuedTasks = buildAndUpload.getQueue().size() > numProcessors; if (throttleOnQueuedTasks) { logger.logWarn( "Throttled due too many queue flush tasks (probably because of slow uploading speed)," diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java index f426e898d..777ae4fdc 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProvider.java @@ -9,9 +9,6 @@ public interface MemoryInfoProvider { /** @return Max memory the JVM can allocate */ long getMaxMemory(); - /** @return Total allocated JVM memory so far */ - long getTotalMemory(); - /** @return Free JVM memory */ long getFreeMemory(); } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java index 3a957f225..d248ddfd9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/MemoryInfoProviderFromRuntime.java @@ -4,20 +4,51 @@ package net.snowflake.ingest.streaming.internal; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + /** Reads memory information from JVM runtime */ public class MemoryInfoProviderFromRuntime implements MemoryInfoProvider { - @Override - public long getMaxMemory() { - return Runtime.getRuntime().maxMemory(); + private final long maxMemory; + private volatile long totalFreeMemory; + private final ScheduledExecutorService executorService; + private static final long FREE_MEMORY_UPDATE_INTERVAL_MS = 100; + private static final MemoryInfoProviderFromRuntime INSTANCE = + new MemoryInfoProviderFromRuntime(FREE_MEMORY_UPDATE_INTERVAL_MS); + + private MemoryInfoProviderFromRuntime(long freeMemoryUpdateIntervalMs) { + maxMemory = Runtime.getRuntime().maxMemory(); + totalFreeMemory = + Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory()); + executorService = + new ScheduledThreadPoolExecutor( + 1, + r -> { + Thread th = new Thread(r, "MemoryInfoProviderFromRuntime"); + th.setDaemon(true); + return th; + }); + executorService.scheduleAtFixedRate( + this::updateFreeMemory, 0, freeMemoryUpdateIntervalMs, TimeUnit.MILLISECONDS); + } + + private void updateFreeMemory() { + totalFreeMemory = + Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory()); + } + + public static MemoryInfoProviderFromRuntime getInstance() { + return INSTANCE; } @Override - public long getTotalMemory() { - return Runtime.getRuntime().totalMemory(); + public long getMaxMemory() { + return maxMemory; } @Override public long getFreeMemory() { - return Runtime.getRuntime().freeMemory(); + return totalFreeMemory; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java index 58e81d116..8ebc23ca1 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java @@ -45,6 +45,10 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn // Reference to the row buffer private final RowBuffer rowBuffer; + private final long insertThrottleIntervalInMs; + private final int insertThrottleThresholdInBytes; + private final int insertThrottleThresholdInPercentage; + private final long maxMemoryLimitInBytes; // Indicates whether the channel is closed private volatile boolean isClosed; @@ -61,6 +65,9 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn // The latest cause of channel invalidation private String invalidationCause; + private final MemoryInfoProvider memoryInfoProvider; + private volatile long freeMemoryInBytes = 0; + /** * Constructor for TESTING ONLY which allows us to set the test mode * @@ -121,6 +128,17 @@ class SnowflakeStreamingIngestChannelInternal implements SnowflakeStreamingIn OffsetTokenVerificationFunction offsetTokenVerificationFunction) { this.isClosed = false; this.owningClient = client; + + this.insertThrottleIntervalInMs = + this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs(); + this.insertThrottleThresholdInBytes = + this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes(); + this.insertThrottleThresholdInPercentage = + this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage(); + this.maxMemoryLimitInBytes = + this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes(); + + this.memoryInfoProvider = MemoryInfoProviderFromRuntime.getInstance(); this.channelFlushContext = new ChannelFlushContext( name, dbName, schemaName, tableName, channelSequencer, encryptionKey, encryptionKeyId); @@ -373,7 +391,7 @@ public InsertValidationResponse insertRows( Iterable> rows, @Nullable String startOffsetToken, @Nullable String endOffsetToken) { - throttleInsertIfNeeded(new MemoryInfoProviderFromRuntime()); + throttleInsertIfNeeded(memoryInfoProvider); checkValidation(); if (isClosed()) { @@ -448,8 +466,6 @@ public Map getTableSchema() { /** Check whether we need to throttle the insertRows API */ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) { int retry = 0; - long insertThrottleIntervalInMs = - this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs(); while ((hasLowRuntimeMemory(memoryInfoProvider) || (this.owningClient.getFlushService() != null && this.owningClient.getFlushService().throttleDueToQueuedFlushTasks())) @@ -473,22 +489,14 @@ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) { /** Check whether we have a low runtime memory condition */ private boolean hasLowRuntimeMemory(MemoryInfoProvider memoryInfoProvider) { - int insertThrottleThresholdInBytes = - this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes(); - int insertThrottleThresholdInPercentage = - this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage(); - long maxMemoryLimitInBytes = - this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes(); long maxMemory = maxMemoryLimitInBytes == MAX_MEMORY_LIMIT_IN_BYTES_DEFAULT ? memoryInfoProvider.getMaxMemory() : maxMemoryLimitInBytes; - long freeMemory = - memoryInfoProvider.getFreeMemory() - + (memoryInfoProvider.getMaxMemory() - memoryInfoProvider.getTotalMemory()); + freeMemoryInBytes = memoryInfoProvider.getFreeMemory(); boolean hasLowRuntimeMemory = - freeMemory < insertThrottleThresholdInBytes - && freeMemory * 100 / maxMemory < insertThrottleThresholdInPercentage; + freeMemoryInBytes < insertThrottleThresholdInBytes + && freeMemoryInBytes * 100 / maxMemory < insertThrottleThresholdInPercentage; if (hasLowRuntimeMemory) { logger.logWarn( "Throttled due to memory pressure, client={}, channel={}.", diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java new file mode 100644 index 000000000..5b28e9c45 --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java @@ -0,0 +1,122 @@ +package net.snowflake.ingest.streaming.internal; + +import static java.time.ZoneOffset.UTC; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import net.snowflake.ingest.streaming.InsertValidationResponse; +import net.snowflake.ingest.streaming.OpenChannelRequest; +import net.snowflake.ingest.utils.Utils; +import org.junit.Assert; +import org.junit.Test; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.TimeValue; + +@State(Scope.Thread) +public class InsertRowsBenchmarkTest { + + private SnowflakeStreamingIngestChannelInternal channel; + private SnowflakeStreamingIngestClientInternal client; + + @Param({"100000"}) + private int numRows; + + @Setup(Level.Trial) + public void setUpBeforeAll() { + client = new SnowflakeStreamingIngestClientInternal("client_PARQUET"); + channel = + new SnowflakeStreamingIngestChannelInternal<>( + "channel", + "db", + "schema", + "table", + "0", + 0L, + 0L, + client, + "key", + 1234L, + OpenChannelRequest.OnErrorOption.CONTINUE, + UTC); + // Setup column fields and vectors + ColumnMetadata col = new ColumnMetadata(); + col.setOrdinal(1); + col.setName("COL"); + col.setPhysicalType("SB16"); + col.setNullable(false); + col.setLogicalType("FIXED"); + col.setPrecision(38); + col.setScale(0); + + channel.setupSchema(Collections.singletonList(col)); + assert Utils.getProvider() != null; + } + + @TearDown(Level.Trial) + public void tearDownAfterAll() throws Exception { + channel.close(); + client.close(); + } + + @Benchmark + public void testInsertRow() { + Map row = new HashMap<>(); + row.put("col", 1); + + for (int i = 0; i < numRows; i++) { + InsertValidationResponse response = channel.insertRow(row, String.valueOf(i)); + Assert.assertFalse(response.hasErrors()); + } + } + + @Test + public void insertRow() throws Exception { + setUpBeforeAll(); + Map row = new HashMap<>(); + row.put("col", 1); + + for (int i = 0; i < 1000000; i++) { + InsertValidationResponse response = channel.insertRow(row, String.valueOf(i)); + Assert.assertFalse(response.hasErrors()); + } + tearDownAfterAll(); + } + + @Test + public void launchBenchmark() throws RunnerException { + Options opt = + new OptionsBuilder() + // Specify which benchmarks to run. + // You can be more specific if you'd like to run only one benchmark per test. + .include(this.getClass().getName() + ".*") + // Set the following options as needed + .mode(Mode.AverageTime) + .timeUnit(TimeUnit.MICROSECONDS) + .warmupTime(TimeValue.seconds(1)) + .warmupIterations(2) + .measurementTime(TimeValue.seconds(1)) + .measurementIterations(10) + .threads(2) + .forks(1) + .shouldFailOnError(true) + .shouldDoGC(true) + // .jvmArgs("-XX:+UnlockDiagnosticVMOptions", "-XX:+PrintInlining") + // .addProfiler(WinPerfAsmProfiler.class) + .build(); + + new Runner(opt).run(); + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index 4ddc61ece..b4fa769a1 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -61,11 +61,6 @@ public long getMaxMemory() { return maxMemory; } - @Override - public long getTotalMemory() { - return maxMemory; - } - @Override public long getFreeMemory() { return freeMemory; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index 5b24bcc7f..553efbd31 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -80,9 +80,27 @@ public class SnowflakeStreamingIngestClientTest { SnowflakeStreamingIngestChannelInternal channel4; @Before - public void setup() { + public void setup() throws Exception { objectMapper.setVisibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.ANY); objectMapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY); + Properties prop = new Properties(); + prop.put(USER, TestUtils.getUser()); + prop.put(ACCOUNT_URL, TestUtils.getHost()); + prop.put(PRIVATE_KEY, TestUtils.getPrivateKey()); + prop.put(ROLE, TestUtils.getRole()); + + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + RequestBuilder requestBuilder = + new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); channel1 = new SnowflakeStreamingIngestChannelInternal<>( "channel1", @@ -92,7 +110,7 @@ public void setup() { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -108,7 +126,7 @@ public void setup() { "0", 2L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -124,7 +142,7 @@ public void setup() { "0", 3L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -140,7 +158,7 @@ public void setup() { "0", 3L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -357,7 +375,7 @@ public void testGetChannelsStatusWithRequest() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -461,7 +479,7 @@ public void testGetChannelsStatusWithRequestError() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -494,6 +512,16 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { RequestBuilder requestBuilder = new RequestBuilder(url, prop.get(USER).toString(), keyPair, null, null); + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + SnowflakeStreamingIngestClientInternal client = + new SnowflakeStreamingIngestClientInternal<>( + "client", + new SnowflakeURL("snowflake.dev.local:8082"), + null, + httpClient, + true, + requestBuilder, + null); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -503,7 +531,7 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE, @@ -1426,7 +1454,7 @@ public void testGetLatestCommittedOffsetTokens() throws Exception { "0", 0L, 0L, - null, + client, "key", 1234L, OpenChannelRequest.OnErrorOption.CONTINUE,