diff --git a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java index f40c35dde83de..40c6acf1dc242 100644 --- a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java +++ b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java @@ -77,9 +77,13 @@ public class TieredSpilloverCache implements ICache { private final TieredSpilloverCacheStatsHolder statsHolder; private ToLongBiFunction, V> weigher; private final List dimensionNames; - ReadWriteLock readWriteLock = new ReentrantReadWriteLock(); - ReleasableLock readLock = new ReleasableLock(readWriteLock.readLock()); - ReleasableLock writeLock = new ReleasableLock(readWriteLock.writeLock()); + + // The locks ensure keys can't end up in both the heap and disk tier simultaneously. + // For performance, we have several locks, which effectively segment the cache by key.hashCode(). + private final int NUM_LOCKS = 256; + private final ReleasableLock[] readLocks; + private final ReleasableLock[] writeLocks; + /** * Maintains caching tiers in ascending order of cache latency. */ @@ -139,6 +143,15 @@ public class TieredSpilloverCache implements ICache { this.policies = builder.policies; // Will never be null; builder initializes it to an empty list builder.cacheConfig.getClusterSettings() .addSettingsUpdateConsumer(DISK_CACHE_ENABLED_SETTING_MAP.get(builder.cacheType), this::enableDisableDiskCache); + + ReadWriteLock[] locks = new ReadWriteLock[NUM_LOCKS]; + this.readLocks = new ReleasableLock[NUM_LOCKS]; + this.writeLocks = new ReleasableLock[NUM_LOCKS]; + for (int i = 0; i < NUM_LOCKS; i++) { + locks[i] = new ReentrantReadWriteLock(); + readLocks[i] = new ReleasableLock(locks[i].readLock()); + writeLocks[i] = new ReleasableLock(locks[i].writeLock()); + } } // Package private for testing @@ -170,7 +183,7 @@ public V get(ICacheKey key) { @Override public void put(ICacheKey key, V value) { - try (ReleasableLock ignore = writeLock.acquire()) { + try (ReleasableLock ignore = getWriteLockForKey(key).acquire()) { onHeapCache.put(key, value); updateStatsOnPut(TIER_DIMENSION_VALUE_ON_HEAP, key, value); } @@ -191,7 +204,7 @@ public V computeIfAbsent(ICacheKey key, LoadAwareCacheLoader, V> // This is needed as there can be many requests for the same key at the same time and we only want to load // the value once. V value = null; - try (ReleasableLock ignore = writeLock.acquire()) { + try (ReleasableLock ignore = getWriteLockForKey(key).acquire()) { value = onHeapCache.computeIfAbsent(key, loader); } // Handle stats @@ -234,7 +247,7 @@ public void invalidate(ICacheKey key) { statsHolder.removeDimensions(dimensionValues); } if (key.key != null) { - try (ReleasableLock ignore = writeLock.acquire()) { + try (ReleasableLock ignore = getWriteLockForKey(key).acquire()) { cacheEntry.getKey().invalidate(key); } } @@ -243,10 +256,20 @@ public void invalidate(ICacheKey key) { @Override public void invalidateAll() { - try (ReleasableLock ignore = writeLock.acquire()) { + // For cache-wide operations like refresh() and invalidateAll(), all the segment locks must be acquired. + // To avoid possible deadlock if they run at the same time, they acquire the locks in index order and + // release them in reverse order. + try { + for (int i = 0; i < NUM_LOCKS; i++) { + writeLocks[i].acquire(); + } for (Map.Entry, TierInfo> cacheEntry : caches.entrySet()) { cacheEntry.getKey().invalidateAll(); } + } finally { + for (int i = NUM_LOCKS - 1; i >= 0; i--) { + writeLocks[i].close(); + } } statsHolder.reset(); } @@ -275,10 +298,18 @@ public long count() { @Override public void refresh() { - try (ReleasableLock ignore = writeLock.acquire()) { + try { + // Acquire the locks in index order + for (int i = 0; i < NUM_LOCKS; i++) { + writeLocks[i].acquire(); + } for (Map.Entry, TierInfo> cacheEntry : caches.entrySet()) { cacheEntry.getKey().refresh(); } + } finally { + for (int i = NUM_LOCKS - 1; i >= 0; i--) { + writeLocks[i].close(); + } } } @@ -302,7 +333,7 @@ public ImmutableCacheStatsHolder stats(String[] levels) { */ private Function, Tuple> getValueFromTieredCache(boolean captureStats) { return key -> { - try (ReleasableLock ignore = readLock.acquire()) { + try (ReleasableLock ignore = getReadLockForKey(key).acquire()) { for (Map.Entry, TierInfo> cacheEntry : caches.entrySet()) { if (cacheEntry.getValue().isEnabled()) { V value = cacheEntry.getKey().get(key); @@ -328,7 +359,7 @@ void handleRemovalFromHeapTier(RemovalNotification, V> notification ICacheKey key = notification.getKey(); boolean wasEvicted = SPILLOVER_REMOVAL_REASONS.contains(notification.getRemovalReason()); if (caches.get(diskCache).isEnabled() && wasEvicted && evaluatePolicies(notification.getValue())) { - try (ReleasableLock ignore = writeLock.acquire()) { + try (ReleasableLock ignore = getWriteLockForKey(key).acquire()) { diskCache.put(key, notification.getValue()); // spill over to the disk tier and increment its stats } updateStatsOnPut(TIER_DIMENSION_VALUE_DISK, key, notification.getValue()); @@ -371,6 +402,22 @@ boolean evaluatePolicies(V value) { return true; } + private int getLockIndexForKey(ICacheKey key) { + // Since OpensearchOnHeapCache also uses segments based on the least significant byte of the key + // (key.hashCode() & 0xff), we use the second-least significant byte. This way, if two keys face + // lock contention in the TSC's locks, they will be unlikely to also face lock contention in OpensearchOnHeapCache. + // This should help p100 times. + return (key.hashCode() & 0xff00) >> 8; + } + + private ReleasableLock getReadLockForKey(ICacheKey key) { + return readLocks[getLockIndexForKey(key)]; + } + + private ReleasableLock getWriteLockForKey(ICacheKey key) { + return writeLocks[getLockIndexForKey(key)]; + } + /** * A class which receives removal events from the heap tier. */ diff --git a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java index 6c49341591589..9e437e956618a 100644 --- a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java +++ b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java @@ -1323,6 +1323,86 @@ public void testTierStatsAddCorrectly() throws Exception { } + public void testGlobalOperationsDoNotCauseDeadlock() throws Exception { + // Confirm refresh() and invalidateAll(), which both require all segment locks, don't cause deadlock if run concurrently + int numEntries = 250; + int onHeapCacheSize = randomIntBetween(10, 30); + int diskCacheSize = randomIntBetween(numEntries, numEntries + 100); + int keyValueSize = 50; + MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( + keyValueSize, + diskCacheSize, + removalListener, + Settings.builder() + .put( + OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE) + .get(MAXIMUM_SIZE_IN_BYTES_KEY) + .getKey(), + onHeapCacheSize * keyValueSize + "b" + ) + .build(), + 0 + ); + + // First try refresh() and then invalidateAll() + // Put some values in the cache + for (int i = 0; i < numEntries; i++) { + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), getLoadAwareCacheLoader()); + } + assertEquals(numEntries, tieredSpilloverCache.count()); + + Phaser phaser = new Phaser(3); + CountDownLatch countDownLatch = new CountDownLatch(2); + Thread refreshThread = new Thread(() -> { + phaser.arriveAndAwaitAdvance(); + tieredSpilloverCache.refresh(); + countDownLatch.countDown(); + }); + Thread invalidateThread = new Thread(() -> { + phaser.arriveAndAwaitAdvance(); + tieredSpilloverCache.invalidateAll(); + countDownLatch.countDown(); + }); + + refreshThread.start(); + invalidateThread.start(); + phaser.arriveAndAwaitAdvance(); + countDownLatch.await(); + + // This should terminate and we should see an empty cache + assertEquals(0, tieredSpilloverCache.count()); + + // Do it again, running invalidateAll() first and then refresh() + for (int i = 0; i < numEntries; i++) { + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), getLoadAwareCacheLoader()); + } + // By successfully adding values back to the cache we show the locks were correctly released by the previous cache-wide operations + assertEquals(numEntries, tieredSpilloverCache.count()); + + Phaser secondPhaser = new Phaser(3); + CountDownLatch secondCountDownLatch = new CountDownLatch(2); + refreshThread = new Thread(() -> { + secondPhaser.arriveAndAwaitAdvance(); + tieredSpilloverCache.refresh(); + secondCountDownLatch.countDown(); + }); + invalidateThread = new Thread(() -> { + secondPhaser.arriveAndAwaitAdvance(); + tieredSpilloverCache.invalidateAll(); + secondCountDownLatch.countDown(); + }); + + invalidateThread.start(); + refreshThread.start(); + + secondPhaser.arriveAndAwaitAdvance(); + secondCountDownLatch.await(); + + // This should terminate and we should see an empty cache + assertEquals(0, tieredSpilloverCache.count()); + } + private List getMockDimensions() { List dims = new ArrayList<>(); for (String dimensionName : dimensionNames) {