From f3ef51c088ca738fef17474361482d63deb3b3b1 Mon Sep 17 00:00:00 2001 From: Alec Huang Date: Wed, 3 Jul 2024 15:29:25 -0700 Subject: [PATCH] Per table flush --- .../streaming/internal/ChannelCache.java | 75 ++++++++++- .../streaming/internal/FlushService.java | 110 ++++++++++----- ...owflakeStreamingIngestChannelInternal.java | 4 +- ...nowflakeStreamingIngestClientInternal.java | 6 +- .../streaming/internal/FlushServiceTest.java | 126 ++++++++++++++++-- 5 files changed, 269 insertions(+), 52 deletions(-) diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java index 989be0fa1..782c5e7cf 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelCache.java @@ -1,11 +1,12 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; import java.util.Iterator; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; /** @@ -23,6 +24,12 @@ class ChannelCache { String, ConcurrentHashMap>> cache = new ConcurrentHashMap<>(); + // Last flush time for each table, the key is FullyQualifiedTableName. + private final ConcurrentHashMap lastFlushTime = new ConcurrentHashMap<>(); + + // Need flush flag for each table, the key is FullyQualifiedTableName. + private final ConcurrentHashMap needFlush = new ConcurrentHashMap<>(); + /** * Add a channel to the channel cache * @@ -33,6 +40,12 @@ void addChannel(SnowflakeStreamingIngestChannelInternal channel) { this.cache.computeIfAbsent( channel.getFullyQualifiedTableName(), v -> new ConcurrentHashMap<>()); + // Update the last flush time for the table, add jitter to avoid all channels flush at the same + // time when the blobs are not interleaved + this.lastFlushTime.putIfAbsent( + channel.getFullyQualifiedTableName(), + System.currentTimeMillis() + (long) (Math.random() * 1000)); + SnowflakeStreamingIngestChannelInternal oldChannel = channels.put(channel.getName(), channel); // Invalidate old channel if it exits to block new inserts and return error to users earlier @@ -43,6 +56,46 @@ void addChannel(SnowflakeStreamingIngestChannelInternal channel) { } } + /** + * Get the last flush time for a table + * + * @param fullyQualifiedTableName fully qualified table name + * @return last flush time in milliseconds + */ + Long getLastFlushTime(String fullyQualifiedTableName) { + return this.lastFlushTime.get(fullyQualifiedTableName); + } + + /** + * Set the last flush time for a table as the current time + * + * @param fullyQualifiedTableName fully qualified table name + * @param lastFlushTime last flush time in milliseconds + */ + void setLastFlushTime(String fullyQualifiedTableName, Long lastFlushTime) { + this.lastFlushTime.put(fullyQualifiedTableName, lastFlushTime); + } + + /** + * Get need flush flag for a table + * + * @param fullyQualifiedTableName fully qualified table name + * @return need flush flag + */ + Boolean getNeedFlush(String fullyQualifiedTableName) { + return this.needFlush.getOrDefault(fullyQualifiedTableName, false); + } + + /** + * Set need flush flag for a table + * + * @param fullyQualifiedTableName fully qualified table name + * @param needFlush need flush flag + */ + void setNeedFlush(String fullyQualifiedTableName, Boolean needFlush) { + this.needFlush.put(fullyQualifiedTableName, needFlush); + } + /** * Returns an iterator over the (table, channels) in this map. * @@ -53,6 +106,20 @@ void addChannel(SnowflakeStreamingIngestChannelInternal channel) { return this.cache.entrySet().iterator(); } + /** + * Returns an iterator over the (table, channels) in this map, filtered by the given table name + * set + * + * @param tableNames the set of table names to filter + * @return + */ + Iterator>>> + iterator(Set tableNames) { + return this.cache.entrySet().stream() + .filter(entry -> tableNames.contains(entry.getKey())) + .iterator(); + } + /** Close all channels in the channel cache */ void closeAllChannels() { this.cache @@ -101,4 +168,10 @@ void invalidateChannelIfSequencersMatch( int getSize() { return cache.size(); } + + public Set< + Map.Entry>>> + entrySet() { + return cache.entrySet(); + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index 76e43ff4d..7b3040a5f 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.TimeZone; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -35,6 +36,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; @@ -108,9 +110,6 @@ List>> getData() { // Reference to register service private final RegisterService registerService; - // Indicates whether we need to schedule a flush - @VisibleForTesting volatile boolean isNeedFlush; - // Latest flush time @VisibleForTesting volatile long lastFlushTime; @@ -141,7 +140,6 @@ List>> getData() { this.targetStage = targetStage; this.counter = new AtomicLong(0); this.registerService = new RegisterService<>(client, isTestMode); - this.isNeedFlush = false; this.lastFlushTime = System.currentTimeMillis(); this.isTestMode = isTestMode; this.latencyTimerContextMap = new ConcurrentHashMap<>(); @@ -175,7 +173,6 @@ List>> getData() { this.registerService = new RegisterService<>(client, isTestMode); this.counter = new AtomicLong(0); - this.isNeedFlush = false; this.lastFlushTime = System.currentTimeMillis(); this.isTestMode = isTestMode; this.latencyTimerContextMap = new ConcurrentHashMap<>(); @@ -204,36 +201,43 @@ private CompletableFuture statsFuture() { /** * @param isForce if true will flush regardless of other conditions - * @param timeDiffMillis Time in milliseconds since the last flush + * @param tablesToFlush list of tables to flush + * @param timeDiffMillis time difference in milliseconds * @return */ - private CompletableFuture distributeFlush(boolean isForce, long timeDiffMillis) { + private CompletableFuture distributeFlush( + boolean isForce, Set tablesToFlush, Long timeDiffMillis) { return CompletableFuture.runAsync( () -> { - logFlushTask(isForce, timeDiffMillis); - distributeFlushTasks(); - this.isNeedFlush = false; + logFlushTask(isForce, tablesToFlush, timeDiffMillis); + distributeFlushTasks(tablesToFlush); this.lastFlushTime = System.currentTimeMillis(); - return; + tablesToFlush.forEach( + table -> { + this.channelCache.setLastFlushTime(table, this.lastFlushTime); + this.channelCache.setNeedFlush(table, false); + }); }, this.flushWorker); } /** If tracing is enabled, print always else, check if it needs flush or is forceful. */ - private void logFlushTask(boolean isForce, long timeDiffMillis) { + private void logFlushTask(boolean isForce, Set tablesToFlush, long timeDiffMillis) { + boolean isNeedFlush = tablesToFlush.stream().anyMatch(channelCache::getNeedFlush); + final String flushTaskLogFormat = String.format( "Submit forced or ad-hoc flush task on client=%s, isForce=%s," + " isNeedFlush=%s, timeDiffMillis=%s, currentDiffMillis=%s", this.owningClient.getName(), isForce, - this.isNeedFlush, + isNeedFlush, timeDiffMillis, System.currentTimeMillis() - this.lastFlushTime); if (logger.isTraceEnabled()) { logger.logTrace(flushTaskLogFormat); } - if (!logger.isTraceEnabled() && (this.isNeedFlush || isForce)) { + if (!logger.isTraceEnabled() && (isNeedFlush || isForce)) { logger.logDebug(flushTaskLogFormat); } } @@ -249,27 +253,57 @@ private CompletableFuture registerFuture() { } /** - * Kick off a flush job and distribute the tasks if one of the following conditions is met: - *
  • Flush is forced by the users - *
  • One or more buffers have reached the flush size - *
  • Periodical background flush when a time interval has reached + * Kick off a flush job and distribute the tasks. The flush service behaves differently based on + * the max chunks in blob: + * + *
      + *
    • The max chunks in blob is not 1 (interleaving is allowed), every channel will be flushed + * together if one of the following conditions is met: + *
        + *
      • Flush is forced by the users + *
      • One or more buffers have reached the flush size + *
      • Periodical background flush when a time interval has reached + *
      + *
    • The max chunks in blob is 1 (interleaving is not allowed), a channel will be flushed if + * one of the following conditions is met: + *
        + *
      • Flush is forced by the users + *
      • One or more buffers with the same target table as the channel have reached the + * flush size + *
      • Periodical background flush of the target table when a time interval has reached + *
      + *
    * * @param isForce * @return Completable future that will return when the blobs are registered successfully, or null * if none of the conditions is met above */ CompletableFuture flush(boolean isForce) { - long timeDiffMillis = System.currentTimeMillis() - this.lastFlushTime; + long currentTime = System.currentTimeMillis(); + long timeDiffMillis = currentTime - this.lastFlushTime; + + Set tablesToFlush = + this.channelCache.entrySet().stream() + .filter( + entry -> + isForce + || currentTime - this.channelCache.getLastFlushTime(entry.getKey()) + >= this.owningClient.getParameterProvider().getCachedMaxClientLagInMs() + || this.channelCache.getNeedFlush(entry.getKey())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + + // Flush every table together when it's interleaving chunk is allowed + if (this.owningClient.getParameterProvider().getMaxChunksInBlobAndRegistrationRequest() != 1 + && !tablesToFlush.isEmpty()) { + tablesToFlush.addAll( + this.channelCache.entrySet().stream().map(Map.Entry::getKey).collect(Collectors.toSet())); + } - if (isForce - || (!DISABLE_BACKGROUND_FLUSH - && !isTestMode() - && (this.isNeedFlush - || timeDiffMillis - >= this.owningClient.getParameterProvider().getCachedMaxClientLagInMs()))) { + if (isForce || (!DISABLE_BACKGROUND_FLUSH && !isTestMode() && !tablesToFlush.isEmpty())) { return this.statsFuture() - .thenCompose((v) -> this.distributeFlush(isForce, timeDiffMillis)) + .thenCompose((v) -> this.distributeFlush(isForce, tablesToFlush, timeDiffMillis)) .thenCompose((v) -> this.registerFuture()); } return this.statsFuture(); @@ -352,12 +386,14 @@ private void createWorkers() { /** * Distribute the flush tasks by iterating through all the channels in the channel cache and kick * off a build blob work when certain size has reached or we have reached the end + * + * @param tablesToFlush list of tables to flush */ - void distributeFlushTasks() { + void distributeFlushTasks(Set tablesToFlush) { Iterator< Map.Entry< String, ConcurrentHashMap>>> - itr = this.channelCache.iterator(); + itr = this.channelCache.iterator(tablesToFlush); List, CompletableFuture>> blobs = new ArrayList<>(); List> leftoverChannelsDataPerTable = new ArrayList<>(); @@ -389,11 +425,11 @@ void distributeFlushTasks() { blobPath); break; } else { - ConcurrentHashMap> table = - itr.next().getValue(); + Map.Entry>> + tableEntry = itr.next(); // Use parallel stream since getData could be the performance bottleneck when we have a // high number of channels - table.values().parallelStream() + tableEntry.getValue().values().parallelStream() .forEach( channel -> { if (channel.isValid()) { @@ -630,9 +666,13 @@ void shutdown() throws InterruptedException { } } - /** Set the flag to indicate that a flush is needed */ - void setNeedFlush() { - this.isNeedFlush = true; + /** + * Set the flag to indicate that a flush is needed + * + * @param fullyQualifiedTableName the fully qualified table name + */ + void setNeedFlush(String fullyQualifiedTableName) { + this.channelCache.setNeedFlush(fullyQualifiedTableName, true); } /** diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java index 8ebc23ca1..ca0bbe782 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelInternal.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -413,7 +413,7 @@ public InsertValidationResponse insertRows( // if a large number of rows are inserted if (this.rowBuffer.getSize() >= this.owningClient.getParameterProvider().getMaxChannelSizeInBytes()) { - this.owningClient.setNeedFlush(); + this.owningClient.setNeedFlush(this.channelFlushContext.getFullyQualifiedTableName()); } return response; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index 75eb4f717..b9d4c23d0 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -820,8 +820,8 @@ CompletableFuture flush(boolean closing) { } /** Set the flag to indicate that a flush is needed */ - void setNeedFlush() { - this.flushService.setNeedFlush(); + void setNeedFlush(String fullyQualifiedTableName) { + this.flushService.setNeedFlush(fullyQualifiedTableName); } /** Remove the channel in the channel cache if the channel sequencer matches */ diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index f200c7177..eaefb8cb4 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.utils.Constants.BLOB_CHECKSUM_SIZE_IN_BYTES; @@ -36,7 +40,9 @@ import java.util.Map; import java.util.TimeZone; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; @@ -51,6 +57,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; +import org.mockito.stubbing.Answer; public class FlushServiceTest { public FlushServiceTest() { @@ -88,13 +95,18 @@ private abstract static class TestContext implements AutoCloseable { Mockito.when(stage.getClientPrefix()).thenReturn("client_prefix"); parameterProvider = new ParameterProvider(); client = Mockito.mock(SnowflakeStreamingIngestClientInternal.class); - Mockito.when(client.getParameterProvider()).thenReturn(parameterProvider); + Mockito.when(client.getParameterProvider()) + .thenAnswer((Answer) (i) -> parameterProvider); channelCache = new ChannelCache<>(); Mockito.when(client.getChannelCache()).thenReturn(channelCache); registerService = Mockito.spy(new RegisterService(client, client.isTestMode())); flushService = Mockito.spy(new FlushService<>(client, channelCache, stage, true)); } + void setParameterOverride(Map parameterOverride) { + this.parameterProvider = new ParameterProvider(parameterOverride, null); + } + ChannelData flushChannel(String name) { SnowflakeStreamingIngestChannelInternal channel = channels.get(name); ChannelData channelData = channel.getRowBuffer().flush(name + "_snowpipe_streaming.bdec"); @@ -422,30 +434,118 @@ public void testGetFilePath() { @Test public void testFlush() throws Exception { - TestContext testContext = testContextFactory.create(); + int numChannels = 4; + TestContext>> testContext = testContextFactory.create(); + addChannel1(testContext); FlushService flushService = testContext.flushService; + ChannelCache channelCache = testContext.channelCache; Mockito.when(flushService.isTestMode()).thenReturn(false); // Nothing to flush flushService.flush(false).get(); - Mockito.verify(flushService, Mockito.times(0)).distributeFlushTasks(); + Mockito.verify(flushService, Mockito.times(0)).distributeFlushTasks(Mockito.any()); // Force = true flushes flushService.flush(true).get(); - Mockito.verify(flushService).distributeFlushTasks(); - Mockito.verify(flushService, Mockito.times(1)).distributeFlushTasks(); + Mockito.verify(flushService, Mockito.times(1)).distributeFlushTasks(Mockito.any()); + + IntStream.range(0, numChannels) + .forEach( + i -> { + addChannel(testContext, i, 1L); + channelCache.setLastFlushTime(getFullyQualifiedTableName(i), Long.MAX_VALUE); + }); // isNeedFlush = true flushes - flushService.isNeedFlush = true; + flushService.setNeedFlush(getFullyQualifiedTableName(0)); flushService.flush(false).get(); - Mockito.verify(flushService, Mockito.times(2)).distributeFlushTasks(); - Assert.assertFalse(flushService.isNeedFlush); + Mockito.verify(flushService, Mockito.times(2)).distributeFlushTasks(Mockito.any()); + Assert.assertNotEquals( + Long.MAX_VALUE, channelCache.getLastFlushTime(getFullyQualifiedTableName(0)).longValue()); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + Assert.assertEquals( + channelCache.getLastFlushTime(getFullyQualifiedTableName(0)), + channelCache.getLastFlushTime(getFullyQualifiedTableName(i))); + }); // lastFlushTime causes flush - flushService.lastFlushTime = 0; + channelCache.setLastFlushTime(getFullyQualifiedTableName(0), 0L); + flushService.flush(false).get(); + Mockito.verify(flushService, Mockito.times(3)).distributeFlushTasks(Mockito.any()); + Assert.assertNotEquals( + Long.MAX_VALUE, channelCache.getLastFlushTime(getFullyQualifiedTableName(0)).longValue()); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + Assert.assertEquals( + channelCache.getLastFlushTime(getFullyQualifiedTableName(0)), + channelCache.getLastFlushTime(getFullyQualifiedTableName(i))); + }); + } + + @Test + public void testNonInterleaveFlush() throws ExecutionException, InterruptedException { + int numChannels = 4; + TestContext>> testContext = testContextFactory.create(); + FlushService flushService = testContext.flushService; + ChannelCache channelCache = testContext.channelCache; + Mockito.when(flushService.isTestMode()).thenReturn(false); + testContext.setParameterOverride( + Collections.singletonMap(ParameterProvider.MAX_CHUNKS_IN_BLOB_AND_REGISTRATION_REQUEST, 1)); + + // Test need flush + IntStream.range(0, numChannels) + .forEach( + i -> { + addChannel(testContext, i, 1L); + channelCache.setLastFlushTime(getFullyQualifiedTableName(i), Long.MAX_VALUE); + if (i % 2 == 0) { + flushService.setNeedFlush(getFullyQualifiedTableName(i)); + } + }); flushService.flush(false).get(); - Mockito.verify(flushService, Mockito.times(3)).distributeFlushTasks(); - Assert.assertTrue(flushService.lastFlushTime > 0); + Mockito.verify(flushService, Mockito.times(1)).distributeFlushTasks(Mockito.any()); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + if (i % 2 == 0) { + Assert.assertNotEquals( + Long.MAX_VALUE, + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)).longValue()); + } else { + Assert.assertEquals( + Long.MAX_VALUE, + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)).longValue()); + } + }); + + // Test time based flush + IntStream.range(0, numChannels) + .forEach( + i -> { + channelCache.setLastFlushTime( + getFullyQualifiedTableName(i), i % 2 == 0 ? 0L : Long.MAX_VALUE); + }); + flushService.flush(false).get(); + Mockito.verify(flushService, Mockito.times(2)).distributeFlushTasks(Mockito.any()); + IntStream.range(0, numChannels) + .forEach( + i -> { + Assert.assertFalse(channelCache.getNeedFlush(getFullyQualifiedTableName(i))); + if (i % 2 == 0) { + Assert.assertNotEquals( + 0L, channelCache.getLastFlushTime(getFullyQualifiedTableName(i)).longValue()); + } else { + Assert.assertEquals( + Long.MAX_VALUE, + channelCache.getLastFlushTime(getFullyQualifiedTableName(i)).longValue()); + } + }); } @Test @@ -1063,4 +1163,8 @@ private Timer setupTimer(long expectedLatencyMs) { return timer; } + + private String getFullyQualifiedTableName(int tableId) { + return String.format("db1.PUBLIC.table%d", tableId); + } }