Skip to content

Commit

Permalink
Various performance improvements in the insertRows path (#782)
Browse files Browse the repository at this point in the history
* Add Microbenchmark for Insert rows

* Various performance improvements in the insertRows path

* Fix tests and format

* Make flush threads daemon to allow jvm to exit if these threads are active

* Remove commented out line

* Make memory info provider a singleton because we dont want multiple instances of it

* Address review comments

* Lower benchmark row count to make gh actions happy

* Review comments

* Mark these deps as test-only and check memory every 100ms
  • Loading branch information
sfc-gh-psaha authored Jun 26, 2024
1 parent 1357e74 commit f8cad10
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 45 deletions.
22 changes: 22 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,18 @@
<version>3.7.7</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<version>1.34</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<version>1.34</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down Expand Up @@ -537,6 +549,16 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-api-mockito2</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import net.snowflake.ingest.connection.TelemetryService;
import net.snowflake.ingest.streaming.InsertValidationResponse;
import net.snowflake.ingest.streaming.OffsetTokenVerificationFunction;
Expand Down Expand Up @@ -400,10 +399,10 @@ public float getSize() {
Set<String> verifyInputColumns(
Map<String, Object> row, InsertValidationResponse.InsertError error, int rowIndex) {
// Map of unquoted column name -> original column name
Map<String, String> inputColNamesMap =
row.keySet().stream()
.collect(Collectors.toMap(LiteralQuoteUtils::unquoteColumnName, value -> value));

Set<String> originalKeys = row.keySet();
Map<String, String> inputColNamesMap = new HashMap<>();
originalKeys.forEach(
key -> inputColNamesMap.put(LiteralQuoteUtils.unquoteColumnName(key), key));
// Check for extra columns in the row
List<String> extraCols = new ArrayList<>();
for (String columnName : inputColNamesMap.keySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ class DataValidationUtil {
objectMapper.registerModule(module);
}

// Caching the powers of 10 that are used for checking the range of numbers because computing them
// on-demand is expensive.
private static final BigDecimal[] POWER_10 = makePower10Table();

private static BigDecimal[] makePower10Table() {
BigDecimal[] power10 = new BigDecimal[Power10.sb16Size];
for (int i = 0; i < Power10.sb16Size; i++) {
power10[i] = new BigDecimal(Power10.sb16Table[i]);
}
return power10;
}

/**
* Validates and parses input as JSON. All types in the object tree must be valid variant types,
* see {@link DataValidationUtil#isAllowedSemiStructuredType}.
Expand Down Expand Up @@ -823,7 +835,11 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR

static void checkValueInRange(
BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) {
if (bigDecimalValue.abs().compareTo(BigDecimal.TEN.pow(precision - scale)) >= 0) {
BigDecimal comparand =
(precision >= scale) && (precision - scale) < POWER_10.length
? POWER_10[precision - scale]
: BigDecimal.TEN.pow(precision - scale);
if (bigDecimalValue.abs().compareTo(comparand) >= 0) {
throw new SFException(
ErrorCode.INVALID_FORMAT_ROW,
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ List<List<ChannelData<T>>> getData() {

// blob encoding version
private final Constants.BdecVersion bdecVersion;
private volatile int numProcessors = Runtime.getRuntime().availableProcessors();

/**
* Constructor for TESTING that takes (usually mocked) StreamingIngestStage
Expand Down Expand Up @@ -360,6 +361,9 @@ void distributeFlushTasks() {
List<Pair<BlobData<T>, CompletableFuture<BlobMetadata>>> blobs = new ArrayList<>();
List<ChannelData<T>> leftoverChannelsDataPerTable = new ArrayList<>();

// The API states that the number of available processors reported can change and therefore, we
// should poll it occasionally.
numProcessors = Runtime.getRuntime().availableProcessors();
while (itr.hasNext() || !leftoverChannelsDataPerTable.isEmpty()) {
List<List<ChannelData<T>>> blobData = new ArrayList<>();
float totalBufferSizeInBytes = 0F;
Expand Down Expand Up @@ -704,8 +708,7 @@ String getClientPrefix() {
*/
boolean throttleDueToQueuedFlushTasks() {
ThreadPoolExecutor buildAndUpload = (ThreadPoolExecutor) this.buildUploadWorkers;
boolean throttleOnQueuedTasks =
buildAndUpload.getQueue().size() > Runtime.getRuntime().availableProcessors();
boolean throttleOnQueuedTasks = buildAndUpload.getQueue().size() > numProcessors;
if (throttleOnQueuedTasks) {
logger.logWarn(
"Throttled due too many queue flush tasks (probably because of slow uploading speed),"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ public interface MemoryInfoProvider {
/** @return Max memory the JVM can allocate */
long getMaxMemory();

/** @return Total allocated JVM memory so far */
long getTotalMemory();

/** @return Free JVM memory */
long getFreeMemory();
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,51 @@

package net.snowflake.ingest.streaming.internal;

import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/** Reads memory information from JVM runtime */
public class MemoryInfoProviderFromRuntime implements MemoryInfoProvider {
@Override
public long getMaxMemory() {
return Runtime.getRuntime().maxMemory();
private final long maxMemory;
private volatile long totalFreeMemory;
private final ScheduledExecutorService executorService;
private static final long FREE_MEMORY_UPDATE_INTERVAL_MS = 100;
private static final MemoryInfoProviderFromRuntime INSTANCE =
new MemoryInfoProviderFromRuntime(FREE_MEMORY_UPDATE_INTERVAL_MS);

private MemoryInfoProviderFromRuntime(long freeMemoryUpdateIntervalMs) {
maxMemory = Runtime.getRuntime().maxMemory();
totalFreeMemory =
Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory());
executorService =
new ScheduledThreadPoolExecutor(
1,
r -> {
Thread th = new Thread(r, "MemoryInfoProviderFromRuntime");
th.setDaemon(true);
return th;
});
executorService.scheduleAtFixedRate(
this::updateFreeMemory, 0, freeMemoryUpdateIntervalMs, TimeUnit.MILLISECONDS);
}

private void updateFreeMemory() {
totalFreeMemory =
Runtime.getRuntime().freeMemory() + (maxMemory - Runtime.getRuntime().totalMemory());
}

public static MemoryInfoProviderFromRuntime getInstance() {
return INSTANCE;
}

@Override
public long getTotalMemory() {
return Runtime.getRuntime().totalMemory();
public long getMaxMemory() {
return maxMemory;
}

@Override
public long getFreeMemory() {
return Runtime.getRuntime().freeMemory();
return totalFreeMemory;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class SnowflakeStreamingIngestChannelInternal<T> implements SnowflakeStreamingIn

// Reference to the row buffer
private final RowBuffer<T> rowBuffer;
private final long insertThrottleIntervalInMs;
private final int insertThrottleThresholdInBytes;
private final int insertThrottleThresholdInPercentage;
private final long maxMemoryLimitInBytes;

// Indicates whether the channel is closed
private volatile boolean isClosed;
Expand All @@ -61,6 +65,9 @@ class SnowflakeStreamingIngestChannelInternal<T> implements SnowflakeStreamingIn
// The latest cause of channel invalidation
private String invalidationCause;

private final MemoryInfoProvider memoryInfoProvider;
private volatile long freeMemoryInBytes = 0;

/**
* Constructor for TESTING ONLY which allows us to set the test mode
*
Expand Down Expand Up @@ -121,6 +128,17 @@ class SnowflakeStreamingIngestChannelInternal<T> implements SnowflakeStreamingIn
OffsetTokenVerificationFunction offsetTokenVerificationFunction) {
this.isClosed = false;
this.owningClient = client;

this.insertThrottleIntervalInMs =
this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs();
this.insertThrottleThresholdInBytes =
this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes();
this.insertThrottleThresholdInPercentage =
this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage();
this.maxMemoryLimitInBytes =
this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes();

this.memoryInfoProvider = MemoryInfoProviderFromRuntime.getInstance();
this.channelFlushContext =
new ChannelFlushContext(
name, dbName, schemaName, tableName, channelSequencer, encryptionKey, encryptionKeyId);
Expand Down Expand Up @@ -373,7 +391,7 @@ public InsertValidationResponse insertRows(
Iterable<Map<String, Object>> rows,
@Nullable String startOffsetToken,
@Nullable String endOffsetToken) {
throttleInsertIfNeeded(new MemoryInfoProviderFromRuntime());
throttleInsertIfNeeded(memoryInfoProvider);
checkValidation();

if (isClosed()) {
Expand Down Expand Up @@ -448,8 +466,6 @@ public Map<String, ColumnProperties> getTableSchema() {
/** Check whether we need to throttle the insertRows API */
void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) {
int retry = 0;
long insertThrottleIntervalInMs =
this.owningClient.getParameterProvider().getInsertThrottleIntervalInMs();
while ((hasLowRuntimeMemory(memoryInfoProvider)
|| (this.owningClient.getFlushService() != null
&& this.owningClient.getFlushService().throttleDueToQueuedFlushTasks()))
Expand All @@ -473,22 +489,14 @@ void throttleInsertIfNeeded(MemoryInfoProvider memoryInfoProvider) {

/** Check whether we have a low runtime memory condition */
private boolean hasLowRuntimeMemory(MemoryInfoProvider memoryInfoProvider) {
int insertThrottleThresholdInBytes =
this.owningClient.getParameterProvider().getInsertThrottleThresholdInBytes();
int insertThrottleThresholdInPercentage =
this.owningClient.getParameterProvider().getInsertThrottleThresholdInPercentage();
long maxMemoryLimitInBytes =
this.owningClient.getParameterProvider().getMaxMemoryLimitInBytes();
long maxMemory =
maxMemoryLimitInBytes == MAX_MEMORY_LIMIT_IN_BYTES_DEFAULT
? memoryInfoProvider.getMaxMemory()
: maxMemoryLimitInBytes;
long freeMemory =
memoryInfoProvider.getFreeMemory()
+ (memoryInfoProvider.getMaxMemory() - memoryInfoProvider.getTotalMemory());
freeMemoryInBytes = memoryInfoProvider.getFreeMemory();
boolean hasLowRuntimeMemory =
freeMemory < insertThrottleThresholdInBytes
&& freeMemory * 100 / maxMemory < insertThrottleThresholdInPercentage;
freeMemoryInBytes < insertThrottleThresholdInBytes
&& freeMemoryInBytes * 100 / maxMemory < insertThrottleThresholdInPercentage;
if (hasLowRuntimeMemory) {
logger.logWarn(
"Throttled due to memory pressure, client={}, channel={}.",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package net.snowflake.ingest.streaming.internal;

import static java.time.ZoneOffset.UTC;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import net.snowflake.ingest.streaming.InsertValidationResponse;
import net.snowflake.ingest.streaming.OpenChannelRequest;
import net.snowflake.ingest.utils.Utils;
import org.junit.Assert;
import org.junit.Test;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.TimeValue;

@State(Scope.Thread)
public class InsertRowsBenchmarkTest {

private SnowflakeStreamingIngestChannelInternal<?> channel;
private SnowflakeStreamingIngestClientInternal<?> client;

@Param({"100000"})
private int numRows;

@Setup(Level.Trial)
public void setUpBeforeAll() {
client = new SnowflakeStreamingIngestClientInternal<ParquetChunkData>("client_PARQUET");
channel =
new SnowflakeStreamingIngestChannelInternal<>(
"channel",
"db",
"schema",
"table",
"0",
0L,
0L,
client,
"key",
1234L,
OpenChannelRequest.OnErrorOption.CONTINUE,
UTC);
// Setup column fields and vectors
ColumnMetadata col = new ColumnMetadata();
col.setOrdinal(1);
col.setName("COL");
col.setPhysicalType("SB16");
col.setNullable(false);
col.setLogicalType("FIXED");
col.setPrecision(38);
col.setScale(0);

channel.setupSchema(Collections.singletonList(col));
assert Utils.getProvider() != null;
}

@TearDown(Level.Trial)
public void tearDownAfterAll() throws Exception {
channel.close();
client.close();
}

@Benchmark
public void testInsertRow() {
Map<String, Object> row = new HashMap<>();
row.put("col", 1);

for (int i = 0; i < numRows; i++) {
InsertValidationResponse response = channel.insertRow(row, String.valueOf(i));
Assert.assertFalse(response.hasErrors());
}
}

@Test
public void insertRow() throws Exception {
setUpBeforeAll();
Map<String, Object> row = new HashMap<>();
row.put("col", 1);

for (int i = 0; i < 1000000; i++) {
InsertValidationResponse response = channel.insertRow(row, String.valueOf(i));
Assert.assertFalse(response.hasErrors());
}
tearDownAfterAll();
}

@Test
public void launchBenchmark() throws RunnerException {
Options opt =
new OptionsBuilder()
// Specify which benchmarks to run.
// You can be more specific if you'd like to run only one benchmark per test.
.include(this.getClass().getName() + ".*")
// Set the following options as needed
.mode(Mode.AverageTime)
.timeUnit(TimeUnit.MICROSECONDS)
.warmupTime(TimeValue.seconds(1))
.warmupIterations(2)
.measurementTime(TimeValue.seconds(1))
.measurementIterations(10)
.threads(2)
.forks(1)
.shouldFailOnError(true)
.shouldDoGC(true)
// .jvmArgs("-XX:+UnlockDiagnosticVMOptions", "-XX:+PrintInlining")
// .addProfiler(WinPerfAsmProfiler.class)
.build();

new Runner(opt).run();
}
}
Loading

0 comments on commit f8cad10

Please sign in to comment.