Skip to content

Commit

Permalink
Add thread to periodically perform pending cache maintenance (opensea…
Browse files Browse the repository at this point in the history
…rch-project#2308)

Signed-off-by: owenhalpert <[email protected]>
  • Loading branch information
owenhalpert authored Jan 13, 2025
1 parent e2cd03e commit b2a47d8
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
* Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236)
* Added periodic cache maintenance for QuantizationStateCache and NativeMemoryCache [#2308](https://github.com/opensearch-project/k-NN/pull/2308)
* Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278]
* Added Lucene BWC tests (#2313)[https://github.com/opensearch-project/k-NN/pull/2313]
* Upgrade jsonpath from 2.8.0 to 2.9.0[2325](https://github.com/opensearch-project/k-NN/pull/2325)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.google.common.cache.CacheStats;
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang.Validate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -24,6 +26,8 @@
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.plugin.stats.StatNames;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.threadpool.Scheduler.Cancellable;

import java.io.Closeable;
import java.util.Deque;
Expand All @@ -47,12 +51,16 @@ public class NativeMemoryCacheManager implements Closeable {

private static final Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class);
private static NativeMemoryCacheManager INSTANCE;
@Setter
private static ThreadPool threadPool;

private Cache<String, NativeMemoryAllocation> cache;
private Deque<String> accessRecencyQueue;
private final ExecutorService executor;
private AtomicBoolean cacheCapacityReached;
private long maxWeight;
@Getter
private Cancellable maintenanceTask;

NativeMemoryCacheManager() {
this.executor = Executors.newSingleThreadExecutor();
Expand Down Expand Up @@ -104,6 +112,12 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) {
cacheCapacityReached = new AtomicBoolean(false);
accessRecencyQueue = new ConcurrentLinkedDeque<>();
cache = cacheBuilder.build();

if (threadPool != null) {
startMaintenance(cache);
} else {
logger.warn("ThreadPool is null during NativeMemoryCacheManager initialization. Maintenance will not start.");
}
}

/**
Expand Down Expand Up @@ -142,6 +156,9 @@ public synchronized void rebuildCache(NativeMemoryCacheManagerDto nativeMemoryCa
@Override
public void close() {
executor.shutdown();
if (maintenanceTask != null) {
maintenanceTask.cancel();
}
}

/**
Expand Down Expand Up @@ -449,4 +466,29 @@ private Float getSizeAsPercentage(long size) {
}
return 100 * size / (float) cbLimit;
}

/**
* Starts the scheduled maintenance for the cache. Without this thread calling cleanUp(), the Guava cache only
* performs maintenance operations (such as evicting expired entries) when the cache is accessed. This
* ensures that the cache is also cleaned up based on the configured expiry time.
* @see <a href="https://github.com/google/guava/wiki/cachesexplained#timed-eviction"> Guava Cache Guide</a>
* @param cacheInstance cache on which to call cleanUp()
*/
private void startMaintenance(Cache<String, NativeMemoryAllocation> cacheInstance) {
if (maintenanceTask != null) {
maintenanceTask.cancel();
}

Runnable cleanUp = () -> {
try {
cacheInstance.cleanUp();
} catch (Exception e) {
logger.error("Error cleaning up cache", e);
}
};

TimeValue interval = KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES);

maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT);
}
}
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.index.engine.EngineFactory;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.query.KNNQueryBuilder;
Expand Down Expand Up @@ -79,6 +80,7 @@
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.knn.training.TrainingJobClusterStateListener;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.knn.training.VectorReader;
Expand Down Expand Up @@ -201,6 +203,8 @@ public Collection<Object> createComponents(
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
QuantizationStateCache.setThreadPool(threadPool);
NativeMemoryCacheManager.setThreadPool(threadPool);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.threadpool.Scheduler.Cancellable;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.time.Instant;
import java.util.concurrent.TimeUnit;
Expand All @@ -27,14 +31,18 @@
* A thread-safe singleton cache that contains quantization states.
*/
@Log4j2
public class QuantizationStateCache {
public class QuantizationStateCache implements Closeable {

private static volatile QuantizationStateCache instance;
@Setter
private static ThreadPool threadPool;
private Cache<String, QuantizationState> cache;
@Getter
private long maxCacheSizeInKB;
@Getter
private Instant evictedDueToSizeAt;
@Getter
private Cancellable maintenanceTask;

@VisibleForTesting
QuantizationStateCache() {
Expand Down Expand Up @@ -71,6 +79,37 @@ private void buildCache() {
)
.removalListener(this::onRemoval)
.build();

if (threadPool != null) {
startMaintenance(cache);
} else {
log.warn("ThreadPool is null during QuantizationStateCache initialization. Maintenance will not start.");
}
}

/**
* Starts the scheduled maintenance for the cache. Without this thread calling cleanUp(), the Guava cache only
* performs maintenance operations (such as evicting expired entries) when the cache is accessed. This
* ensures that the cache is also cleaned up based on the configured expiry time.
* @see <a href="https://github.com/google/guava/wiki/cachesexplained#timed-eviction"> Guava Cache Guide</a>
* @param cacheInstance cache on which to call cleanUp()
*/
private void startMaintenance(Cache<String, QuantizationState> cacheInstance) {
if (maintenanceTask != null) {
maintenanceTask.cancel();
}

Runnable cleanUp = () -> {
try {
cacheInstance.cleanUp();
} catch (Exception e) {
log.error("Error cleaning up cache", e);
}
};

TimeValue interval = KNNSettings.state().getSettingValue(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES);

maintenanceTask = threadPool.scheduleWithFixedDelay(cleanUp, interval, ThreadPool.Names.MANAGEMENT);
}

synchronized void rebuildCache() {
Expand Down Expand Up @@ -129,4 +168,12 @@ private void updateEvictedDueToSizeAt() {
public void clear() {
cache.invalidateAll();
}

@Override
public void close() throws IOException {
if (maintenanceTask != null) {
maintenanceTask.cancel();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader;

import java.io.Closeable;
import java.io.IOException;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class QuantizationStateCacheManager {
public final class QuantizationStateCacheManager implements Closeable {

private static volatile QuantizationStateCacheManager instance;

Expand Down Expand Up @@ -79,4 +80,9 @@ public void setMaxCacheSizeInKB(long maxCacheSizeInKB) {
public void clear() {
QuantizationStateCache.getInstance().clear();
}

@Override
public void close() throws IOException {
QuantizationStateCache.getInstance().close();
}
}
2 changes: 2 additions & 0 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.IndexService;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.plugins.Plugin;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
Expand Down Expand Up @@ -86,6 +87,7 @@ protected boolean resetNodeAfterTest() {
public void tearDown() throws Exception {
NativeMemoryCacheManager.getInstance().invalidateAll();
NativeMemoryCacheManager.getInstance().close();
QuantizationStateCacheManager.getInstance().close();
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close();
NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance().close();
NativeMemoryLoadStrategy.AnonymousLoadStrategy.getInstance().close();
Expand Down
5 changes: 4 additions & 1 deletion src/test/java/org/opensearch/knn/KNNTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
Expand Down Expand Up @@ -73,7 +75,7 @@ protected boolean enableWarningsCheck() {
return false;
}

public void resetState() {
public void resetState() throws IOException {
// Reset all of the counters
for (KNNCounter knnCounter : KNNCounter.values()) {
knnCounter.set(0L);
Expand All @@ -83,6 +85,7 @@ public void resetState() {
// Clean up the cache
NativeMemoryCacheManager.getInstance().invalidateAll();
NativeMemoryCacheManager.getInstance().close();
QuantizationStateCacheManager.getInstance().close();
}

private void initKNNSettings() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
package org.opensearch.knn.index.memory;

import com.google.common.cache.CacheStats;
import org.junit.After;
import org.junit.Before;
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.common.exception.OutOfNativeMemoryException;
Expand All @@ -20,6 +22,8 @@
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.test.OpenSearchSingleNodeTestCase;
import org.opensearch.threadpool.Scheduler.Cancellable;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Collection;
Expand All @@ -34,13 +38,29 @@

public class NativeMemoryCacheManagerTests extends OpenSearchSingleNodeTestCase {

private ThreadPool threadPool;

@Before
public void setUp() throws Exception {
super.setUp();
threadPool = new ThreadPool(Settings.builder().put("node.name", "NativeMemoryCacheManagerTests").build());
NativeMemoryCacheManager.setThreadPool(threadPool);
}

@After
public void shutdown() throws Exception {
super.tearDown();
terminate(threadPool);
}

@Override
public void tearDown() throws Exception {
// Clear out persistent metadata
ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest();
Settings circuitBreakerSettings = Settings.builder().putNull(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED).build();
clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings);
client().admin().cluster().updateSettings(clusterUpdateSettingsRequest).get();
NativeMemoryCacheManager.getInstance().close();
super.tearDown();
}

Expand All @@ -51,6 +71,8 @@ protected Collection<Class<? extends Plugin>> getPlugins() {

public void testRebuildCache() throws ExecutionException, InterruptedException {
NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager();
Cancellable task1 = nativeMemoryCacheManager.getMaintenanceTask();
assertNotNull(task1);

// Put entry in cache and check that the weight matches
int size = 10;
Expand All @@ -65,6 +87,9 @@ public void testRebuildCache() throws ExecutionException, InterruptedException {
// Sleep for a second or two so that the executor can invalidate all entries
Thread.sleep(2000);

assertTrue(task1.isCancelled());
assertNotNull(nativeMemoryCacheManager.getMaintenanceTask());

assertEquals(0, nativeMemoryCacheManager.getCacheSizeInKilobytes());
nativeMemoryCacheManager.close();
}
Expand Down Expand Up @@ -378,6 +403,7 @@ public void testCacheCapacity() {

nativeMemoryCacheManager.setCacheCapacityReached(false);
assertFalse(nativeMemoryCacheManager.isCacheCapacityReached());
nativeMemoryCacheManager.close();
}

public void testGetIndicesCacheStats() throws IOException, ExecutionException {
Expand Down Expand Up @@ -464,6 +490,16 @@ public void testGetIndicesCacheStats() throws IOException, ExecutionException {
nativeMemoryCacheManager.close();
}

public void testMaintenanceScheduled() {
NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager();
Cancellable maintenanceTask = nativeMemoryCacheManager.getMaintenanceTask();

assertNotNull(maintenanceTask);

nativeMemoryCacheManager.close();
assertTrue(maintenanceTask.isCancelled());
}

private static class TestNativeMemoryAllocation implements NativeMemoryAllocation {

int size;
Expand Down
Loading

0 comments on commit b2a47d8

Please sign in to comment.