From b2a47d8259b0d6671c82962f2ec4d5f69876896c Mon Sep 17 00:00:00 2001 From: owenhalpert Date: Mon, 13 Jan 2025 11:20:16 -0800 Subject: [PATCH] Add thread to periodically perform pending cache maintenance (#2308) Signed-off-by: owenhalpert --- CHANGELOG.md | 1 + .../memory/NativeMemoryCacheManager.java | 42 ++++++++++++++++ .../org/opensearch/knn/plugin/KNNPlugin.java | 4 ++ .../QuantizationStateCache.java | 49 ++++++++++++++++++- .../QuantizationStateCacheManager.java | 8 ++- .../opensearch/knn/KNNSingleNodeTestCase.java | 2 + .../java/org/opensearch/knn/KNNTestCase.java | 5 +- .../memory/NativeMemoryCacheManagerTests.java | 36 ++++++++++++++ .../QuantizationStateCacheTests.java | 46 ++++++++++++++++- 9 files changed, 189 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e199e76f9..639555b48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Documentation ### Maintenance * Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236) +* Added periodic cache maintenance for QuantizationStateCache and NativeMemoryCache [#2308](https://github.com/opensearch-project/k-NN/pull/2308) * Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278] * Added Lucene BWC tests (#2313)[https://github.com/opensearch-project/k-NN/pull/2313] * Upgrade jsonpath from 2.8.0 to 2.9.0[2325](https://github.com/opensearch-project/k-NN/pull/2325) diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index b8aecc5a5..76e94ee66 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -16,6 +16,8 @@ import com.google.common.cache.CacheStats; import com.google.common.cache.RemovalCause; import com.google.common.cache.RemovalNotification; +import lombok.Getter; +import lombok.Setter; import org.apache.commons.lang.Validate; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -24,6 +26,8 @@ import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.plugin.stats.StatNames; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.threadpool.Scheduler.Cancellable; import java.io.Closeable; import java.util.Deque; @@ -47,12 +51,16 @@ public class NativeMemoryCacheManager implements Closeable { private static final Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class); private static NativeMemoryCacheManager INSTANCE; + @Setter + private static ThreadPool threadPool; private Cache cache; private Deque accessRecencyQueue; private final ExecutorService executor; private AtomicBoolean cacheCapacityReached; private long maxWeight; + @Getter + private Cancellable maintenanceTask; NativeMemoryCacheManager() { this.executor = Executors.newSingleThreadExecutor(); @@ -104,6 +112,12 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) { cacheCapacityReached = new AtomicBoolean(false); accessRecencyQueue = new ConcurrentLinkedDeque<>(); cache = cacheBuilder.build(); + + if (threadPool != null) { + startMaintenance(cache); + } else { + logger.warn("ThreadPool is null during NativeMemoryCacheManager initialization. Maintenance will not start."); + } } /** @@ -142,6 +156,9 @@ public synchronized void rebuildCache(NativeMemoryCacheManagerDto nativeMemoryCa @Override public void close() { executor.shutdown(); + if (maintenanceTask != null) { + maintenanceTask.cancel(); + } } /** @@ -449,4 +466,29 @@ private Float getSizeAsPercentage(long size) { } return 100 * size / (float) cbLimit; } + + /** + * Starts the scheduled maintenance for the cache. Without this thread calling cleanUp(), the Guava cache only + * performs maintenance operations (such as evicting expired entries) when the cache is accessed. This + * ensures that the cache is also cleaned up based on the configured expiry time. + * @see Guava Cache Guide + * @param cacheInstance cache on which to call cleanUp() + */ + private void startMaintenance(Cache cacheInstance) { + if (maintenanceTask != null) { + maintenanceTask.cancel(); + } + + Runnable cleanUp = () -> { + try { + cacheInstance.cleanUp(); + } catch (Exception e) { + logger.error("Error cleaning up cache", e); + } + }; + + TimeValue interval = KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES); + + maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index d27f502e1..7fb880f19 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -13,6 +13,7 @@ import org.opensearch.index.engine.EngineFactory; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; @@ -79,6 +80,7 @@ import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.knn.training.VectorReader; @@ -201,6 +203,8 @@ public Collection createComponents( ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); + QuantizationStateCache.setThreadPool(threadPool); + NativeMemoryCacheManager.setThreadPool(threadPool); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java index f057026b9..d2b99fef0 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java @@ -11,11 +11,15 @@ import com.google.common.cache.RemovalCause; import com.google.common.cache.RemovalNotification; import lombok.Getter; +import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.threadpool.Scheduler.Cancellable; +import org.opensearch.threadpool.ThreadPool; +import java.io.Closeable; import java.io.IOException; import java.time.Instant; import java.util.concurrent.TimeUnit; @@ -27,14 +31,18 @@ * A thread-safe singleton cache that contains quantization states. */ @Log4j2 -public class QuantizationStateCache { +public class QuantizationStateCache implements Closeable { private static volatile QuantizationStateCache instance; + @Setter + private static ThreadPool threadPool; private Cache cache; @Getter private long maxCacheSizeInKB; @Getter private Instant evictedDueToSizeAt; + @Getter + private Cancellable maintenanceTask; @VisibleForTesting QuantizationStateCache() { @@ -71,6 +79,37 @@ private void buildCache() { ) .removalListener(this::onRemoval) .build(); + + if (threadPool != null) { + startMaintenance(cache); + } else { + log.warn("ThreadPool is null during QuantizationStateCache initialization. Maintenance will not start."); + } + } + + /** + * Starts the scheduled maintenance for the cache. Without this thread calling cleanUp(), the Guava cache only + * performs maintenance operations (such as evicting expired entries) when the cache is accessed. This + * ensures that the cache is also cleaned up based on the configured expiry time. + * @see Guava Cache Guide + * @param cacheInstance cache on which to call cleanUp() + */ + private void startMaintenance(Cache cacheInstance) { + if (maintenanceTask != null) { + maintenanceTask.cancel(); + } + + Runnable cleanUp = () -> { + try { + cacheInstance.cleanUp(); + } catch (Exception e) { + log.error("Error cleaning up cache", e); + } + }; + + TimeValue interval = KNNSettings.state().getSettingValue(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES); + + maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT); } synchronized void rebuildCache() { @@ -129,4 +168,12 @@ private void updateEvictedDueToSizeAt() { public void clear() { cache.invalidateAll(); } + + @Override + public void close() throws IOException { + if (maintenanceTask != null) { + maintenanceTask.cancel(); + } + } + } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 932d5cde0..63282029a 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -9,10 +9,11 @@ import lombok.NoArgsConstructor; import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader; +import java.io.Closeable; import java.io.IOException; @NoArgsConstructor(access = AccessLevel.PRIVATE) -public final class QuantizationStateCacheManager { +public final class QuantizationStateCacheManager implements Closeable { private static volatile QuantizationStateCacheManager instance; @@ -79,4 +80,9 @@ public void setMaxCacheSizeInKB(long maxCacheSizeInKB) { public void clear() { QuantizationStateCache.getInstance().clear(); } + + @Override + public void close() throws IOException { + QuantizationStateCache.getInstance().close(); + } } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 41cd4e8a5..2ec9ce6b5 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -35,6 +35,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.IndexService; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.plugins.Plugin; import org.opensearch.core.rest.RestStatus; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -86,6 +87,7 @@ protected boolean resetNodeAfterTest() { public void tearDown() throws Exception { NativeMemoryCacheManager.getInstance().invalidateAll(); NativeMemoryCacheManager.getInstance().close(); + QuantizationStateCacheManager.getInstance().close(); NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().close(); NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance().close(); diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 21b3298be..376692f26 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -24,8 +24,10 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.HashSet; import java.util.Map; @@ -73,7 +75,7 @@ protected boolean enableWarningsCheck() { return false; } - public void resetState() { + public void resetState() throws IOException { // Reset all of the counters for (KNNCounter knnCounter : KNNCounter.values()) { knnCounter.set(0L); @@ -83,6 +85,7 @@ public void resetState() { // Clean up the cache NativeMemoryCacheManager.getInstance().invalidateAll(); NativeMemoryCacheManager.getInstance().close(); + QuantizationStateCacheManager.getInstance().close(); } private void initKNNSettings() { diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 5fe41c88c..8a46a781e 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -12,6 +12,8 @@ package org.opensearch.knn.index.memory; import com.google.common.cache.CacheStats; +import org.junit.After; +import org.junit.Before; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.common.settings.Settings; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; @@ -20,6 +22,8 @@ import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.opensearch.threadpool.Scheduler.Cancellable; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.Collection; @@ -34,6 +38,21 @@ public class NativeMemoryCacheManagerTests extends OpenSearchSingleNodeTestCase { + private ThreadPool threadPool; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = new ThreadPool(Settings.builder().put("node.name", "NativeMemoryCacheManagerTests").build()); + NativeMemoryCacheManager.setThreadPool(threadPool); + } + + @After + public void shutdown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + @Override public void tearDown() throws Exception { // Clear out persistent metadata @@ -41,6 +60,7 @@ public void tearDown() throws Exception { Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build(); clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get(); + NativeMemoryCacheManager.getInstance().close(); super.tearDown(); } @@ -51,6 +71,8 @@ protected Collection> getPlugins() { public void testRebuildCache() throws ExecutionException, InterruptedException { NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + Cancellable task1 = nativeMemoryCacheManager.getMaintenanceTask(); + assertNotNull(task1); // Put entry in cache and check that the weight matches int size = 10; @@ -65,6 +87,9 @@ public void testRebuildCache() throws ExecutionException, InterruptedException { // Sleep for a second or two so that the executor can invalidate all entries Thread.sleep(2000); + assertTrue(task1.isCancelled()); + assertNotNull(nativeMemoryCacheManager.getMaintenanceTask()); + assertEquals(0, nativeMemoryCacheManager.getCacheSizeInKilobytes()); nativeMemoryCacheManager.close(); } @@ -378,6 +403,7 @@ public void testCacheCapacity() { nativeMemoryCacheManager.setCacheCapacityReached(false); assertFalse(nativeMemoryCacheManager.isCacheCapacityReached()); + nativeMemoryCacheManager.close(); } public void testGetIndicesCacheStats() throws IOException, ExecutionException { @@ -464,6 +490,16 @@ public void testGetIndicesCacheStats() throws IOException, ExecutionException { nativeMemoryCacheManager.close(); } + public void testMaintenanceScheduled() { + NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); + Cancellable maintenanceTask = nativeMemoryCacheManager.getMaintenanceTask(); + + assertNotNull(maintenanceTask); + + nativeMemoryCacheManager.close(); + assertTrue(maintenanceTask.isCancelled()); + } + private static class TestNativeMemoryAllocation implements NativeMemoryAllocation { int size; diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java index e5381aec7..1a3c56e9a 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -7,6 +7,8 @@ import com.google.common.collect.ImmutableSet; import lombok.SneakyThrows; +import org.junit.After; +import org.junit.Before; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -15,7 +17,10 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -28,6 +33,21 @@ public class QuantizationStateCacheTests extends KNNTestCase { + private ThreadPool threadPool; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = new ThreadPool(Settings.builder().put("node.name", "QuantizationStateCacheTests").build()); + QuantizationStateCache.setThreadPool(threadPool); + } + + @After + public void shutdown() throws Exception { + super.tearDown(); + terminate(threadPool); + } + @SneakyThrows public void testSingleThreadedAddAndRetrieve() { String fieldName = "singleThreadField"; @@ -417,7 +437,7 @@ public void testRebuildOnTimeExpirySettingsChange() { assertNull("State should be null", retrievedState); } - public void testCacheEvictionDueToSize() { + public void testCacheEvictionDueToSize() throws IOException { String fieldName = "evictionField"; // States have size of slightly over 500 bytes so that adding two will reach the max size of 1 kb for the cache int arrayLength = 112; @@ -445,6 +465,30 @@ public void testCacheEvictionDueToSize() { cache.addQuantizationState(fieldName, state); cache.addQuantizationState(fieldName, state2); cache.clear(); + cache.close(); assertNotNull(cache.getEvictedDueToSizeAt()); } + + public void testMaintenanceScheduled() throws Exception { + QuantizationStateCache quantizationStateCache = new QuantizationStateCache(); + Scheduler.Cancellable maintenanceTask = quantizationStateCache.getMaintenanceTask(); + + assertNotNull(maintenanceTask); + + quantizationStateCache.close(); + assertTrue(maintenanceTask.isCancelled()); + } + + public void testMaintenanceWithRebuild() throws Exception { + QuantizationStateCache quantizationStateCache = new QuantizationStateCache(); + Scheduler.Cancellable task1 = quantizationStateCache.getMaintenanceTask(); + assertNotNull(task1); + + quantizationStateCache.rebuildCache(); + + Scheduler.Cancellable task2 = quantizationStateCache.getMaintenanceTask(); + assertTrue(task1.isCancelled()); + assertNotNull(task2); + quantizationStateCache.close(); + } }