From f2e2a85e948d486d9041b9642c85ebe4614d7bb4 Mon Sep 17 00:00:00 2001 From: Niyati Aggarwal <121826855+niyatiagg@users.noreply.github.com> Date: Fri, 12 Apr 2024 11:02:06 -0700 Subject: [PATCH 1/4] Refactoring globMatch using simpleMatchWithNormalizedStrings from Regex (#13104) * Refactoring globMatch using simpleMatchWithNormalizedStrings from Regex Signed-off-by: Niyati Aggarwal * Adding entry to CHANGELOG.md Signed-off-by: Niyati Aggarwal * Adding tests for GlobMatch Signed-off-by: Niyati Aggarwal * Moving entry to Changed section in CHANGELOG.md Signed-off-by: Niyati Aggarwal --------- Signed-off-by: Niyati Aggarwal --- CHANGELOG.md | 1 + .../main/java/org/opensearch/common/Glob.java | 53 ++++++++------- .../org/opensearch/common/regex/Regex.java | 35 +--------- .../java/org/opensearch/common/GlobTests.java | 67 +++++++++++++++++++ 4 files changed, 97 insertions(+), 59 deletions(-) create mode 100644 server/src/test/java/org/opensearch/common/GlobTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f8b3e10a37a37..eef437da4e1fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [BWC and API enforcement] Enforcing the presence of API annotations at build time ([#12872](https://github.com/opensearch-project/OpenSearch/pull/12872)) - Improve built-in secure transports support ([#12907](https://github.com/opensearch-project/OpenSearch/pull/12907)) - Update links to documentation in rest-api-spec ([#13043](https://github.com/opensearch-project/OpenSearch/pull/13043)) +- Refactoring globMatch using simpleMatchWithNormalizedStrings from Regex ([#13104](https://github.com/opensearch-project/OpenSearch/pull/13104)) ### Deprecated diff --git a/libs/common/src/main/java/org/opensearch/common/Glob.java b/libs/common/src/main/java/org/opensearch/common/Glob.java index daf045dd49e3a..b390a3ca84182 100644 --- a/libs/common/src/main/java/org/opensearch/common/Glob.java +++ b/libs/common/src/main/java/org/opensearch/common/Glob.java @@ -52,34 +52,35 @@ public static boolean globMatch(String pattern, String str) { if (pattern == null || str == null) { return false; } - int firstIndex = pattern.indexOf('*'); - if (firstIndex == -1) { - return pattern.equals(str); - } - if (firstIndex == 0) { - if (pattern.length() == 1) { - return true; - } - int nextIndex = pattern.indexOf('*', firstIndex + 1); - if (nextIndex == -1) { - return str.endsWith(pattern.substring(1)); - } else if (nextIndex == 1) { - // Double wildcard "**" - skipping the first "*" - return globMatch(pattern.substring(1), str); + int sIdx = 0, pIdx = 0, match = 0, wildcardIdx = -1; + while (sIdx < str.length()) { + // both chars matching, incrementing both pointers + if (pIdx < pattern.length() && str.charAt(sIdx) == pattern.charAt(pIdx)) { + sIdx++; + pIdx++; + } else if (pIdx < pattern.length() && pattern.charAt(pIdx) == '*') { + // wildcard found, only incrementing pattern pointer + wildcardIdx = pIdx; + match = sIdx; + pIdx++; + } else if (wildcardIdx != -1) { + // last pattern pointer was a wildcard, incrementing string pointer + pIdx = wildcardIdx + 1; + match++; + sIdx = match; + } else { + // current pattern pointer is not a wildcard, last pattern pointer was also not a wildcard + // characters do not match + return false; } - String part = pattern.substring(1, nextIndex); - int partIndex = str.indexOf(part); - while (partIndex != -1) { - if (globMatch(pattern.substring(nextIndex), str.substring(partIndex + part.length()))) { - return true; - } - partIndex = str.indexOf(part, partIndex + 1); - } - return false; } - return (str.length() >= firstIndex - && pattern.substring(0, firstIndex).equals(str.substring(0, firstIndex)) - && globMatch(pattern.substring(firstIndex), str.substring(firstIndex))); + + // check for remaining characters in pattern + while (pIdx < pattern.length() && pattern.charAt(pIdx) == '*') { + pIdx++; + } + + return pIdx == pattern.length(); } } diff --git a/server/src/main/java/org/opensearch/common/regex/Regex.java b/server/src/main/java/org/opensearch/common/regex/Regex.java index 323b460af62df..6d8b5c3585c4c 100644 --- a/server/src/main/java/org/opensearch/common/regex/Regex.java +++ b/server/src/main/java/org/opensearch/common/regex/Regex.java @@ -35,6 +35,7 @@ import org.apache.lucene.util.automaton.Automata; import org.apache.lucene.util.automaton.Automaton; import org.apache.lucene.util.automaton.Operations; +import org.opensearch.common.Glob; import org.opensearch.core.common.Strings; import java.util.ArrayList; @@ -125,39 +126,7 @@ public static boolean simpleMatch(String pattern, String str, boolean caseInsens pattern = Strings.toLowercaseAscii(pattern); str = Strings.toLowercaseAscii(str); } - return simpleMatchWithNormalizedStrings(pattern, str); - } - - private static boolean simpleMatchWithNormalizedStrings(String pattern, String str) { - int sIdx = 0, pIdx = 0, match = 0, wildcardIdx = -1; - while (sIdx < str.length()) { - // both chars matching, incrementing both pointers - if (pIdx < pattern.length() && str.charAt(sIdx) == pattern.charAt(pIdx)) { - sIdx++; - pIdx++; - } else if (pIdx < pattern.length() && pattern.charAt(pIdx) == '*') { - // wildcard found, only incrementing pattern pointer - wildcardIdx = pIdx; - match = sIdx; - pIdx++; - } else if (wildcardIdx != -1) { - // last pattern pointer was a wildcard, incrementing string pointer - pIdx = wildcardIdx + 1; - match++; - sIdx = match; - } else { - // current pattern pointer is not a wildcard, last pattern pointer was also not a wildcard - // characters do not match - return false; - } - } - - // check for remaining characters in pattern - while (pIdx < pattern.length() && pattern.charAt(pIdx) == '*') { - pIdx++; - } - - return pIdx == pattern.length(); + return Glob.globMatch(pattern, str); } /** diff --git a/server/src/test/java/org/opensearch/common/GlobTests.java b/server/src/test/java/org/opensearch/common/GlobTests.java new file mode 100644 index 0000000000000..2bbe157be43cc --- /dev/null +++ b/server/src/test/java/org/opensearch/common/GlobTests.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common; + +import org.opensearch.test.OpenSearchTestCase; + +public class GlobTests extends OpenSearchTestCase { + + public void testGlobMatchForNull() { + assertFalse(Glob.globMatch(null, "test")); + assertFalse(Glob.globMatch("test", null)); + assertFalse(Glob.globMatch(null, null)); + } + + public void testGlobMatchNoWildcard() { + assertTrue(Glob.globMatch("abcd", "abcd")); + assertFalse(Glob.globMatch("abcd", "foobar")); + } + + public void testGlobMatchSingleWildcard() { + assertTrue(Glob.globMatch("*foo", "barfoo")); + assertFalse(Glob.globMatch("*foo", "foobar")); + assertTrue(Glob.globMatch("foo*", "foobarfoo")); + assertFalse(Glob.globMatch("foo*", "barfoobar")); + assertTrue(Glob.globMatch("foo*bar", "foobarnfoosbar")); + } + + public void testGlobMatchMultipleWildcards() { + assertTrue(Glob.globMatch("*foo*", "barfoobar")); + assertFalse(Glob.globMatch("*foo*", "baroofbar")); + assertTrue(Glob.globMatch("*foo*bar", "abcdfooefghbar")); + assertFalse(Glob.globMatch("*foo*bar", "foonotbars")); + } + + public void testGlobalMatchDoubleWildcard() { + assertTrue(Glob.globMatch("**foo", "barbarfoo")); + assertFalse(Glob.globMatch("**foo", "barbarfoowoof")); + assertTrue(Glob.globMatch("**bar**", "foobarfoo")); + assertFalse(Glob.globMatch("**bar**", "foobanfoo")); + } + + public void testGlobMatchMultipleCharactersWithSingleWildcard() { + assertTrue(Glob.globMatch("a*b", "acb")); + assertTrue(Glob.globMatch("f*oo", "foo")); + assertTrue(Glob.globMatch("a*b", "aab")); + assertTrue(Glob.globMatch("a*b", "aaab")); + } + + public void testGlobMatchWildcardWithEmptyString() { + assertTrue(Glob.globMatch("*", "")); + assertTrue(Glob.globMatch("a*", "a")); + assertFalse(Glob.globMatch("a*", "")); + } + + public void testGlobMatchMultipleWildcardsWithMultipleCharacters() { + assertTrue(Glob.globMatch("a*b*c", "abc")); + assertTrue(Glob.globMatch("a*b*c", "axxxbxbc")); + assertFalse(Glob.globMatch("a*b*c", "abca")); + assertFalse(Glob.globMatch("a*b*c", "ac")); + } +} From cc22310145309b3f3dfca377a84f4e26579611be Mon Sep 17 00:00:00 2001 From: peteralfonsi Date: Fri, 12 Apr 2024 16:09:06 -0700 Subject: [PATCH 2/4] [Tiered Caching] Stats rework (1/3): Interfaces and implementations for individual tiers (#12531) As part of tiered caching stats, changes the common ICache interface to use ICacheKey as its key. This key contains dimensions (for example, shard ID, index name, or tier) that can be used to aggregate stats. Also changes the CacheStats interface to store the necessary cache stats, and to support getting stats either as a total or aggregated by these dimensions. Integrates these changes with OpenSearchOnHeapCache and EhcacheDiskCache. The stats implementation for the TieredSpilloverCache will be in a followup PR. --------- Signed-off-by: Peter Alfonsi Co-authored-by: Peter Alfonsi --- CHANGELOG.md | 1 + .../common/tier/TieredSpilloverCache.java | 40 ++- .../cache/common/tier/MockDiskCache.java | 25 +- .../tier/TieredSpilloverCacheTests.java | 270 +++++++-------- .../cache/store/disk/EhcacheDiskCache.java | 171 +++++++--- .../store/disk/EhCacheDiskCacheTests.java | 314 +++++++++++++++--- .../org/opensearch/common/cache/Cache.java | 4 + .../org/opensearch/common/cache/ICache.java | 18 +- .../opensearch/common/cache/ICacheKey.java | 96 ++++++ .../cache/serializer/ICacheKeySerializer.java | 87 +++++ .../common/cache/stats/CacheStats.java | 132 ++++++++ .../common/cache/stats/CacheStatsHolder.java | 295 ++++++++++++++++ .../cache/stats/ImmutableCacheStats.java | 103 ++++++ .../stats/ImmutableCacheStatsHolder.java | 111 +++++++ .../common/cache/stats/package-info.java | 9 + .../cache/store/OpenSearchOnHeapCache.java | 85 ++++- .../cache/store/builders/ICacheBuilder.java | 13 +- .../cache/store/config/CacheConfig.java | 35 +- .../indices/IndicesRequestCache.java | 86 ++++- .../serializer/ICacheKeySerializerTests.java | 107 ++++++ .../cache/stats/CacheStatsHolderTests.java | 287 ++++++++++++++++ .../stats/ImmutableCacheStatsHolderTests.java | 88 +++++ .../cache/stats/ImmutableCacheStatsTests.java | 47 +++ .../store/OpenSearchOnHeapCacheTests.java | 181 ++++++++++ .../indices/IndicesRequestCacheTests.java | 137 +++++++- 25 files changed, 2411 insertions(+), 331 deletions(-) create mode 100644 server/src/main/java/org/opensearch/common/cache/ICacheKey.java create mode 100644 server/src/main/java/org/opensearch/common/cache/serializer/ICacheKeySerializer.java create mode 100644 server/src/main/java/org/opensearch/common/cache/stats/CacheStats.java create mode 100644 server/src/main/java/org/opensearch/common/cache/stats/CacheStatsHolder.java create mode 100644 server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStats.java create mode 100644 server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolder.java create mode 100644 server/src/main/java/org/opensearch/common/cache/stats/package-info.java create mode 100644 server/src/test/java/org/opensearch/common/cache/serializer/ICacheKeySerializerTests.java create mode 100644 server/src/test/java/org/opensearch/common/cache/stats/CacheStatsHolderTests.java create mode 100644 server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolderTests.java create mode 100644 server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsTests.java create mode 100644 server/src/test/java/org/opensearch/common/cache/store/OpenSearchOnHeapCacheTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index eef437da4e1fd..9b41f1cd4ca01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Remote Store] Make translog transfer timeout configurable ([#12704](https://github.com/opensearch-project/OpenSearch/pull/12704)) - Reject Resize index requests (i.e, split, shrink and clone), While DocRep to SegRep migration is in progress.([#12686](https://github.com/opensearch-project/OpenSearch/pull/12686)) - Add support for more than one protocol for transport ([#12967](https://github.com/opensearch-project/OpenSearch/pull/12967)) +- [Tiered Caching] Add dimension-based stats to ICache implementations. ([#12531](https://github.com/opensearch-project/OpenSearch/pull/12531)) - Add changes for overriding remote store and replication settings during snapshot restore. ([#11868](https://github.com/opensearch-project/OpenSearch/pull/11868)) ### Dependencies diff --git a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java index cab05732ba1c4..ae3d9f1dbcf62 100644 --- a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java +++ b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java @@ -12,11 +12,13 @@ import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.RemovalReason; import org.opensearch.common.cache.policy.CachedQueryResult; +import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder; import org.opensearch.common.cache.store.config.CacheConfig; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; @@ -54,7 +56,11 @@ public class TieredSpilloverCache implements ICache { private final ICache diskCache; private final ICache onHeapCache; - private final RemovalListener removalListener; + + // The listener for removals from the spillover cache as a whole + // TODO: In TSC stats PR, each tier will have its own separate removal listener. + private final RemovalListener, V> removalListener; + private final List dimensionNames; ReadWriteLock readWriteLock = new ReentrantReadWriteLock(); ReleasableLock readLock = new ReleasableLock(readWriteLock.readLock()); ReleasableLock writeLock = new ReleasableLock(readWriteLock.writeLock()); @@ -70,9 +76,9 @@ public class TieredSpilloverCache implements ICache { this.removalListener = Objects.requireNonNull(builder.removalListener, "Removal listener can't be null"); this.onHeapCache = builder.onHeapCacheFactory.create( - new CacheConfig.Builder().setRemovalListener(new RemovalListener() { + new CacheConfig.Builder().setRemovalListener(new RemovalListener, V>() { @Override - public void onRemoval(RemovalNotification notification) { + public void onRemoval(RemovalNotification, V> notification) { try (ReleasableLock ignore = writeLock.acquire()) { if (SPILLOVER_REMOVAL_REASONS.contains(notification.getRemovalReason()) && evaluatePolicies(notification.getValue())) { @@ -87,6 +93,7 @@ && evaluatePolicies(notification.getValue())) { .setValueType(builder.cacheConfig.getValueType()) .setSettings(builder.cacheConfig.getSettings()) .setWeigher(builder.cacheConfig.getWeigher()) + .setDimensionNames(builder.cacheConfig.getDimensionNames()) .setMaxSizeInBytes(builder.cacheConfig.getMaxSizeInBytes()) .setExpireAfterAccess(builder.cacheConfig.getExpireAfterAccess()) .setClusterSettings(builder.cacheConfig.getClusterSettings()) @@ -97,7 +104,7 @@ && evaluatePolicies(notification.getValue())) { ); this.diskCache = builder.diskCacheFactory.create(builder.cacheConfig, builder.cacheType, builder.cacheFactories); this.cacheList = Arrays.asList(onHeapCache, diskCache); - + this.dimensionNames = builder.cacheConfig.getDimensionNames(); this.policies = builder.policies; // Will never be null; builder initializes it to an empty list } @@ -112,19 +119,19 @@ ICache getDiskCache() { } @Override - public V get(K key) { + public V get(ICacheKey key) { return getValueFromTieredCache().apply(key); } @Override - public void put(K key, V value) { + public void put(ICacheKey key, V value) { try (ReleasableLock ignore = writeLock.acquire()) { onHeapCache.put(key, value); } } @Override - public V computeIfAbsent(K key, LoadAwareCacheLoader loader) throws Exception { + public V computeIfAbsent(ICacheKey key, LoadAwareCacheLoader, V> loader) throws Exception { V cacheValue = getValueFromTieredCache().apply(key); if (cacheValue == null) { @@ -141,7 +148,7 @@ public V computeIfAbsent(K key, LoadAwareCacheLoader loader) throws Except } @Override - public void invalidate(K key) { + public void invalidate(ICacheKey key) { // We are trying to invalidate the key from all caches though it would be present in only of them. // Doing this as we don't know where it is located. We could do a get from both and check that, but what will // also trigger a hit/miss listener event, so ignoring it for now. @@ -167,9 +174,9 @@ public void invalidateAll() { */ @SuppressWarnings({ "unchecked" }) @Override - public Iterable keys() { - Iterable[] iterables = (Iterable[]) new Iterable[] { onHeapCache.keys(), diskCache.keys() }; - return new ConcatenatedIterables(iterables); + public Iterable> keys() { + Iterable>[] iterables = (Iterable>[]) new Iterable[] { onHeapCache.keys(), diskCache.keys() }; + return new ConcatenatedIterables>(iterables); } @Override @@ -197,7 +204,12 @@ public void close() throws IOException { } } - private Function getValueFromTieredCache() { + @Override + public ImmutableCacheStatsHolder stats() { + return null; // TODO: in TSC stats PR + } + + private Function, V> getValueFromTieredCache() { return key -> { try (ReleasableLock ignore = readLock.acquire()) { for (ICache cache : cacheList) { @@ -354,7 +366,7 @@ public String getCacheName() { public static class Builder { private ICache.Factory onHeapCacheFactory; private ICache.Factory diskCacheFactory; - private RemovalListener removalListener; + private RemovalListener, V> removalListener; private CacheConfig cacheConfig; private CacheType cacheType; private Map cacheFactories; @@ -390,7 +402,7 @@ public Builder setDiskCacheFactory(ICache.Factory diskCacheFactory) { * @param removalListener Removal listener * @return builder */ - public Builder setRemovalListener(RemovalListener removalListener) { + public Builder setRemovalListener(RemovalListener, V> removalListener) { this.removalListener = removalListener; return this; } diff --git a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/MockDiskCache.java b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/MockDiskCache.java index 548c5d846dda5..0d98503af635f 100644 --- a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/MockDiskCache.java +++ b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/MockDiskCache.java @@ -10,11 +10,13 @@ import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.RemovalReason; import org.opensearch.common.cache.serializer.Serializer; +import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder; import org.opensearch.common.cache.store.builders.ICacheBuilder; import org.opensearch.common.cache.store.config.CacheConfig; @@ -25,27 +27,27 @@ public class MockDiskCache implements ICache { - Map cache; + Map, V> cache; int maxSize; long delay; - private final RemovalListener removalListener; + private final RemovalListener, V> removalListener; - public MockDiskCache(int maxSize, long delay, RemovalListener removalListener) { + public MockDiskCache(int maxSize, long delay, RemovalListener, V> removalListener) { this.maxSize = maxSize; this.delay = delay; this.removalListener = removalListener; - this.cache = new ConcurrentHashMap(); + this.cache = new ConcurrentHashMap, V>(); } @Override - public V get(K key) { + public V get(ICacheKey key) { V value = cache.get(key); return value; } @Override - public void put(K key, V value) { + public void put(ICacheKey key, V value) { if (this.cache.size() >= maxSize) { // For simplification this.removalListener.onRemoval(new RemovalNotification<>(key, value, RemovalReason.EVICTED)); } @@ -58,7 +60,7 @@ public void put(K key, V value) { } @Override - public V computeIfAbsent(K key, LoadAwareCacheLoader loader) { + public V computeIfAbsent(ICacheKey key, LoadAwareCacheLoader, V> loader) { V value = cache.computeIfAbsent(key, key1 -> { try { return loader.load(key); @@ -70,7 +72,7 @@ public V computeIfAbsent(K key, LoadAwareCacheLoader loader) { } @Override - public void invalidate(K key) { + public void invalidate(ICacheKey key) { this.cache.remove(key); } @@ -80,7 +82,7 @@ public void invalidateAll() { } @Override - public Iterable keys() { + public Iterable> keys() { return () -> new CacheKeyIterator<>(cache, removalListener); } @@ -92,6 +94,11 @@ public long count() { @Override public void refresh() {} + @Override + public ImmutableCacheStatsHolder stats() { + return null; + } + @Override public void close() { diff --git a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java index 431aca51099a6..bf9f8fd22d793 100644 --- a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java +++ b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java @@ -10,6 +10,7 @@ import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; @@ -45,6 +46,8 @@ import static org.opensearch.common.cache.store.settings.OpenSearchOnHeapCacheSettings.MAXIMUM_SIZE_IN_BYTES_KEY; public class TieredSpilloverCacheTests extends OpenSearchTestCase { + // TODO: TSC stats impl is in a future PR. Parts of tests which use stats values are missing for now. + static final List dimensionNames = List.of("dim1", "dim2", "dim3"); private ClusterSettings clusterSettings; @@ -60,7 +63,7 @@ public void testComputeIfAbsentWithoutAnyOnHeapCacheEviction() throws Exception int keyValueSize = 50; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, randomIntBetween(1, 4), removalListener, @@ -75,12 +78,12 @@ public void testComputeIfAbsentWithoutAnyOnHeapCacheEviction() throws Exception 0 ); int numOfItems1 = randomIntBetween(1, onHeapCacheSize / 2 - 1); - List keys = new ArrayList<>(); + List> keys = new ArrayList<>(); // Put values in cache. for (int iter = 0; iter < numOfItems1; iter++) { - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); keys.add(key); - LoadAwareCacheLoader tieredCacheLoader = getLoadAwareCacheLoader(); + LoadAwareCacheLoader, String> tieredCacheLoader = getLoadAwareCacheLoader(); tieredSpilloverCache.computeIfAbsent(key, tieredCacheLoader); } assertEquals(0, removalListener.evictionsMetric.count()); @@ -97,7 +100,7 @@ public void testComputeIfAbsentWithoutAnyOnHeapCacheEviction() throws Exception tieredSpilloverCache.computeIfAbsent(keys.get(index), getLoadAwareCacheLoader()); } else { // Hit cache with randomized key which is expected to miss cache always. - tieredSpilloverCache.computeIfAbsent(UUID.randomUUID().toString(), getLoadAwareCacheLoader()); + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), getLoadAwareCacheLoader()); cacheMiss++; } } @@ -145,6 +148,7 @@ public void testComputeIfAbsentWithFactoryBasedCacheCreation() throws Exception .setWeigher((k, v) -> keyValueSize) .setRemovalListener(removalListener) .setSettings(settings) + .setDimensionNames(dimensionNames) .setCachedResultParser(s -> new CachedQueryResult.PolicyValues(20_000_000L)) // Values will always appear to have taken // 20_000_000 ns = 20 ms to compute .setClusterSettings(clusterSettings) @@ -161,11 +165,15 @@ public void testComputeIfAbsentWithFactoryBasedCacheCreation() throws Exception TieredSpilloverCache tieredSpilloverCache = (TieredSpilloverCache) tieredSpilloverICache; int numOfItems1 = randomIntBetween(onHeapCacheSize + 1, totalSize); + List> onHeapKeys = new ArrayList<>(); + List> diskTierKeys = new ArrayList<>(); for (int iter = 0; iter < numOfItems1; iter++) { String key = UUID.randomUUID().toString(); - LoadAwareCacheLoader tieredCacheLoader = getLoadAwareCacheLoader(); - tieredSpilloverCache.computeIfAbsent(key, tieredCacheLoader); + LoadAwareCacheLoader, String> tieredCacheLoader = getLoadAwareCacheLoader(); + tieredSpilloverCache.computeIfAbsent(getICacheKey(key), tieredCacheLoader); } + tieredSpilloverCache.getOnHeapCache().keys().forEach(onHeapKeys::add); + tieredSpilloverCache.getDiskCache().keys().forEach(diskTierKeys::add); // Verify on heap cache size. assertEquals(onHeapCacheSize, tieredSpilloverCache.getOnHeapCache().count()); // Verify disk cache size. @@ -278,6 +286,7 @@ public void testComputeIfAbsentWithEvictionsFromOnHeapCache() throws Exception { .setKeyType(String.class) .setWeigher((k, v) -> keyValueSize) .setRemovalListener(removalListener) + .setDimensionNames(dimensionNames) .setSettings( Settings.builder() .put( @@ -307,20 +316,17 @@ public void testComputeIfAbsentWithEvictionsFromOnHeapCache() throws Exception { // Put values in cache more than it's size and cause evictions from onHeap. int numOfItems1 = randomIntBetween(onHeapCacheSize + 1, totalSize); - List onHeapKeys = new ArrayList<>(); - List diskTierKeys = new ArrayList<>(); + List> onHeapKeys = new ArrayList<>(); + List> diskTierKeys = new ArrayList<>(); for (int iter = 0; iter < numOfItems1; iter++) { - String key = UUID.randomUUID().toString(); - LoadAwareCacheLoader tieredCacheLoader = getLoadAwareCacheLoader(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); + LoadAwareCacheLoader, String> tieredCacheLoader = getLoadAwareCacheLoader(); tieredSpilloverCache.computeIfAbsent(key, tieredCacheLoader); } tieredSpilloverCache.getOnHeapCache().keys().forEach(onHeapKeys::add); tieredSpilloverCache.getDiskCache().keys().forEach(diskTierKeys::add); - assertEquals(tieredSpilloverCache.getOnHeapCache().count(), onHeapKeys.size()); - assertEquals(tieredSpilloverCache.getDiskCache().count(), diskTierKeys.size()); - // Try to hit cache again with some randomization. int numOfItems2 = randomIntBetween(50, 200); int onHeapCacheHit = 0; @@ -330,21 +336,21 @@ public void testComputeIfAbsentWithEvictionsFromOnHeapCache() throws Exception { if (randomBoolean()) { // Hit cache with key stored in onHeap cache. onHeapCacheHit++; int index = randomIntBetween(0, onHeapKeys.size() - 1); - LoadAwareCacheLoader loadAwareCacheLoader = getLoadAwareCacheLoader(); + LoadAwareCacheLoader, String> loadAwareCacheLoader = getLoadAwareCacheLoader(); tieredSpilloverCache.computeIfAbsent(onHeapKeys.get(index), loadAwareCacheLoader); assertFalse(loadAwareCacheLoader.isLoaded()); } else { // Hit cache with key stored in disk cache. diskCacheHit++; int index = randomIntBetween(0, diskTierKeys.size() - 1); - LoadAwareCacheLoader loadAwareCacheLoader = getLoadAwareCacheLoader(); + LoadAwareCacheLoader, String> loadAwareCacheLoader = getLoadAwareCacheLoader(); tieredSpilloverCache.computeIfAbsent(diskTierKeys.get(index), loadAwareCacheLoader); assertFalse(loadAwareCacheLoader.isLoaded()); } } for (int iter = 0; iter < randomIntBetween(50, 200); iter++) { // Hit cache with randomized key which is expected to miss cache always. - LoadAwareCacheLoader tieredCacheLoader = getLoadAwareCacheLoader(); - tieredSpilloverCache.computeIfAbsent(UUID.randomUUID().toString(), tieredCacheLoader); + LoadAwareCacheLoader, String> tieredCacheLoader = getLoadAwareCacheLoader(); + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), tieredCacheLoader); cacheMiss++; } } @@ -356,7 +362,7 @@ public void testComputeIfAbsentWithEvictionsFromTieredCache() throws Exception { int keyValueSize = 50; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -372,9 +378,10 @@ public void testComputeIfAbsentWithEvictionsFromTieredCache() throws Exception { ); int numOfItems = randomIntBetween(totalSize + 1, totalSize * 3); for (int iter = 0; iter < numOfItems; iter++) { - LoadAwareCacheLoader tieredCacheLoader = getLoadAwareCacheLoader(); - tieredSpilloverCache.computeIfAbsent(UUID.randomUUID().toString(), tieredCacheLoader); + LoadAwareCacheLoader, String> tieredCacheLoader = getLoadAwareCacheLoader(); + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), tieredCacheLoader); } + int evictions = numOfItems - (totalSize); assertEquals(evictions, removalListener.evictionsMetric.count()); } @@ -386,7 +393,7 @@ public void testGetAndCount() throws Exception { int totalSize = onHeapCacheSize + diskCacheSize; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -402,17 +409,17 @@ public void testGetAndCount() throws Exception { ); int numOfItems1 = randomIntBetween(onHeapCacheSize + 1, totalSize); - List onHeapKeys = new ArrayList<>(); - List diskTierKeys = new ArrayList<>(); + List> onHeapKeys = new ArrayList<>(); + List> diskTierKeys = new ArrayList<>(); for (int iter = 0; iter < numOfItems1; iter++) { - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); if (iter > (onHeapCacheSize - 1)) { // All these are bound to go to disk based cache. diskTierKeys.add(key); } else { onHeapKeys.add(key); } - LoadAwareCacheLoader loadAwareCacheLoader = getLoadAwareCacheLoader(); + LoadAwareCacheLoader, String> loadAwareCacheLoader = getLoadAwareCacheLoader(); tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader); } @@ -426,7 +433,7 @@ public void testGetAndCount() throws Exception { assertNotNull(tieredSpilloverCache.get(diskTierKeys.get(index))); } } else { - assertNull(tieredSpilloverCache.get(UUID.randomUUID().toString())); + assertNull(tieredSpilloverCache.get(getICacheKey(UUID.randomUUID().toString()))); } } assertEquals(numOfItems1, tieredSpilloverCache.count()); @@ -438,7 +445,7 @@ public void testPut() { int keyValueSize = 50; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -452,10 +459,9 @@ public void testPut() { .build(), 0 ); - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); String value = UUID.randomUUID().toString(); tieredSpilloverCache.put(key, value); - assertEquals(1, tieredSpilloverCache.count()); } public void testPutAndVerifyNewItemsArePresentOnHeapCache() throws Exception { @@ -465,7 +471,7 @@ public void testPutAndVerifyNewItemsArePresentOnHeapCache() throws Exception { MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -485,52 +491,27 @@ public void testPutAndVerifyNewItemsArePresentOnHeapCache() throws Exception { ); for (int i = 0; i < onHeapCacheSize; i++) { - tieredSpilloverCache.computeIfAbsent(UUID.randomUUID().toString(), new LoadAwareCacheLoader<>() { - @Override - public boolean isLoaded() { - return false; - } - - @Override - public String load(String key) { - return UUID.randomUUID().toString(); - } - }); + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), getLoadAwareCacheLoader()); } - assertEquals(onHeapCacheSize, tieredSpilloverCache.getOnHeapCache().count()); - assertEquals(0, tieredSpilloverCache.getDiskCache().count()); - // Again try to put OnHeap cache capacity amount of new items. - List newKeyList = new ArrayList<>(); + List> newKeyList = new ArrayList<>(); for (int i = 0; i < onHeapCacheSize; i++) { - newKeyList.add(UUID.randomUUID().toString()); + newKeyList.add(getICacheKey(UUID.randomUUID().toString())); } for (int i = 0; i < newKeyList.size(); i++) { - tieredSpilloverCache.computeIfAbsent(newKeyList.get(i), new LoadAwareCacheLoader<>() { - @Override - public boolean isLoaded() { - return false; - } - - @Override - public String load(String key) { - return UUID.randomUUID().toString(); - } - }); + tieredSpilloverCache.computeIfAbsent(newKeyList.get(i), getLoadAwareCacheLoader()); } // Verify that new items are part of onHeap cache. - List actualOnHeapCacheKeys = new ArrayList<>(); + List> actualOnHeapCacheKeys = new ArrayList<>(); tieredSpilloverCache.getOnHeapCache().keys().forEach(actualOnHeapCacheKeys::add); assertEquals(newKeyList.size(), actualOnHeapCacheKeys.size()); for (int i = 0; i < actualOnHeapCacheKeys.size(); i++) { assertTrue(newKeyList.contains(actualOnHeapCacheKeys.get(i))); } - assertEquals(onHeapCacheSize, tieredSpilloverCache.getOnHeapCache().count()); - assertEquals(onHeapCacheSize, tieredSpilloverCache.getDiskCache().count()); } public void testInvalidate() { @@ -539,7 +520,7 @@ public void testInvalidate() { int keyValueSize = 20; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -553,24 +534,29 @@ public void testInvalidate() { .build(), 0 ); - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); String value = UUID.randomUUID().toString(); // First try to invalidate without the key present in cache. tieredSpilloverCache.invalidate(key); + // assertEquals(0, tieredSpilloverCache.stats().getEvictionsByDimensions(HEAP_DIMS)); // Now try to invalidate with the key present in onHeap cache. tieredSpilloverCache.put(key, value); tieredSpilloverCache.invalidate(key); + // Evictions metric shouldn't increase for invalidations. assertEquals(0, tieredSpilloverCache.count()); tieredSpilloverCache.put(key, value); // Put another key/value so that one of the item is evicted to disk cache. - String key2 = UUID.randomUUID().toString(); + ICacheKey key2 = getICacheKey(UUID.randomUUID().toString()); tieredSpilloverCache.put(key2, UUID.randomUUID().toString()); + assertEquals(2, tieredSpilloverCache.count()); - // Again invalidate older key + + // Again invalidate older key, leaving one in heap tier and zero in disk tier tieredSpilloverCache.invalidate(key); assertEquals(1, tieredSpilloverCache.count()); + } public void testCacheKeys() throws Exception { @@ -579,7 +565,7 @@ public void testCacheKeys() throws Exception { int keyValueSize = 50; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -593,46 +579,46 @@ public void testCacheKeys() throws Exception { .build(), 0 ); - List onHeapKeys = new ArrayList<>(); - List diskTierKeys = new ArrayList<>(); + List> onHeapKeys = new ArrayList<>(); + List> diskTierKeys = new ArrayList<>(); // During first round add onHeapCacheSize entries. Will go to onHeap cache initially. for (int i = 0; i < onHeapCacheSize; i++) { - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); diskTierKeys.add(key); tieredSpilloverCache.computeIfAbsent(key, getLoadAwareCacheLoader()); } // In another round, add another onHeapCacheSize entries. These will go to onHeap and above ones will be // evicted to onDisk cache. for (int i = 0; i < onHeapCacheSize; i++) { - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); onHeapKeys.add(key); tieredSpilloverCache.computeIfAbsent(key, getLoadAwareCacheLoader()); } - List actualOnHeapKeys = new ArrayList<>(); - List actualOnDiskKeys = new ArrayList<>(); - Iterable onHeapiterable = tieredSpilloverCache.getOnHeapCache().keys(); - Iterable onDiskiterable = tieredSpilloverCache.getDiskCache().keys(); + List> actualOnHeapKeys = new ArrayList<>(); + List> actualOnDiskKeys = new ArrayList<>(); + Iterable> onHeapiterable = tieredSpilloverCache.getOnHeapCache().keys(); + Iterable> onDiskiterable = tieredSpilloverCache.getDiskCache().keys(); onHeapiterable.iterator().forEachRemaining(actualOnHeapKeys::add); onDiskiterable.iterator().forEachRemaining(actualOnDiskKeys::add); - for (String onHeapKey : onHeapKeys) { + for (ICacheKey onHeapKey : onHeapKeys) { assertTrue(actualOnHeapKeys.contains(onHeapKey)); } - for (String onDiskKey : actualOnDiskKeys) { + for (ICacheKey onDiskKey : actualOnDiskKeys) { assertTrue(actualOnDiskKeys.contains(onDiskKey)); } // Testing keys() which returns all keys. - List actualMergedKeys = new ArrayList<>(); - List expectedMergedKeys = new ArrayList<>(); + List> actualMergedKeys = new ArrayList<>(); + List> expectedMergedKeys = new ArrayList<>(); expectedMergedKeys.addAll(onHeapKeys); expectedMergedKeys.addAll(diskTierKeys); - Iterable mergedIterable = tieredSpilloverCache.keys(); + Iterable> mergedIterable = tieredSpilloverCache.keys(); mergedIterable.iterator().forEachRemaining(actualMergedKeys::add); assertEquals(expectedMergedKeys.size(), actualMergedKeys.size()); - for (String key : expectedMergedKeys) { + for (ICacheKey key : expectedMergedKeys) { assertTrue(actualMergedKeys.contains(key)); } } @@ -641,7 +627,7 @@ public void testRefresh() { int diskCacheSize = randomIntBetween(60, 100); MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( 50, diskCacheSize, removalListener, @@ -658,7 +644,7 @@ public void testInvalidateAll() throws Exception { int totalSize = onHeapCacheSize + diskCacheSize; MockCacheRemovalListener removalListener = new MockCacheRemovalListener<>(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -674,17 +660,17 @@ public void testInvalidateAll() throws Exception { ); // Put values in cache more than it's size and cause evictions from onHeap. int numOfItems1 = randomIntBetween(onHeapCacheSize + 1, totalSize); - List onHeapKeys = new ArrayList<>(); - List diskTierKeys = new ArrayList<>(); + List> onHeapKeys = new ArrayList<>(); + List> diskTierKeys = new ArrayList<>(); for (int iter = 0; iter < numOfItems1; iter++) { - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); if (iter > (onHeapCacheSize - 1)) { // All these are bound to go to disk based cache. diskTierKeys.add(key); } else { onHeapKeys.add(key); } - LoadAwareCacheLoader tieredCacheLoader = getLoadAwareCacheLoader(); + LoadAwareCacheLoader, String> tieredCacheLoader = getLoadAwareCacheLoader(); tieredSpilloverCache.computeIfAbsent(key, tieredCacheLoader); } assertEquals(numOfItems1, tieredSpilloverCache.count()); @@ -707,7 +693,7 @@ public void testComputeIfAbsentConcurrently() throws Exception { ) .build(); - TieredSpilloverCache tieredSpilloverCache = intializeTieredSpilloverCache( + TieredSpilloverCache tieredSpilloverCache = initializeTieredSpilloverCache( keyValueSize, diskCacheSize, removalListener, @@ -716,19 +702,19 @@ public void testComputeIfAbsentConcurrently() throws Exception { ); int numberOfSameKeys = randomIntBetween(10, onHeapCacheSize - 1); - String key = UUID.randomUUID().toString(); + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); String value = UUID.randomUUID().toString(); Thread[] threads = new Thread[numberOfSameKeys]; Phaser phaser = new Phaser(numberOfSameKeys + 1); CountDownLatch countDownLatch = new CountDownLatch(numberOfSameKeys); // To wait for all threads to finish. - List> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); + List, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); for (int i = 0; i < numberOfSameKeys; i++) { threads[i] = new Thread(() -> { try { - LoadAwareCacheLoader loadAwareCacheLoader = new LoadAwareCacheLoader<>() { + LoadAwareCacheLoader, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() { boolean isLoaded = false; @Override @@ -737,7 +723,7 @@ public boolean isLoaded() { } @Override - public String load(String key) { + public String load(ICacheKey key) { isLoaded = true; return value; } @@ -757,7 +743,7 @@ public String load(String key) { int numberOfTimesKeyLoaded = 0; assertEquals(numberOfSameKeys, loadAwareCacheLoaderList.size()); for (int i = 0; i < loadAwareCacheLoaderList.size(); i++) { - LoadAwareCacheLoader loader = loadAwareCacheLoaderList.get(i); + LoadAwareCacheLoader, String> loader = loadAwareCacheLoaderList.get(i); if (loader.isLoaded()) { numberOfTimesKeyLoaded++; } @@ -791,6 +777,7 @@ public void testConcurrencyForEvictionFlowFromOnHeapToDiskTier() throws Exceptio ) .build() ) + .setDimensionNames(dimensionNames) .build(); TieredSpilloverCache tieredSpilloverCache = new TieredSpilloverCache.Builder() .setOnHeapCacheFactory(onHeapCacheFactory) @@ -800,26 +787,17 @@ public void testConcurrencyForEvictionFlowFromOnHeapToDiskTier() throws Exceptio .setCacheType(CacheType.INDICES_REQUEST_CACHE) .build(); - String keyToBeEvicted = "key1"; - String secondKey = "key2"; + ICacheKey keyToBeEvicted = getICacheKey("key1"); + ICacheKey secondKey = getICacheKey("key2"); // Put first key on tiered cache. Will go into onHeap cache. - tieredSpilloverCache.computeIfAbsent(keyToBeEvicted, new LoadAwareCacheLoader<>() { - @Override - public boolean isLoaded() { - return false; - } - - @Override - public String load(String key) { - return UUID.randomUUID().toString(); - } - }); + tieredSpilloverCache.computeIfAbsent(keyToBeEvicted, getLoadAwareCacheLoader()); + // assertEquals(1, tieredSpilloverCache.stats().getEntriesByDimensions(HEAP_DIMS)); CountDownLatch countDownLatch = new CountDownLatch(1); CountDownLatch countDownLatch1 = new CountDownLatch(1); // Put second key on tiered cache. Will cause eviction of first key from onHeap cache and should go into // disk cache. - LoadAwareCacheLoader loadAwareCacheLoader = getLoadAwareCacheLoader(); + LoadAwareCacheLoader, String> loadAwareCacheLoader = getLoadAwareCacheLoader(); Thread thread = new Thread(() -> { try { tieredSpilloverCache.computeIfAbsent(secondKey, loadAwareCacheLoader); @@ -830,7 +808,7 @@ public String load(String key) { }); thread.start(); assertBusy(() -> { assertTrue(loadAwareCacheLoader.isLoaded()); }, 100, TimeUnit.MILLISECONDS); // We wait for new key to be loaded - // after which it eviction flow is + // after which it eviction flow is // guaranteed to occur. ICache onDiskCache = tieredSpilloverCache.getDiskCache(); @@ -849,20 +827,12 @@ public String load(String key) { countDownLatch.await(); assertNotNull(actualValue.get()); countDownLatch1.await(); + assertEquals(1, tieredSpilloverCache.getOnHeapCache().count()); assertEquals(1, onDiskCache.count()); assertNotNull(onDiskCache.get(keyToBeEvicted)); } - class MockCacheRemovalListener implements RemovalListener { - final CounterMetric evictionsMetric = new CounterMetric(); - - @Override - public void onRemoval(RemovalNotification notification) { - evictionsMetric.inc(); - } - } - public void testDiskTierPolicies() throws Exception { // For policy function, allow if what it receives starts with "a" and string is even length ArrayList> policies = new ArrayList<>(); @@ -901,26 +871,14 @@ public void testDiskTierPolicies() throws Exception { keyValuePairs.put("key5", ""); expectedOutputs.put("key5", false); - LoadAwareCacheLoader loader = new LoadAwareCacheLoader() { - boolean isLoaded = false; - - @Override - public boolean isLoaded() { - return isLoaded; - } - - @Override - public String load(String key) throws Exception { - isLoaded = true; - return keyValuePairs.get(key); - } - }; + LoadAwareCacheLoader, String> loader = getLoadAwareCacheLoader(keyValuePairs); for (String key : keyValuePairs.keySet()) { + ICacheKey iCacheKey = getICacheKey(key); Boolean expectedOutput = expectedOutputs.get(key); - String value = tieredSpilloverCache.computeIfAbsent(key, loader); + String value = tieredSpilloverCache.computeIfAbsent(iCacheKey, loader); assertEquals(keyValuePairs.get(key), value); - String result = tieredSpilloverCache.get(key); + String result = tieredSpilloverCache.get(iCacheKey); if (expectedOutput) { // Should retrieve from disk tier if it was accepted assertEquals(keyValuePairs.get(key), result); @@ -985,6 +943,7 @@ public void testTookTimePolicyFromFactory() throws Exception { .setRemovalListener(removalListener) .setSettings(settings) .setMaxSizeInBytes(onHeapCacheSize * keyValueSize) + .setDimensionNames(dimensionNames) .setCachedResultParser(new Function() { @Override public CachedQueryResult.PolicyValues apply(String s) { @@ -1006,22 +965,22 @@ public CachedQueryResult.PolicyValues apply(String s) { // First add all our values to the on heap cache for (String key : tookTimeMap.keySet()) { - tieredSpilloverCache.computeIfAbsent(key, getLoadAwareCacheLoader(keyValueMap)); + tieredSpilloverCache.computeIfAbsent(getICacheKey(key), getLoadAwareCacheLoader(keyValueMap)); } assertEquals(tookTimeMap.size(), tieredSpilloverCache.count()); // Ensure all these keys get evicted from the on heap tier by adding > heap tier size worth of random keys for (int i = 0; i < onHeapCacheSize; i++) { - tieredSpilloverCache.computeIfAbsent(UUID.randomUUID().toString(), getLoadAwareCacheLoader(keyValueMap)); + tieredSpilloverCache.computeIfAbsent(getICacheKey(UUID.randomUUID().toString()), getLoadAwareCacheLoader(keyValueMap)); } ICache onHeapCache = tieredSpilloverCache.getOnHeapCache(); for (String key : tookTimeMap.keySet()) { - assertNull(onHeapCache.get(key)); + assertNull(onHeapCache.get(getICacheKey(key))); } // Now the original keys should be in the disk tier if the policy allows them, or misses if not for (String key : tookTimeMap.keySet()) { - String computedValue = tieredSpilloverCache.get(key); + String computedValue = tieredSpilloverCache.get(getICacheKey(key)); String mapValue = keyValueMap.get(key); Long tookTime = tookTimeMap.get(mapValue); if (tookTime != null && tookTime > timeValueThresholdNanos) { @@ -1049,6 +1008,27 @@ public void testMinimumThresholdSettingValue() throws Exception { assertEquals(validDuration, concreteSetting.get(validSettings)); } + private List getMockDimensions() { + List dims = new ArrayList<>(); + for (String dimensionName : dimensionNames) { + dims.add("0"); + } + return dims; + } + + private ICacheKey getICacheKey(String key) { + return new ICacheKey<>(key, getMockDimensions()); + } + + class MockCacheRemovalListener implements RemovalListener, V> { + final CounterMetric evictionsMetric = new CounterMetric(); + + @Override + public void onRemoval(RemovalNotification, V> notification) { + evictionsMetric.inc(); + } + } + private static class AllowFirstLetterA implements Predicate { @Override public boolean test(String data) { @@ -1067,12 +1047,12 @@ public boolean test(String data) { } } - private LoadAwareCacheLoader getLoadAwareCacheLoader() { + private LoadAwareCacheLoader, String> getLoadAwareCacheLoader() { return new LoadAwareCacheLoader<>() { boolean isLoaded = false; @Override - public String load(String key) { + public String load(ICacheKey key) { isLoaded = true; return UUID.randomUUID().toString(); } @@ -1084,14 +1064,14 @@ public boolean isLoaded() { }; } - private LoadAwareCacheLoader getLoadAwareCacheLoader(Map keyValueMap) { + private LoadAwareCacheLoader, String> getLoadAwareCacheLoader(Map keyValueMap) { return new LoadAwareCacheLoader<>() { boolean isLoaded = false; @Override - public String load(String key) { + public String load(ICacheKey key) { isLoaded = true; - String mapValue = keyValueMap.get(key); + String mapValue = keyValueMap.get(key.key); if (mapValue == null) { mapValue = UUID.randomUUID().toString(); } @@ -1105,10 +1085,10 @@ public boolean isLoaded() { }; } - private TieredSpilloverCache intializeTieredSpilloverCache( + private TieredSpilloverCache initializeTieredSpilloverCache( int keyValueSize, int diskCacheSize, - RemovalListener removalListener, + RemovalListener, String> removalListener, Settings settings, long diskDeliberateDelay @@ -1119,7 +1099,7 @@ private TieredSpilloverCache intializeTieredSpilloverCache( private TieredSpilloverCache intializeTieredSpilloverCache( int keyValueSize, int diskCacheSize, - RemovalListener removalListener, + RemovalListener, String> removalListener, Settings settings, long diskDeliberateDelay, List> policies @@ -1128,6 +1108,8 @@ private TieredSpilloverCache intializeTieredSpilloverCache( CacheConfig cacheConfig = new CacheConfig.Builder().setKeyType(String.class) .setKeyType(String.class) .setWeigher((k, v) -> keyValueSize) + .setSettings(settings) + .setDimensionNames(dimensionNames) .setRemovalListener(removalListener) .setSettings( Settings.builder() diff --git a/plugins/cache-ehcache/src/main/java/org/opensearch/cache/store/disk/EhcacheDiskCache.java b/plugins/cache-ehcache/src/main/java/org/opensearch/cache/store/disk/EhcacheDiskCache.java index edb2c900be46c..185d51732a116 100644 --- a/plugins/cache-ehcache/src/main/java/org/opensearch/cache/store/disk/EhcacheDiskCache.java +++ b/plugins/cache-ehcache/src/main/java/org/opensearch/cache/store/disk/EhcacheDiskCache.java @@ -17,15 +17,18 @@ import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.RemovalReason; +import org.opensearch.common.cache.serializer.ICacheKeySerializer; import org.opensearch.common.cache.serializer.Serializer; +import org.opensearch.common.cache.stats.CacheStatsHolder; +import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder; import org.opensearch.common.cache.store.builders.ICacheBuilder; import org.opensearch.common.cache.store.config.CacheConfig; import org.opensearch.common.collect.Tuple; -import org.opensearch.common.metrics.CounterMetric; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -40,6 +43,7 @@ import java.time.Duration; import java.util.Arrays; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; @@ -49,6 +53,7 @@ import java.util.concurrent.ExecutionException; import java.util.function.BiFunction; import java.util.function.Supplier; +import java.util.function.ToLongBiFunction; import org.ehcache.Cache; import org.ehcache.CachePersistenceException; @@ -101,21 +106,20 @@ public class EhcacheDiskCache implements ICache { private final PersistentCacheManager cacheManager; // Disk cache. Using ByteArrayWrapper to compare two byte[] by values rather than the default reference checks - private Cache cache; + @SuppressWarnings({ "rawtypes" }) // We have to use the raw type as there's no way to pass the "generic class" to ehcache + private Cache cache; private final long maxWeightInBytes; private final String storagePath; private final Class keyType; private final Class valueType; private final TimeValue expireAfterAccess; + private final CacheStatsHolder cacheStatsHolder; private final EhCacheEventListener ehCacheEventListener; private final String threadPoolAlias; private final Settings settings; - private final RemovalListener removalListener; + private final RemovalListener, V> removalListener; private final CacheType cacheType; private final String diskCacheAlias; - // TODO: Move count to stats once those changes are ready. - private final CounterMetric entries = new CounterMetric(); - private final Serializer keySerializer; private final Serializer valueSerializer; @@ -123,7 +127,7 @@ public class EhcacheDiskCache implements ICache { * Used in computeIfAbsent to synchronize loading of a given key. This is needed as ehcache doesn't provide a * computeIfAbsent method. */ - Map>> completableFutureMap = new ConcurrentHashMap<>(); + Map, CompletableFuture, V>>> completableFutureMap = new ConcurrentHashMap<>(); private EhcacheDiskCache(Builder builder) { this.keyType = Objects.requireNonNull(builder.keyType, "Key type shouldn't be null"); @@ -154,31 +158,39 @@ private EhcacheDiskCache(Builder builder) { this.cacheManager = buildCacheManager(); Objects.requireNonNull(builder.getRemovalListener(), "Removal listener can't be null"); this.removalListener = builder.getRemovalListener(); - this.ehCacheEventListener = new EhCacheEventListener(builder.getRemovalListener()); + Objects.requireNonNull(builder.getWeigher(), "Weigher can't be null"); + this.ehCacheEventListener = new EhCacheEventListener(builder.getRemovalListener(), builder.getWeigher()); this.cache = buildCache(Duration.ofMillis(expireAfterAccess.getMillis()), builder); + List dimensionNames = Objects.requireNonNull(builder.dimensionNames, "Dimension names can't be null"); + this.cacheStatsHolder = new CacheStatsHolder(dimensionNames); } - private Cache buildCache(Duration expireAfterAccess, Builder builder) { + @SuppressWarnings({ "rawtypes" }) + private Cache buildCache(Duration expireAfterAccess, Builder builder) { try { return this.cacheManager.createCache( this.diskCacheAlias, CacheConfigurationBuilder.newCacheConfigurationBuilder( - this.keyType, + ICacheKey.class, ByteArrayWrapper.class, ResourcePoolsBuilder.newResourcePoolsBuilder().disk(maxWeightInBytes, MemoryUnit.B) ).withExpiry(new ExpiryPolicy<>() { @Override - public Duration getExpiryForCreation(K key, ByteArrayWrapper value) { + public Duration getExpiryForCreation(ICacheKey key, ByteArrayWrapper value) { return INFINITE; } @Override - public Duration getExpiryForAccess(K key, Supplier value) { + public Duration getExpiryForAccess(ICacheKey key, Supplier value) { return expireAfterAccess; } @Override - public Duration getExpiryForUpdate(K key, Supplier oldValue, ByteArrayWrapper newValue) { + public Duration getExpiryForUpdate( + ICacheKey key, + Supplier oldValue, + ByteArrayWrapper newValue + ) { return INFINITE; } }) @@ -192,7 +204,7 @@ public Duration getExpiryForUpdate(K key, Supplier o (Integer) EhcacheDiskCacheSettings.getSettingListForCacheType(cacheType).get(DISK_SEGMENT_KEY).get(settings) ) ) - .withKeySerializer(new KeySerializerWrapper(keySerializer)) + .withKeySerializer(new KeySerializerWrapper(keySerializer)) .withValueSerializer(new ByteArrayWrapperSerializer()) // We pass ByteArrayWrapperSerializer as ehcache's value serializer. If V is an interface, and we pass its // serializer directly to ehcache, ehcache requires the classes match exactly before/after serialization. @@ -225,7 +237,7 @@ private CacheEventListenerConfigurationBuilder getListenerConfiguration(Builder< } // Package private for testing - Map>> getCompletableFutureMap() { + Map, CompletableFuture, V>>> getCompletableFutureMap() { return completableFutureMap; } @@ -254,7 +266,7 @@ private PersistentCacheManager buildCacheManager() { } @Override - public V get(K key) { + public V get(ICacheKey key) { if (key == null) { throw new IllegalArgumentException("Key passed to ehcache disk cache was null."); } @@ -264,6 +276,11 @@ public V get(K key) { } catch (CacheLoadingException ex) { throw new OpenSearchException("Exception occurred while trying to fetch item from ehcache disk cache"); } + if (value != null) { + cacheStatsHolder.incrementHits(key.dimensions); + } else { + cacheStatsHolder.incrementMisses(key.dimensions); + } return value; } @@ -273,7 +290,7 @@ public V get(K key) { * @param value Type of value. */ @Override - public void put(K key, V value) { + public void put(ICacheKey key, V value) { try { cache.put(key, serializeValue(value)); } catch (CacheWritingException ex) { @@ -289,26 +306,31 @@ public void put(K key, V value) { * @throws Exception when either internal get or put calls fail. */ @Override - public V computeIfAbsent(K key, LoadAwareCacheLoader loader) throws Exception { - // Ehache doesn't provide any computeIfAbsent function. Exposes putIfAbsent but that works differently and is + public V computeIfAbsent(ICacheKey key, LoadAwareCacheLoader, V> loader) throws Exception { + // Ehcache doesn't provide any computeIfAbsent function. Exposes putIfAbsent but that works differently and is // not performant in case there are multiple concurrent request for same key. Below is our own custom // implementation of computeIfAbsent on top of ehcache. Inspired by OpenSearch Cache implementation. V value = deserializeValue(cache.get(key)); if (value == null) { value = compute(key, loader); } + if (!loader.isLoaded()) { + cacheStatsHolder.incrementHits(key.dimensions); + } else { + cacheStatsHolder.incrementMisses(key.dimensions); + } return value; } - private V compute(K key, LoadAwareCacheLoader loader) throws Exception { + private V compute(ICacheKey key, LoadAwareCacheLoader, V> loader) throws Exception { // A future that returns a pair of key/value. - CompletableFuture> completableFuture = new CompletableFuture<>(); + CompletableFuture, V>> completableFuture = new CompletableFuture<>(); // Only one of the threads will succeed putting a future into map for the same key. // Rest will fetch existing future. - CompletableFuture> future = completableFutureMap.putIfAbsent(key, completableFuture); + CompletableFuture, V>> future = completableFutureMap.putIfAbsent(key, completableFuture); // Handler to handle results post processing. Takes a tuple or exception as an input and returns // the value. Also before returning value, puts the value in cache. - BiFunction, Throwable, V> handler = (pair, ex) -> { + BiFunction, V>, Throwable, V> handler = (pair, ex) -> { V value = null; if (pair != null) { cache.put(pair.v1(), serializeValue(pair.v2())); @@ -358,9 +380,14 @@ private V compute(K key, LoadAwareCacheLoader loader) throws Exception { * @param key key to be invalidated. */ @Override - public void invalidate(K key) { + public void invalidate(ICacheKey key) { try { - cache.remove(key); + if (key.getDropStatsForDimensions()) { + cacheStatsHolder.removeDimensions(key.dimensions); + } + if (key.key != null) { + cache.remove(key); + } } catch (CacheWritingException ex) { // Handle throw new RuntimeException(ex); @@ -371,7 +398,7 @@ public void invalidate(K key) { @Override public void invalidateAll() { cache.clear(); - this.entries.dec(this.entries.count()); // reset to zero. + cacheStatsHolder.reset(); } /** @@ -379,7 +406,7 @@ public void invalidateAll() { * @return Iterable */ @Override - public Iterable keys() { + public Iterable> keys() { return () -> new EhCacheKeyIterator<>(cache.iterator()); } @@ -389,7 +416,7 @@ public Iterable keys() { */ @Override public long count() { - return entries.count(); + return cacheStatsHolder.count(); } @Override @@ -416,15 +443,25 @@ public void close() { } } + /** + * Relevant stats for this cache. + * @return CacheStats + */ + @Override + public ImmutableCacheStatsHolder stats() { + return cacheStatsHolder.getImmutableCacheStatsHolder(); + } + /** * This iterator wraps ehCache iterator and only iterates over its keys. * @param Type of key */ - class EhCacheKeyIterator implements Iterator { + @SuppressWarnings({ "rawtypes", "unchecked" }) + class EhCacheKeyIterator implements Iterator> { - Iterator> iterator; + Iterator> iterator; - EhCacheKeyIterator(Iterator> iterator) { + EhCacheKeyIterator(Iterator> iterator) { this.iterator = iterator; } @@ -434,7 +471,7 @@ public boolean hasNext() { } @Override - public K next() { + public ICacheKey next() { if (!hasNext()) { throw new NoSuchElementException(); } @@ -450,43 +487,60 @@ public void remove() { /** * Wrapper over Ehcache original listener to listen to desired events and notify desired subscribers. */ - class EhCacheEventListener implements CacheEventListener { + class EhCacheEventListener implements CacheEventListener, ByteArrayWrapper> { + private final RemovalListener, V> removalListener; + private ToLongBiFunction, V> weigher; - private final RemovalListener removalListener; - - EhCacheEventListener(RemovalListener removalListener) { + EhCacheEventListener(RemovalListener, V> removalListener, ToLongBiFunction, V> weigher) { this.removalListener = removalListener; + this.weigher = weigher; + } + + private long getOldValuePairSize(CacheEvent, ? extends ByteArrayWrapper> event) { + return weigher.applyAsLong(event.getKey(), deserializeValue(event.getOldValue())); + } + + private long getNewValuePairSize(CacheEvent, ? extends ByteArrayWrapper> event) { + return weigher.applyAsLong(event.getKey(), deserializeValue(event.getNewValue())); } @Override - public void onEvent(CacheEvent event) { + public void onEvent(CacheEvent, ? extends ByteArrayWrapper> event) { switch (event.getType()) { case CREATED: - entries.inc(); + cacheStatsHolder.incrementEntries(event.getKey().dimensions); + cacheStatsHolder.incrementSizeInBytes(event.getKey().dimensions, getNewValuePairSize(event)); assert event.getOldValue() == null; break; case EVICTED: this.removalListener.onRemoval( new RemovalNotification<>(event.getKey(), deserializeValue(event.getOldValue()), RemovalReason.EVICTED) ); - entries.dec(); + cacheStatsHolder.decrementEntries(event.getKey().dimensions); + cacheStatsHolder.decrementSizeInBytes(event.getKey().dimensions, getOldValuePairSize(event)); + cacheStatsHolder.incrementEvictions(event.getKey().dimensions); assert event.getNewValue() == null; break; case REMOVED: - entries.dec(); this.removalListener.onRemoval( new RemovalNotification<>(event.getKey(), deserializeValue(event.getOldValue()), RemovalReason.EXPLICIT) ); + cacheStatsHolder.decrementEntries(event.getKey().dimensions); + cacheStatsHolder.decrementSizeInBytes(event.getKey().dimensions, getOldValuePairSize(event)); assert event.getNewValue() == null; break; case EXPIRED: this.removalListener.onRemoval( new RemovalNotification<>(event.getKey(), deserializeValue(event.getOldValue()), RemovalReason.INVALIDATED) ); - entries.dec(); + cacheStatsHolder.decrementEntries(event.getKey().dimensions); + cacheStatsHolder.decrementSizeInBytes(event.getKey().dimensions, getOldValuePairSize(event)); assert event.getNewValue() == null; break; case UPDATED: + long newSize = getNewValuePairSize(event); + long oldSize = getOldValuePairSize(event); + cacheStatsHolder.incrementSizeInBytes(event.getKey().dimensions, newSize - oldSize); break; default: break; @@ -495,13 +549,14 @@ public void onEvent(CacheEvent event) { } /** - * Wrapper over Serializer which is compatible with ehcache's serializer requirements. + * Wrapper over ICacheKeySerializer which is compatible with ehcache's serializer requirements. */ - private class KeySerializerWrapper implements org.ehcache.spi.serialization.Serializer { - private Serializer serializer; + @SuppressWarnings({ "rawtypes", "unchecked" }) + private class KeySerializerWrapper implements org.ehcache.spi.serialization.Serializer { + private ICacheKeySerializer serializer; - public KeySerializerWrapper(Serializer keySerializer) { - this.serializer = keySerializer; + public KeySerializerWrapper(Serializer internalKeySerializer) { + this.serializer = new ICacheKeySerializer<>(internalKeySerializer); } // This constructor must be present, but does not have to work as we are not actually persisting the disk @@ -510,19 +565,19 @@ public KeySerializerWrapper(Serializer keySerializer) { public KeySerializerWrapper(ClassLoader classLoader, FileBasedPersistenceContext persistenceContext) {} @Override - public ByteBuffer serialize(T object) throws SerializerException { + public ByteBuffer serialize(ICacheKey object) throws SerializerException { return ByteBuffer.wrap(serializer.serialize(object)); } @Override - public T read(ByteBuffer binary) throws ClassNotFoundException, SerializerException { + public ICacheKey read(ByteBuffer binary) throws ClassNotFoundException, SerializerException { byte[] arr = new byte[binary.remaining()]; binary.get(arr); return serializer.deserialize(arr); } @Override - public boolean equals(T object, ByteBuffer binary) throws ClassNotFoundException, SerializerException { + public boolean equals(ICacheKey object, ByteBuffer binary) throws ClassNotFoundException, SerializerException { byte[] arr = new byte[binary.remaining()]; binary.get(arr); return serializer.equals(object, arr); @@ -566,8 +621,7 @@ public boolean equals(ByteArrayWrapper object, ByteBuffer binary) throws ClassNo * @return the serialized value */ private ByteArrayWrapper serializeValue(V value) { - ByteArrayWrapper result = new ByteArrayWrapper(valueSerializer.serialize(value)); - return result; + return new ByteArrayWrapper(valueSerializer.serialize(value)); } /** @@ -625,6 +679,8 @@ public ICache create(CacheConfig config, CacheType cacheType, .setValueType(config.getValueType()) .setKeySerializer(keySerializer) .setValueSerializer(valueSerializer) + .setDimensionNames(config.getDimensionNames()) + .setWeigher(config.getWeigher()) .setRemovalListener(config.getRemovalListener()) .setExpireAfterAccess((TimeValue) settingList.get(DISK_CACHE_EXPIRE_AFTER_ACCESS_KEY).get(settings)) .setMaximumWeightInBytes((Long) settingList.get(DISK_MAX_SIZE_IN_BYTES_KEY).get(settings)) @@ -658,6 +714,7 @@ public static class Builder extends ICacheBuilder { private Class keyType; private Class valueType; + private List dimensionNames; private Serializer keySerializer; private Serializer valueSerializer; @@ -736,6 +793,16 @@ public Builder setIsEventListenerModeSync(boolean isEventListenerModeSync) return this; } + /** + * Sets the allowed dimension names for keys that will enter this cache. + * @param dimensionNames A list of dimension names this cache will accept + * @return builder + */ + public Builder setDimensionNames(List dimensionNames) { + this.dimensionNames = dimensionNames; + return this; + } + /** * Sets the key serializer for this cache. * @param keySerializer the key serializer @@ -764,7 +831,7 @@ public EhcacheDiskCache build() { /** * A wrapper over byte[], with equals() that works using Arrays.equals(). - * Necessary due to a bug in Ehcache. + * Necessary due to a limitation in how Ehcache compares byte[]. */ static class ByteArrayWrapper { private final byte[] value; diff --git a/plugins/cache-ehcache/src/test/java/org/opensearch/cache/store/disk/EhCacheDiskCacheTests.java b/plugins/cache-ehcache/src/test/java/org/opensearch/cache/store/disk/EhCacheDiskCacheTests.java index 3a98ad2fef6b1..06ebed08d7525 100644 --- a/plugins/cache-ehcache/src/test/java/org/opensearch/cache/store/disk/EhCacheDiskCacheTests.java +++ b/plugins/cache-ehcache/src/test/java/org/opensearch/cache/store/disk/EhCacheDiskCacheTests.java @@ -14,11 +14,13 @@ import org.opensearch.common.Randomness; import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.serializer.BytesReferenceSerializer; import org.opensearch.common.cache.serializer.Serializer; +import org.opensearch.common.cache.stats.ImmutableCacheStats; import org.opensearch.common.cache.store.config.CacheConfig; import org.opensearch.common.metrics.CounterMetric; import org.opensearch.common.settings.Settings; @@ -43,6 +45,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Phaser; +import java.util.function.ToLongBiFunction; import static org.opensearch.cache.EhcacheDiskCacheSettings.DISK_LISTENER_MODE_SYNC_KEY; import static org.opensearch.cache.EhcacheDiskCacheSettings.DISK_MAX_SIZE_IN_BYTES_KEY; @@ -53,10 +56,12 @@ public class EhCacheDiskCacheTests extends OpenSearchSingleNodeTestCase { private static final int CACHE_SIZE_IN_BYTES = 1024 * 101; + private final String dimensionName = "shardId"; public void testBasicGetAndPut() throws IOException { Settings settings = Settings.builder().build(); MockRemovalListener removalListener = new MockRemovalListener<>(); + ToLongBiFunction, String> weigher = getWeigher(); try (NodeEnvironment env = newNodeEnvironment(settings)) { ICache ehcacheTest = new EhcacheDiskCache.Builder().setThreadPoolAlias("ehcacheTest") .setStoragePath(env.nodePaths()[0].indicesPath.toString() + "/request_cache") @@ -65,32 +70,42 @@ public void testBasicGetAndPut() throws IOException { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(weigher) .build(); int randomKeys = randomIntBetween(10, 100); + long expectedSize = 0; Map keyValueMap = new HashMap<>(); for (int i = 0; i < randomKeys; i++) { keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); } for (Map.Entry entry : keyValueMap.entrySet()) { - ehcacheTest.put(entry.getKey(), entry.getValue()); + ICacheKey iCacheKey = getICacheKey(entry.getKey()); + ehcacheTest.put(iCacheKey, entry.getValue()); + expectedSize += weigher.applyAsLong(iCacheKey, entry.getValue()); } for (Map.Entry entry : keyValueMap.entrySet()) { - String value = ehcacheTest.get(entry.getKey()); + String value = ehcacheTest.get(getICacheKey(entry.getKey())); assertEquals(entry.getValue(), value); } + assertEquals(randomKeys, ehcacheTest.stats().getTotalEntries()); + assertEquals(randomKeys, ehcacheTest.stats().getTotalHits()); + assertEquals(expectedSize, ehcacheTest.stats().getTotalSizeInBytes()); assertEquals(randomKeys, ehcacheTest.count()); // Validate misses int expectedNumberOfMisses = randomIntBetween(10, 200); for (int i = 0; i < expectedNumberOfMisses; i++) { - ehcacheTest.get(UUID.randomUUID().toString()); + ehcacheTest.get(getICacheKey(UUID.randomUUID().toString())); } + assertEquals(expectedNumberOfMisses, ehcacheTest.stats().getTotalMisses()); + ehcacheTest.close(); } } @@ -105,6 +120,8 @@ public void testBasicGetAndPutUsingFactory() throws IOException { .setRemovalListener(removalListener) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) + .setWeigher(getWeigher()) .setSettings( Settings.builder() .put( @@ -132,14 +149,14 @@ public void testBasicGetAndPutUsingFactory() throws IOException { Map.of() ); int randomKeys = randomIntBetween(10, 100); - Map keyValueMap = new HashMap<>(); + Map, String> keyValueMap = new HashMap<>(); for (int i = 0; i < randomKeys; i++) { - keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + keyValueMap.put(getICacheKey(UUID.randomUUID().toString()), UUID.randomUUID().toString()); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { ehcacheTest.put(entry.getKey(), entry.getValue()); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { String value = ehcacheTest.get(entry.getKey()); assertEquals(entry.getValue(), value); } @@ -148,7 +165,7 @@ public void testBasicGetAndPutUsingFactory() throws IOException { // Validate misses int expectedNumberOfMisses = randomIntBetween(10, 200); for (int i = 0; i < expectedNumberOfMisses; i++) { - ehcacheTest.get(UUID.randomUUID().toString()); + ehcacheTest.get(getICacheKey(UUID.randomUUID().toString())); } ehcacheTest.close(); @@ -167,22 +184,24 @@ public void testConcurrentPut() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int randomKeys = randomIntBetween(20, 100); Thread[] threads = new Thread[randomKeys]; Phaser phaser = new Phaser(randomKeys + 1); CountDownLatch countDownLatch = new CountDownLatch(randomKeys); - Map keyValueMap = new HashMap<>(); + Map, String> keyValueMap = new HashMap<>(); int j = 0; for (int i = 0; i < randomKeys; i++) { - keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + keyValueMap.put(getICacheKey(UUID.randomUUID().toString()), UUID.randomUUID().toString()); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { threads[j] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); ehcacheTest.put(entry.getKey(), entry.getValue()); @@ -193,11 +212,12 @@ public void testConcurrentPut() throws Exception { } phaser.arriveAndAwaitAdvance(); // Will trigger parallel puts above. countDownLatch.await(); // Wait for all threads to finish - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { String value = ehcacheTest.get(entry.getKey()); assertEquals(entry.getValue(), value); } assertEquals(randomKeys, ehcacheTest.count()); + assertEquals(randomKeys, ehcacheTest.stats().getTotalEntries()); ehcacheTest.close(); } } @@ -214,11 +234,13 @@ public void testEhcacheParallelGets() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int randomKeys = randomIntBetween(20, 100); Thread[] threads = new Thread[randomKeys]; @@ -230,13 +252,13 @@ public void testEhcacheParallelGets() throws Exception { keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); } for (Map.Entry entry : keyValueMap.entrySet()) { - ehcacheTest.put(entry.getKey(), entry.getValue()); + ehcacheTest.put(getICacheKey(entry.getKey()), entry.getValue()); } assertEquals(keyValueMap.size(), ehcacheTest.count()); for (Map.Entry entry : keyValueMap.entrySet()) { threads[j] = new Thread(() -> { phaser.arriveAndAwaitAdvance(); - assertEquals(entry.getValue(), ehcacheTest.get(entry.getKey())); + assertEquals(entry.getValue(), ehcacheTest.get(getICacheKey(entry.getKey()))); countDownLatch.countDown(); }); threads[j].start(); @@ -244,6 +266,7 @@ public void testEhcacheParallelGets() throws Exception { } phaser.arriveAndAwaitAdvance(); // Will trigger parallel puts above. countDownLatch.await(); // Wait for all threads to finish + assertEquals(randomKeys, ehcacheTest.stats().getTotalHits()); ehcacheTest.close(); } } @@ -259,11 +282,13 @@ public void testEhcacheKeyIterator() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(new MockRemovalListener<>()) + .setWeigher(getWeigher()) .build(); int randomKeys = randomIntBetween(2, 100); @@ -272,12 +297,12 @@ public void testEhcacheKeyIterator() throws Exception { keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); } for (Map.Entry entry : keyValueMap.entrySet()) { - ehcacheTest.put(entry.getKey(), entry.getValue()); + ehcacheTest.put(getICacheKey(entry.getKey()), entry.getValue()); } - Iterator keys = ehcacheTest.keys().iterator(); + Iterator> keys = ehcacheTest.keys().iterator(); int keysCount = 0; while (keys.hasNext()) { - String key = keys.next(); + ICacheKey key = keys.next(); keysCount++; assertNotNull(ehcacheTest.get(key)); } @@ -289,6 +314,7 @@ public void testEhcacheKeyIterator() throws Exception { public void testEvictions() throws Exception { Settings settings = Settings.builder().build(); MockRemovalListener removalListener = new MockRemovalListener<>(); + ToLongBiFunction, String> weigher = getWeigher(); try (NodeEnvironment env = newNodeEnvironment(settings)) { ICache ehcacheTest = new EhcacheDiskCache.Builder().setDiskCacheAlias("test1") .setStoragePath(env.nodePaths()[0].indicesPath.toString() + "/request_cache") @@ -298,11 +324,13 @@ public void testEvictions() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(weigher) .build(); // Generate a string with 100 characters @@ -311,9 +339,10 @@ public void testEvictions() throws Exception { // Trying to generate more than 100kb to cause evictions. for (int i = 0; i < 1000; i++) { String key = "Key" + i; - ehcacheTest.put(key, value); + ehcacheTest.put(getICacheKey(key), value); } assertEquals(660, removalListener.evictionMetric.count()); + assertEquals(660, ehcacheTest.stats().getTotalEvictions()); ehcacheTest.close(); } } @@ -330,11 +359,13 @@ public void testComputeIfAbsentConcurrently() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int numberOfRequest = 2;// randomIntBetween(200, 400); @@ -344,12 +375,12 @@ public void testComputeIfAbsentConcurrently() throws Exception { Phaser phaser = new Phaser(numberOfRequest + 1); CountDownLatch countDownLatch = new CountDownLatch(numberOfRequest); - List> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); + List, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); // Try to hit different request with the same key concurrently. Verify value is only loaded once. for (int i = 0; i < numberOfRequest; i++) { threads[i] = new Thread(() -> { - LoadAwareCacheLoader loadAwareCacheLoader = new LoadAwareCacheLoader<>() { + LoadAwareCacheLoader, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() { boolean isLoaded; @Override @@ -358,7 +389,7 @@ public boolean isLoaded() { } @Override - public String load(String key) { + public String load(ICacheKey key) { isLoaded = true; return value; } @@ -366,7 +397,7 @@ public String load(String key) { loadAwareCacheLoaderList.add(loadAwareCacheLoader); phaser.arriveAndAwaitAdvance(); try { - assertEquals(value, ehcacheTest.computeIfAbsent(key, loadAwareCacheLoader)); + assertEquals(value, ehcacheTest.computeIfAbsent(getICacheKey(key), loadAwareCacheLoader)); } catch (Exception e) { throw new RuntimeException(e); } @@ -384,6 +415,9 @@ public String load(String key) { } assertEquals(1, numberOfTimesValueLoaded); assertEquals(0, ((EhcacheDiskCache) ehcacheTest).getCompletableFutureMap().size()); + assertEquals(1, ehcacheTest.stats().getTotalMisses()); + assertEquals(1, ehcacheTest.stats().getTotalEntries()); + assertEquals(numberOfRequest - 1, ehcacheTest.stats().getTotalHits()); assertEquals(1, ehcacheTest.count()); ehcacheTest.close(); } @@ -401,11 +435,13 @@ public void testComputeIfAbsentConcurrentlyAndThrowsException() throws Exception .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int numberOfRequest = randomIntBetween(200, 400); @@ -414,12 +450,12 @@ public void testComputeIfAbsentConcurrentlyAndThrowsException() throws Exception Phaser phaser = new Phaser(numberOfRequest + 1); CountDownLatch countDownLatch = new CountDownLatch(numberOfRequest); - List> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); + List, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); // Try to hit different request with the same key concurrently. Loader throws exception. for (int i = 0; i < numberOfRequest; i++) { threads[i] = new Thread(() -> { - LoadAwareCacheLoader loadAwareCacheLoader = new LoadAwareCacheLoader<>() { + LoadAwareCacheLoader, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() { boolean isLoaded; @Override @@ -428,14 +464,14 @@ public boolean isLoaded() { } @Override - public String load(String key) throws Exception { + public String load(ICacheKey key) throws Exception { isLoaded = true; throw new RuntimeException("Exception"); } }; loadAwareCacheLoaderList.add(loadAwareCacheLoader); phaser.arriveAndAwaitAdvance(); - assertThrows(ExecutionException.class, () -> ehcacheTest.computeIfAbsent(key, loadAwareCacheLoader)); + assertThrows(ExecutionException.class, () -> ehcacheTest.computeIfAbsent(getICacheKey(key), loadAwareCacheLoader)); countDownLatch.countDown(); }); threads[i].start(); @@ -460,11 +496,13 @@ public void testComputeIfAbsentWithNullValueLoading() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int numberOfRequest = randomIntBetween(200, 400); @@ -473,12 +511,12 @@ public void testComputeIfAbsentWithNullValueLoading() throws Exception { Phaser phaser = new Phaser(numberOfRequest + 1); CountDownLatch countDownLatch = new CountDownLatch(numberOfRequest); - List> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); + List, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>(); // Try to hit different request with the same key concurrently. Loader throws exception. for (int i = 0; i < numberOfRequest; i++) { threads[i] = new Thread(() -> { - LoadAwareCacheLoader loadAwareCacheLoader = new LoadAwareCacheLoader<>() { + LoadAwareCacheLoader, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() { boolean isLoaded; @Override @@ -487,7 +525,7 @@ public boolean isLoaded() { } @Override - public String load(String key) throws Exception { + public String load(ICacheKey key) throws Exception { isLoaded = true; return null; } @@ -495,11 +533,11 @@ public String load(String key) throws Exception { loadAwareCacheLoaderList.add(loadAwareCacheLoader); phaser.arriveAndAwaitAdvance(); try { - ehcacheTest.computeIfAbsent(key, loadAwareCacheLoader); + ehcacheTest.computeIfAbsent(getICacheKey(key), loadAwareCacheLoader); } catch (Exception ex) { assertThat(ex.getCause(), instanceOf(NullPointerException.class)); } - assertThrows(ExecutionException.class, () -> ehcacheTest.computeIfAbsent(key, loadAwareCacheLoader)); + assertThrows(ExecutionException.class, () -> ehcacheTest.computeIfAbsent(getICacheKey(key), loadAwareCacheLoader)); countDownLatch.countDown(); }); threads[i].start(); @@ -512,42 +550,119 @@ public String load(String key) throws Exception { } } - public void testEhcacheKeyIteratorWithRemove() throws IOException { + public void testMemoryTracking() throws Exception { + // Test all cases for EhCacheEventListener.onEvent and check stats memory usage is updated correctly Settings settings = Settings.builder().build(); + ToLongBiFunction, String> weigher = getWeigher(); + int initialKeyLength = 40; + int initialValueLength = 40; + long sizeForOneInitialEntry = weigher.applyAsLong( + new ICacheKey<>(generateRandomString(initialKeyLength), getMockDimensions()), + generateRandomString(initialValueLength) + ); + int maxEntries = 2000; try (NodeEnvironment env = newNodeEnvironment(settings)) { ICache ehcacheTest = new EhcacheDiskCache.Builder().setDiskCacheAlias("test1") .setThreadPoolAlias("ehcacheTest") .setStoragePath(env.nodePaths()[0].indicesPath.toString() + "/request_cache") - .setIsEventListenerModeSync(true) .setKeyType(String.class) .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) + .setIsEventListenerModeSync(true) // Test fails if async; probably not all updates happen before checking stats .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) + .setMaximumWeightInBytes(maxEntries * sizeForOneInitialEntry) + .setRemovalListener(new MockRemovalListener<>()) + .setWeigher(weigher) + .build(); + long expectedSize = 0; + + // Test CREATED case + int numInitialKeys = randomIntBetween(10, 100); + ArrayList> initialKeys = new ArrayList<>(); + for (int i = 0; i < numInitialKeys; i++) { + ICacheKey key = new ICacheKey<>(generateRandomString(initialKeyLength), getMockDimensions()); + String value = generateRandomString(initialValueLength); + ehcacheTest.put(key, value); + initialKeys.add(key); + expectedSize += weigher.applyAsLong(key, value); + assertEquals(expectedSize, ehcacheTest.stats().getTotalStats().getSizeInBytes()); + } + + // Test UPDATED case + HashMap, String> updatedValues = new HashMap<>(); + for (int i = 0; i < numInitialKeys * 0.5; i++) { + int newLengthDifference = randomIntBetween(-20, 20); + String newValue = generateRandomString(initialValueLength + newLengthDifference); + ehcacheTest.put(initialKeys.get(i), newValue); + updatedValues.put(initialKeys.get(i), newValue); + expectedSize += newLengthDifference; + assertEquals(expectedSize, ehcacheTest.stats().getTotalStats().getSizeInBytes()); + } + + // Test REMOVED case by removing all updated keys + for (int i = 0; i < numInitialKeys * 0.5; i++) { + ICacheKey removedKey = initialKeys.get(i); + ehcacheTest.invalidate(removedKey); + expectedSize -= weigher.applyAsLong(removedKey, updatedValues.get(removedKey)); + assertEquals(expectedSize, ehcacheTest.stats().getTotalStats().getSizeInBytes()); + } + + // Test EVICTED case by adding entries past the cap and ensuring memory size stays as what we expect + for (int i = 0; i < maxEntries - ehcacheTest.count(); i++) { + ICacheKey key = new ICacheKey<>(generateRandomString(initialKeyLength), getMockDimensions()); + String value = generateRandomString(initialValueLength); + ehcacheTest.put(key, value); + } + // TODO: Ehcache incorrectly evicts at 30-40% of max size. Fix this test once we figure out why. + // Since the EVICTED and EXPIRED cases use the same code as REMOVED, we should be ok on testing them for now. + // assertEquals(maxEntries * sizeForOneInitialEntry, ehcacheTest.stats().getTotalMemorySize()); + + ehcacheTest.close(); + } + } + + public void testEhcacheKeyIteratorWithRemove() throws IOException { + Settings settings = Settings.builder().build(); + try (NodeEnvironment env = newNodeEnvironment(settings)) { + ICache ehcacheTest = new EhcacheDiskCache.Builder().setDiskCacheAlias("test1") + .setThreadPoolAlias("ehcacheTest") + .setStoragePath(env.nodePaths()[0].indicesPath.toString() + "/request_cache") + .setIsEventListenerModeSync(true) + .setKeySerializer(new StringSerializer()) + .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) + .setCacheType(CacheType.INDICES_REQUEST_CACHE) + .setKeyType(String.class) + .setValueType(String.class) + .setSettings(settings) + .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(new MockRemovalListener<>()) + .setWeigher(getWeigher()) .build(); int randomKeys = randomIntBetween(2, 100); for (int i = 0; i < randomKeys; i++) { - ehcacheTest.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + ehcacheTest.put(getICacheKey(UUID.randomUUID().toString()), UUID.randomUUID().toString()); } long originalSize = ehcacheTest.count(); assertEquals(randomKeys, originalSize); // Now try removing subset of keys and verify - List removedKeyList = new ArrayList<>(); - for (Iterator iterator = ehcacheTest.keys().iterator(); iterator.hasNext();) { - String key = iterator.next(); + List> removedKeyList = new ArrayList<>(); + for (Iterator> iterator = ehcacheTest.keys().iterator(); iterator.hasNext();) { + ICacheKey key = iterator.next(); if (randomBoolean()) { removedKeyList.add(key); iterator.remove(); } } // Verify the removed key doesn't exist anymore. - for (String ehcacheKey : removedKeyList) { + for (ICacheKey ehcacheKey : removedKeyList) { assertNull(ehcacheTest.get(ehcacheKey)); } // Verify ehcache entry size again. @@ -568,22 +683,24 @@ public void testInvalidateAll() throws Exception { .setValueType(String.class) .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int randomKeys = randomIntBetween(10, 100); - Map keyValueMap = new HashMap<>(); + Map, String> keyValueMap = new HashMap<>(); for (int i = 0; i < randomKeys; i++) { - keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + keyValueMap.put(getICacheKey(UUID.randomUUID().toString()), UUID.randomUUID().toString()); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { ehcacheTest.put(entry.getKey(), entry.getValue()); } ehcacheTest.invalidateAll(); // clear all the entries. - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { // Verify that value is null for a removed entry. assertNull(ehcacheTest.get(entry.getKey())); } @@ -600,6 +717,7 @@ public void testBasicGetAndPutBytesReference() throws Exception { .setStoragePath(env.nodePaths()[0].indicesPath.toString() + "/request_cache") .setKeySerializer(new StringSerializer()) .setValueSerializer(new BytesReferenceSerializer()) + .setDimensionNames(List.of(dimensionName)) .setKeyType(String.class) .setValueType(BytesReference.class) .setCacheType(CacheType.INDICES_REQUEST_CACHE) @@ -607,15 +725,16 @@ public void testBasicGetAndPutBytesReference() throws Exception { .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES * 20) // bigger so no evictions happen .setExpireAfterAccess(TimeValue.MAX_VALUE) .setRemovalListener(new MockRemovalListener<>()) + .setWeigher((key, value) -> 1) .build(); int randomKeys = randomIntBetween(10, 100); int valueLength = 100; Random rand = Randomness.get(); - Map keyValueMap = new HashMap<>(); + Map, BytesReference> keyValueMap = new HashMap<>(); for (int i = 0; i < randomKeys; i++) { byte[] valueBytes = new byte[valueLength]; rand.nextBytes(valueBytes); - keyValueMap.put(UUID.randomUUID().toString(), new BytesArray(valueBytes)); + keyValueMap.put(getICacheKey(UUID.randomUUID().toString()), new BytesArray(valueBytes)); // Test a non-BytesArray implementation of BytesReference. byte[] compositeBytes1 = new byte[valueLength]; @@ -623,12 +742,12 @@ public void testBasicGetAndPutBytesReference() throws Exception { rand.nextBytes(compositeBytes1); rand.nextBytes(compositeBytes2); BytesReference composite = CompositeBytesReference.of(new BytesArray(compositeBytes1), new BytesArray(compositeBytes2)); - keyValueMap.put(UUID.randomUUID().toString(), composite); + keyValueMap.put(getICacheKey(UUID.randomUUID().toString()), composite); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, BytesReference> entry : keyValueMap.entrySet()) { ehCacheDiskCachingTier.put(entry.getKey(), entry.getValue()); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, BytesReference> entry : keyValueMap.entrySet()) { BytesReference value = ehCacheDiskCachingTier.get(entry.getKey()); assertEquals(entry.getValue(), value); } @@ -647,29 +766,31 @@ public void testInvalidate() throws Exception { .setKeySerializer(new StringSerializer()) .setValueSerializer(new StringSerializer()) .setValueType(String.class) + .setDimensionNames(List.of(dimensionName)) .setCacheType(CacheType.INDICES_REQUEST_CACHE) .setSettings(settings) .setExpireAfterAccess(TimeValue.MAX_VALUE) .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES) .setRemovalListener(removalListener) + .setWeigher(getWeigher()) .build(); int randomKeys = randomIntBetween(10, 100); - Map keyValueMap = new HashMap<>(); + Map, String> keyValueMap = new HashMap<>(); for (int i = 0; i < randomKeys; i++) { - keyValueMap.put(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + keyValueMap.put(getICacheKey(UUID.randomUUID().toString()), UUID.randomUUID().toString()); } - for (Map.Entry entry : keyValueMap.entrySet()) { + for (Map.Entry, String> entry : keyValueMap.entrySet()) { ehcacheTest.put(entry.getKey(), entry.getValue()); } assertEquals(keyValueMap.size(), ehcacheTest.count()); - List removedKeyList = new ArrayList<>(); - for (Map.Entry entry : keyValueMap.entrySet()) { + List> removedKeyList = new ArrayList<>(); + for (Map.Entry, String> entry : keyValueMap.entrySet()) { if (randomBoolean()) { removedKeyList.add(entry.getKey()); ehcacheTest.invalidate(entry.getKey()); } } - for (String removedKey : removedKeyList) { + for (ICacheKey removedKey : removedKeyList) { assertNull(ehcacheTest.get(removedKey)); } assertEquals(keyValueMap.size() - removedKeyList.size(), ehcacheTest.count()); @@ -677,6 +798,67 @@ public void testInvalidate() throws Exception { } } + // Modified from OpenSearchOnHeapCacheTests.java + public void testInvalidateWithDropDimensions() throws Exception { + Settings settings = Settings.builder().build(); + List dimensionNames = List.of("dim1", "dim2"); + try (NodeEnvironment env = newNodeEnvironment(settings)) { + ICache ehCacheDiskCachingTier = new EhcacheDiskCache.Builder().setThreadPoolAlias("ehcacheTest") + .setStoragePath(env.nodePaths()[0].indicesPath.toString() + "/request_cache") + .setKeySerializer(new StringSerializer()) + .setValueSerializer(new StringSerializer()) + .setIsEventListenerModeSync(true) + .setDimensionNames(dimensionNames) + .setKeyType(String.class) + .setValueType(String.class) + .setCacheType(CacheType.INDICES_REQUEST_CACHE) + .setSettings(settings) + .setMaximumWeightInBytes(CACHE_SIZE_IN_BYTES * 20) // bigger so no evictions happen + .setExpireAfterAccess(TimeValue.MAX_VALUE) + .setRemovalListener(new MockRemovalListener<>()) + .setWeigher((key, value) -> 1) + .build(); + + List> keysAdded = new ArrayList<>(); + + for (int i = 0; i < 20; i++) { + ICacheKey key = new ICacheKey<>(UUID.randomUUID().toString(), getRandomDimensions(dimensionNames)); + keysAdded.add(key); + ehCacheDiskCachingTier.put(key, UUID.randomUUID().toString()); + } + + ICacheKey keyToDrop = keysAdded.get(0); + + ImmutableCacheStats snapshot = ehCacheDiskCachingTier.stats().getStatsForDimensionValues(keyToDrop.dimensions); + assertNotNull(snapshot); + + keyToDrop.setDropStatsForDimensions(true); + ehCacheDiskCachingTier.invalidate(keyToDrop); + + // Now assert the stats are gone for any key that has this combination of dimensions, but still there otherwise + for (ICacheKey keyAdded : keysAdded) { + snapshot = ehCacheDiskCachingTier.stats().getStatsForDimensionValues(keyAdded.dimensions); + if (keyAdded.dimensions.equals(keyToDrop.dimensions)) { + assertNull(snapshot); + } else { + assertNotNull(snapshot); + } + } + + ehCacheDiskCachingTier.close(); + } + } + + private List getRandomDimensions(List dimensionNames) { + Random rand = Randomness.get(); + int bound = 3; + List result = new ArrayList<>(); + for (String dimName : dimensionNames) { + result.add(String.valueOf(rand.nextInt(bound))); + } + return result; + } + private static String generateRandomString(int length) { String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; StringBuilder randomString = new StringBuilder(length); @@ -689,12 +871,34 @@ private static String generateRandomString(int length) { return randomString.toString(); } - static class MockRemovalListener implements RemovalListener { + private List getMockDimensions() { + return List.of("0"); + } + + private ICacheKey getICacheKey(String key) { + return new ICacheKey<>(key, getMockDimensions()); + } + + private ToLongBiFunction, String> getWeigher() { + return (iCacheKey, value) -> { + // Size consumed by key + long totalSize = iCacheKey.key.length(); + for (String dim : iCacheKey.dimensions) { + totalSize += dim.length(); + } + totalSize += 10; // The ICacheKeySerializer writes 2 VInts to record array lengths, which can be 1-5 bytes each + // Size consumed by value + totalSize += value.length(); + return totalSize; + }; + } + + static class MockRemovalListener implements RemovalListener, V> { CounterMetric evictionMetric = new CounterMetric(); @Override - public void onRemoval(RemovalNotification notification) { + public void onRemoval(RemovalNotification, V> notification) { evictionMetric.inc(); } } diff --git a/server/src/main/java/org/opensearch/common/cache/Cache.java b/server/src/main/java/org/opensearch/common/cache/Cache.java index d8aa4e93735e6..6d346de25cadf 100644 --- a/server/src/main/java/org/opensearch/common/cache/Cache.java +++ b/server/src/main/java/org/opensearch/common/cache/Cache.java @@ -896,6 +896,10 @@ private void relinkAtHead(Entry entry) { } } + public ToLongBiFunction getWeigher() { + return weigher; + } + private CacheSegment getCacheSegment(K key) { return segments[key.hashCode() & 0xff]; } diff --git a/server/src/main/java/org/opensearch/common/cache/ICache.java b/server/src/main/java/org/opensearch/common/cache/ICache.java index f7be46a852631..8d8964abf0829 100644 --- a/server/src/main/java/org/opensearch/common/cache/ICache.java +++ b/server/src/main/java/org/opensearch/common/cache/ICache.java @@ -9,6 +9,7 @@ package org.opensearch.common.cache; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder; import org.opensearch.common.cache.store.config.CacheConfig; import java.io.Closeable; @@ -23,22 +24,29 @@ */ @ExperimentalApi public interface ICache extends Closeable { - V get(K key); + V get(ICacheKey key); - void put(K key, V value); + void put(ICacheKey key, V value); - V computeIfAbsent(K key, LoadAwareCacheLoader loader) throws Exception; + V computeIfAbsent(ICacheKey key, LoadAwareCacheLoader, V> loader) throws Exception; - void invalidate(K key); + /** + * Invalidates the key. If a dimension in the key has dropStatsOnInvalidation set to true, the cache also completely + * resets stats for that dimension value. It's the caller's responsibility to make sure all keys with that dimension value are + * actually invalidated. + */ + void invalidate(ICacheKey key); void invalidateAll(); - Iterable keys(); + Iterable> keys(); long count(); void refresh(); + ImmutableCacheStatsHolder stats(); + /** * Factory to create objects. */ diff --git a/server/src/main/java/org/opensearch/common/cache/ICacheKey.java b/server/src/main/java/org/opensearch/common/cache/ICacheKey.java new file mode 100644 index 0000000000000..4d93aab933751 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/ICacheKey.java @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache; + +import org.opensearch.common.annotation.ExperimentalApi; + +import java.util.List; + +/** + * A key wrapper used for ICache implementations, which carries dimensions with it. + * @param the type of the underlying key. K must implement equals(), or else ICacheKey.equals() + * won't work properly and cache behavior may be incorrect! + * + * @opensearch.experimental + */ +@ExperimentalApi +public class ICacheKey { + public final K key; // K must implement equals() + public final List dimensions; // Dimension values. The dimension names are implied. + /** + * If this key is invalidated and dropDimensions is true, the ICache implementation will also drop all stats, + * including hits/misses/evictions, with this combination of dimension values. + */ + private boolean dropStatsForDimensions; + + /** + * Constructor to use when specifying dimensions. + */ + public ICacheKey(K key, List dimensions) { + this.key = key; + this.dimensions = dimensions; + } + + /** + * Constructor to use when no dimensions are needed. + */ + public ICacheKey(K key) { + this.key = key; + this.dimensions = List.of(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (o == null) { + return false; + } + if (o.getClass() != ICacheKey.class) { + return false; + } + ICacheKey other = (ICacheKey) o; + if (!dimensions.equals(other.dimensions)) { + return false; + } + if (this.key == null && other.key == null) { + return true; + } + if (this.key == null || other.key == null) { + return false; + } + return this.key.equals(other.key); + } + + @Override + public int hashCode() { + if (key == null) { + return dimensions.hashCode(); + } + return 31 * key.hashCode() + dimensions.hashCode(); + } + + // As K might not be Accountable, directly pass in its memory usage to be added. + public long ramBytesUsed(long underlyingKeyRamBytes) { + long estimate = underlyingKeyRamBytes; + for (String dim : dimensions) { + estimate += dim.length(); + } + return estimate; + } + + public boolean getDropStatsForDimensions() { + return dropStatsForDimensions; + } + + public void setDropStatsForDimensions(boolean newValue) { + dropStatsForDimensions = newValue; + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/serializer/ICacheKeySerializer.java b/server/src/main/java/org/opensearch/common/cache/serializer/ICacheKeySerializer.java new file mode 100644 index 0000000000000..7521e23091464 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/serializer/ICacheKeySerializer.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.serializer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; +import org.opensearch.common.cache.ICacheKey; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A serializer for ICacheKey. + * @param the type of the underlying key in ICacheKey + */ +public class ICacheKeySerializer implements Serializer, byte[]> { + + public final Serializer keySerializer; + private final Logger logger = LogManager.getLogger(ICacheKeySerializer.class); + + public ICacheKeySerializer(Serializer serializer) { + this.keySerializer = serializer; + } + + @Override + public byte[] serialize(ICacheKey object) { + if (object == null || object.key == null || object.dimensions == null) { + return null; + } + byte[] serializedKey = keySerializer.serialize(object.key); + try { + BytesStreamOutput os = new BytesStreamOutput(); + // First write the number of dimensions + os.writeVInt(object.dimensions.size()); + for (String dimValue : object.dimensions) { + os.writeString(dimValue); + } + os.writeVInt(serializedKey.length); // The read byte[] fn seems to not work as expected + os.writeBytes(serializedKey); + byte[] finalBytes = BytesReference.toBytes(os.bytes()); + return finalBytes; + } catch (IOException e) { + logger.debug("Could not write ICacheKey to byte[]"); + throw new OpenSearchException(e); + } + } + + @Override + public ICacheKey deserialize(byte[] bytes) { + if (bytes == null) { + return null; + } + List dimensionList = new ArrayList<>(); + try { + BytesStreamInput is = new BytesStreamInput(bytes, 0, bytes.length); + int numDimensions = is.readVInt(); + for (int i = 0; i < numDimensions; i++) { + dimensionList.add(is.readString()); + } + + int length = is.readVInt(); + byte[] serializedKey = new byte[length]; + is.readBytes(serializedKey, 0, length); + return new ICacheKey<>(keySerializer.deserialize(serializedKey), dimensionList); + } catch (IOException e) { + logger.debug("Could not write byte[] to ICacheKey"); + throw new OpenSearchException(e); + } + } + + @Override + public boolean equals(ICacheKey object, byte[] bytes) { + return Arrays.equals(serialize(object), bytes); + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/stats/CacheStats.java b/server/src/main/java/org/opensearch/common/cache/stats/CacheStats.java new file mode 100644 index 0000000000000..b0cb66b56b70d --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/stats/CacheStats.java @@ -0,0 +1,132 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import org.opensearch.common.metrics.CounterMetric; + +import java.util.Objects; + +/** + * A mutable class containing the 5 live metrics tracked by a StatsHolder object. + */ +public class CacheStats { + CounterMetric hits; + CounterMetric misses; + CounterMetric evictions; + CounterMetric sizeInBytes; + CounterMetric entries; + + public CacheStats(long hits, long misses, long evictions, long sizeInBytes, long entries) { + this.hits = new CounterMetric(); + this.hits.inc(hits); + this.misses = new CounterMetric(); + this.misses.inc(misses); + this.evictions = new CounterMetric(); + this.evictions.inc(evictions); + this.sizeInBytes = new CounterMetric(); + this.sizeInBytes.inc(sizeInBytes); + this.entries = new CounterMetric(); + this.entries.inc(entries); + } + + public CacheStats() { + this(0, 0, 0, 0, 0); + } + + private void internalAdd(long otherHits, long otherMisses, long otherEvictions, long otherSizeInBytes, long otherEntries) { + this.hits.inc(otherHits); + this.misses.inc(otherMisses); + this.evictions.inc(otherEvictions); + this.sizeInBytes.inc(otherSizeInBytes); + this.entries.inc(otherEntries); + } + + public void add(CacheStats other) { + if (other == null) { + return; + } + internalAdd(other.getHits(), other.getMisses(), other.getEvictions(), other.getSizeInBytes(), other.getEntries()); + } + + public void add(ImmutableCacheStats snapshot) { + if (snapshot == null) { + return; + } + internalAdd(snapshot.getHits(), snapshot.getMisses(), snapshot.getEvictions(), snapshot.getSizeInBytes(), snapshot.getEntries()); + } + + public void subtract(ImmutableCacheStats other) { + if (other == null) { + return; + } + internalAdd(-other.getHits(), -other.getMisses(), -other.getEvictions(), -other.getSizeInBytes(), -other.getEntries()); + } + + @Override + public int hashCode() { + return Objects.hash(hits.count(), misses.count(), evictions.count(), sizeInBytes.count(), entries.count()); + } + + public void incrementHits() { + hits.inc(); + } + + public void incrementMisses() { + misses.inc(); + } + + public void incrementEvictions() { + evictions.inc(); + } + + public void incrementSizeInBytes(long amount) { + sizeInBytes.inc(amount); + } + + public void decrementSizeInBytes(long amount) { + sizeInBytes.dec(amount); + } + + public void incrementEntries() { + entries.inc(); + } + + public void decrementEntries() { + entries.dec(); + } + + public long getHits() { + return hits.count(); + } + + public long getMisses() { + return misses.count(); + } + + public long getEvictions() { + return evictions.count(); + } + + public long getSizeInBytes() { + return sizeInBytes.count(); + } + + public long getEntries() { + return entries.count(); + } + + public void resetSizeAndEntries() { + sizeInBytes = new CounterMetric(); + entries = new CounterMetric(); + } + + public ImmutableCacheStats immutableSnapshot() { + return new ImmutableCacheStats(hits.count(), misses.count(), evictions.count(), sizeInBytes.count(), entries.count()); + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/stats/CacheStatsHolder.java b/server/src/main/java/org/opensearch/common/cache/stats/CacheStatsHolder.java new file mode 100644 index 0000000000000..a8b7c27ef9e79 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/stats/CacheStatsHolder.java @@ -0,0 +1,295 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; + +/** + * A class ICache implementations use to internally keep track of their stats across multiple dimensions. + * Not intended to be exposed outside the cache; for this, caches use getImmutableCacheStatsHolder() to create an immutable + * copy of the current state of the stats. + * Currently, in the IRC, the stats tracked in a CacheStatsHolder will not appear for empty shards that have had no cache + * operations done on them yet. This might be changed in the future, by exposing a method to add empty nodes to the + * tree in CacheStatsHolder in the ICache interface. + * + * @opensearch.experimental + */ +public class CacheStatsHolder { + + // The list of permitted dimensions. Should be ordered from "outermost" to "innermost", as you would like to + // aggregate them in an API response. + private final List dimensionNames; + // A tree structure based on dimension values, which stores stats values in its leaf nodes. + // Non-leaf nodes have stats matching the sum of their children. + // We use a tree structure, rather than a map with concatenated keys, to save on memory usage. If there are many leaf + // nodes that share a parent, that parent's dimension value will only be stored once, not many times. + private final Node statsRoot; + // To avoid sync problems, obtain a lock before creating or removing nodes in the stats tree. + // No lock is needed to edit stats on existing nodes. + private final Lock lock = new ReentrantLock(); + + public CacheStatsHolder(List dimensionNames) { + this.dimensionNames = Collections.unmodifiableList(dimensionNames); + this.statsRoot = new Node("", true); // The root node has the empty string as its dimension value + } + + public List getDimensionNames() { + return dimensionNames; + } + + // For all these increment functions, the dimensions list comes from the key, and contains all dimensions present in dimensionNames. + // The order has to match the order given in dimensionNames. + public void incrementHits(List dimensionValues) { + internalIncrement(dimensionValues, Node::incrementHits, true); + } + + public void incrementMisses(List dimensionValues) { + internalIncrement(dimensionValues, Node::incrementMisses, true); + } + + public void incrementEvictions(List dimensionValues) { + internalIncrement(dimensionValues, Node::incrementEvictions, true); + } + + public void incrementSizeInBytes(List dimensionValues, long amountBytes) { + internalIncrement(dimensionValues, (node) -> node.incrementSizeInBytes(amountBytes), true); + } + + // For decrements, we should not create nodes if they are absent. This protects us from erroneously decrementing values for keys + // which have been entirely deleted, for example in an async removal listener. + public void decrementSizeInBytes(List dimensionValues, long amountBytes) { + internalIncrement(dimensionValues, (node) -> node.decrementSizeInBytes(amountBytes), false); + } + + public void incrementEntries(List dimensionValues) { + internalIncrement(dimensionValues, Node::incrementEntries, true); + } + + public void decrementEntries(List dimensionValues) { + internalIncrement(dimensionValues, Node::decrementEntries, false); + } + + /** + * Reset number of entries and memory size when all keys leave the cache, but don't reset hit/miss/eviction numbers. + * This is in line with the behavior of the existing API when caches are cleared. + */ + public void reset() { + resetHelper(statsRoot); + } + + private void resetHelper(Node current) { + current.resetSizeAndEntries(); + for (Node child : current.children.values()) { + resetHelper(child); + } + } + + public long count() { + // Include this here so caches don't have to create an entire CacheStats object to run count(). + return statsRoot.getEntries(); + } + + private void internalIncrement(List dimensionValues, Consumer adder, boolean createNodesIfAbsent) { + assert dimensionValues.size() == dimensionNames.size(); + // First try to increment without creating nodes + boolean didIncrement = internalIncrementHelper(dimensionValues, statsRoot, 0, adder, false); + // If we failed to increment, because nodes had to be created, obtain the lock and run again while creating nodes if needed + if (!didIncrement && createNodesIfAbsent) { + try { + lock.lock(); + internalIncrementHelper(dimensionValues, statsRoot, 0, adder, true); + } finally { + lock.unlock(); + } + } + } + + /** + * Use the incrementer function to increment/decrement a value in the stats for a set of dimensions. + * If createNodesIfAbsent is true, and there is no stats for this set of dimensions, create one. + * Returns true if the increment was applied, false if not. + */ + private boolean internalIncrementHelper( + List dimensionValues, + Node node, + int depth, // Pass in the depth to avoid having to slice the list for each node. + Consumer adder, + boolean createNodesIfAbsent + ) { + if (depth == dimensionValues.size()) { + // This is the leaf node we are trying to reach + adder.accept(node); + return true; + } + + Node child = node.getChild(dimensionValues.get(depth)); + if (child == null) { + if (createNodesIfAbsent) { + boolean createMapInChild = depth < dimensionValues.size() - 1; + child = node.createChild(dimensionValues.get(depth), createMapInChild); + } else { + return false; + } + } + if (internalIncrementHelper(dimensionValues, child, depth + 1, adder, createNodesIfAbsent)) { + // Function returns true if the next node down was incremented + adder.accept(node); + return true; + } + return false; + } + + /** + * Produce an immutable version of these stats. + */ + public ImmutableCacheStatsHolder getImmutableCacheStatsHolder() { + return new ImmutableCacheStatsHolder(statsRoot.snapshot(), dimensionNames); + } + + public void removeDimensions(List dimensionValues) { + assert dimensionValues.size() == dimensionNames.size() : "Must specify a value for every dimension when removing from StatsHolder"; + // As we are removing nodes from the tree, obtain the lock + lock.lock(); + try { + removeDimensionsHelper(dimensionValues, statsRoot, 0); + } finally { + lock.unlock(); + } + } + + // Returns a CacheStatsCounterSnapshot object for the stats to decrement if the removal happened, null otherwise. + private ImmutableCacheStats removeDimensionsHelper(List dimensionValues, Node node, int depth) { + if (depth == dimensionValues.size()) { + // Pass up a snapshot of the original stats to avoid issues when the original is decremented by other fn invocations + return node.getImmutableStats(); + } + Node child = node.getChild(dimensionValues.get(depth)); + if (child == null) { + return null; + } + ImmutableCacheStats statsToDecrement = removeDimensionsHelper(dimensionValues, child, depth + 1); + if (statsToDecrement != null) { + // The removal took place, decrement values and remove this node from its parent if it's now empty + node.decrementBySnapshot(statsToDecrement); + if (child.getChildren().isEmpty()) { + node.children.remove(child.getDimensionValue()); + } + } + return statsToDecrement; + } + + // pkg-private for testing + Node getStatsRoot() { + return statsRoot; + } + + static class Node { + private final String dimensionValue; + // Map from dimensionValue to the DimensionNode for that dimension value. + final Map children; + // The stats for this node. If a leaf node, corresponds to the stats for this combination of dimensions; if not, + // contains the sum of its children's stats. + private CacheStats stats; + + // Used for leaf nodes to avoid allocating many unnecessary maps + private static final Map EMPTY_CHILDREN_MAP = new HashMap<>(); + + Node(String dimensionValue, boolean createChildrenMap) { + this.dimensionValue = dimensionValue; + if (createChildrenMap) { + this.children = new ConcurrentHashMap<>(); + } else { + this.children = EMPTY_CHILDREN_MAP; + } + this.stats = new CacheStats(); + } + + public String getDimensionValue() { + return dimensionValue; + } + + protected Map getChildren() { + // We can safely iterate over ConcurrentHashMap without worrying about thread issues. + return children; + } + + // Functions for modifying internal CacheStatsCounter without callers having to be aware of CacheStatsCounter + + void incrementHits() { + this.stats.incrementHits(); + } + + void incrementMisses() { + this.stats.incrementMisses(); + } + + void incrementEvictions() { + this.stats.incrementEvictions(); + } + + void incrementSizeInBytes(long amountBytes) { + this.stats.incrementSizeInBytes(amountBytes); + } + + void decrementSizeInBytes(long amountBytes) { + this.stats.decrementSizeInBytes(amountBytes); + } + + void incrementEntries() { + this.stats.incrementEntries(); + } + + void decrementEntries() { + this.stats.decrementEntries(); + } + + long getEntries() { + return this.stats.getEntries(); + } + + ImmutableCacheStats getImmutableStats() { + return this.stats.immutableSnapshot(); + } + + void decrementBySnapshot(ImmutableCacheStats snapshot) { + this.stats.subtract(snapshot); + } + + void resetSizeAndEntries() { + this.stats.resetSizeAndEntries(); + } + + Node getChild(String dimensionValue) { + return children.get(dimensionValue); + } + + Node createChild(String dimensionValue, boolean createMapInChild) { + return children.computeIfAbsent(dimensionValue, (key) -> new Node(dimensionValue, createMapInChild)); + } + + ImmutableCacheStatsHolder.Node snapshot() { + TreeMap snapshotChildren = null; + if (!children.isEmpty()) { + snapshotChildren = new TreeMap<>(); + for (Node child : children.values()) { + snapshotChildren.put(child.getDimensionValue(), child.snapshot()); + } + } + return new ImmutableCacheStatsHolder.Node(dimensionValue, snapshotChildren, getImmutableStats()); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStats.java b/server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStats.java new file mode 100644 index 0000000000000..7549490fd6b74 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStats.java @@ -0,0 +1,103 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Objects; + +/** + * An immutable snapshot of CacheStats. + * + * @opensearch.experimental + */ +@ExperimentalApi +public class ImmutableCacheStats implements Writeable { // TODO: Make this extend ToXContent (in API PR) + private final long hits; + private final long misses; + private final long evictions; + private final long sizeInBytes; + private final long entries; + + public ImmutableCacheStats(long hits, long misses, long evictions, long sizeInBytes, long entries) { + this.hits = hits; + this.misses = misses; + this.evictions = evictions; + this.sizeInBytes = sizeInBytes; + this.entries = entries; + } + + public ImmutableCacheStats(StreamInput in) throws IOException { + this(in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong()); + } + + public static ImmutableCacheStats addSnapshots(ImmutableCacheStats s1, ImmutableCacheStats s2) { + return new ImmutableCacheStats( + s1.hits + s2.hits, + s1.misses + s2.misses, + s1.evictions + s2.evictions, + s1.sizeInBytes + s2.sizeInBytes, + s1.entries + s2.entries + ); + } + + public long getHits() { + return hits; + } + + public long getMisses() { + return misses; + } + + public long getEvictions() { + return evictions; + } + + public long getSizeInBytes() { + return sizeInBytes; + } + + public long getEntries() { + return entries; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(hits); + out.writeVLong(misses); + out.writeVLong(evictions); + out.writeVLong(sizeInBytes); + out.writeVLong(entries); + } + + @Override + public boolean equals(Object o) { + if (o == null) { + return false; + } + if (o.getClass() != ImmutableCacheStats.class) { + return false; + } + ImmutableCacheStats other = (ImmutableCacheStats) o; + return (hits == other.hits) + && (misses == other.misses) + && (evictions == other.evictions) + && (sizeInBytes == other.sizeInBytes) + && (entries == other.entries); + } + + @Override + public int hashCode() { + return Objects.hash(hits, misses, evictions, sizeInBytes, entries); + } +} diff --git a/server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolder.java b/server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolder.java new file mode 100644 index 0000000000000..12e325046d83b --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolder.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import org.opensearch.common.annotation.ExperimentalApi; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * An object storing an immutable snapshot of an entire cache's stats. Accessible outside the cache itself. + * + * @opensearch.experimental + */ + +@ExperimentalApi +public class ImmutableCacheStatsHolder { // TODO: extends Writeable, ToXContent + // An immutable snapshot of a stats within a CacheStatsHolder, containing all the stats maintained by the cache. + // Pkg-private for testing. + final Node statsRoot; + final List dimensionNames; + + public ImmutableCacheStatsHolder(Node statsRoot, List dimensionNames) { + this.statsRoot = statsRoot; + this.dimensionNames = dimensionNames; + } + + public ImmutableCacheStats getTotalStats() { + return statsRoot.getStats(); + } + + public long getTotalHits() { + return getTotalStats().getHits(); + } + + public long getTotalMisses() { + return getTotalStats().getMisses(); + } + + public long getTotalEvictions() { + return getTotalStats().getEvictions(); + } + + public long getTotalSizeInBytes() { + return getTotalStats().getSizeInBytes(); + } + + public long getTotalEntries() { + return getTotalStats().getEntries(); + } + + public ImmutableCacheStats getStatsForDimensionValues(List dimensionValues) { + Node current = statsRoot; + for (String dimensionValue : dimensionValues) { + current = current.children.get(dimensionValue); + if (current == null) { + return null; + } + } + return current.stats; + } + + // A similar class to CacheStatsHolder.Node, which uses an ordered TreeMap and holds immutable CacheStatsSnapshot as its stats. + static class Node { + private final String dimensionValue; + final Map children; // Map from dimensionValue to the Node for that dimension value + + // The stats for this node. If a leaf node, corresponds to the stats for this combination of dimensions; if not, + // contains the sum of its children's stats. + private final ImmutableCacheStats stats; + private static final Map EMPTY_CHILDREN_MAP = new HashMap<>(); + + Node(String dimensionValue, TreeMap snapshotChildren, ImmutableCacheStats stats) { + this.dimensionValue = dimensionValue; + this.stats = stats; + if (snapshotChildren == null) { + this.children = EMPTY_CHILDREN_MAP; + } else { + this.children = Collections.unmodifiableMap(snapshotChildren); + } + } + + Map getChildren() { + return children; + } + + public ImmutableCacheStats getStats() { + return stats; + } + + public String getDimensionValue() { + return dimensionValue; + } + } + + // pkg-private for testing + Node getStatsRoot() { + return statsRoot; + } + + // TODO (in API PR): Produce XContent based on aggregateByLevels() +} diff --git a/server/src/main/java/org/opensearch/common/cache/stats/package-info.java b/server/src/main/java/org/opensearch/common/cache/stats/package-info.java new file mode 100644 index 0000000000000..95b5bc8efb510 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/cache/stats/package-info.java @@ -0,0 +1,9 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +/** A package for cache stats. */ +package org.opensearch.common.cache.stats; diff --git a/server/src/main/java/org/opensearch/common/cache/store/OpenSearchOnHeapCache.java b/server/src/main/java/org/opensearch/common/cache/store/OpenSearchOnHeapCache.java index c9bec4ba47def..29e5667c9f27d 100644 --- a/server/src/main/java/org/opensearch/common/cache/store/OpenSearchOnHeapCache.java +++ b/server/src/main/java/org/opensearch/common/cache/store/OpenSearchOnHeapCache.java @@ -12,10 +12,14 @@ import org.opensearch.common.cache.CacheBuilder; import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; +import org.opensearch.common.cache.RemovalReason; import org.opensearch.common.cache.settings.CacheSettings; +import org.opensearch.common.cache.stats.CacheStatsHolder; +import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder; import org.opensearch.common.cache.store.builders.ICacheBuilder; import org.opensearch.common.cache.store.config.CacheConfig; import org.opensearch.common.cache.store.settings.OpenSearchOnHeapCacheSettings; @@ -25,7 +29,10 @@ import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.unit.ByteSizeValue; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.function.ToLongBiFunction; import static org.opensearch.common.cache.store.settings.OpenSearchOnHeapCacheSettings.EXPIRE_AFTER_ACCESS_KEY; import static org.opensearch.common.cache.store.settings.OpenSearchOnHeapCacheSettings.MAXIMUM_SIZE_IN_BYTES_KEY; @@ -37,13 +44,16 @@ * * @opensearch.experimental */ -public class OpenSearchOnHeapCache implements ICache, RemovalListener { +public class OpenSearchOnHeapCache implements ICache, RemovalListener, V> { - private final Cache cache; - private final RemovalListener removalListener; + private final Cache, V> cache; + private final CacheStatsHolder cacheStatsHolder; + private final RemovalListener, V> removalListener; + private final List dimensionNames; + private final ToLongBiFunction, V> weigher; public OpenSearchOnHeapCache(Builder builder) { - CacheBuilder cacheBuilder = CacheBuilder.builder() + CacheBuilder, V> cacheBuilder = CacheBuilder., V>builder() .setMaximumWeight(builder.getMaxWeightInBytes()) .weigher(builder.getWeigher()) .removalListener(this); @@ -51,44 +61,67 @@ public OpenSearchOnHeapCache(Builder builder) { cacheBuilder.setExpireAfterAccess(builder.getExpireAfterAcess()); } cache = cacheBuilder.build(); + this.dimensionNames = Objects.requireNonNull(builder.dimensionNames, "Dimension names can't be null"); + this.cacheStatsHolder = new CacheStatsHolder(dimensionNames); this.removalListener = builder.getRemovalListener(); + this.weigher = builder.getWeigher(); } @Override - public V get(K key) { + public V get(ICacheKey key) { V value = cache.get(key); + if (value != null) { + cacheStatsHolder.incrementHits(key.dimensions); + } else { + cacheStatsHolder.incrementMisses(key.dimensions); + } return value; } @Override - public void put(K key, V value) { + public void put(ICacheKey key, V value) { cache.put(key, value); + cacheStatsHolder.incrementEntries(key.dimensions); + cacheStatsHolder.incrementSizeInBytes(key.dimensions, weigher.applyAsLong(key, value)); } @Override - public V computeIfAbsent(K key, LoadAwareCacheLoader loader) throws Exception { + public V computeIfAbsent(ICacheKey key, LoadAwareCacheLoader, V> loader) throws Exception { V value = cache.computeIfAbsent(key, key1 -> loader.load(key)); + if (!loader.isLoaded()) { + cacheStatsHolder.incrementHits(key.dimensions); + } else { + cacheStatsHolder.incrementMisses(key.dimensions); + cacheStatsHolder.incrementEntries(key.dimensions); + cacheStatsHolder.incrementSizeInBytes(key.dimensions, cache.getWeigher().applyAsLong(key, value)); + } return value; } @Override - public void invalidate(K key) { - cache.invalidate(key); + public void invalidate(ICacheKey key) { + if (key.getDropStatsForDimensions()) { + cacheStatsHolder.removeDimensions(key.dimensions); + } + if (key.key != null) { + cache.invalidate(key); + } } @Override public void invalidateAll() { cache.invalidateAll(); + cacheStatsHolder.reset(); } @Override - public Iterable keys() { + public Iterable> keys() { return cache.keys(); } @Override public long count() { - return cache.count(); + return cacheStatsHolder.count(); } @Override @@ -100,8 +133,23 @@ public void refresh() { public void close() {} @Override - public void onRemoval(RemovalNotification notification) { - this.removalListener.onRemoval(notification); + public ImmutableCacheStatsHolder stats() { + return cacheStatsHolder.getImmutableCacheStatsHolder(); + } + + @Override + public void onRemoval(RemovalNotification, V> notification) { + removalListener.onRemoval(notification); + cacheStatsHolder.decrementEntries(notification.getKey().dimensions); + cacheStatsHolder.decrementSizeInBytes( + notification.getKey().dimensions, + cache.getWeigher().applyAsLong(notification.getKey(), notification.getValue()) + ); + + if (RemovalReason.EVICTED.equals(notification.getRemovalReason()) + || RemovalReason.CAPACITY.equals(notification.getRemovalReason())) { + cacheStatsHolder.incrementEvictions(notification.getKey().dimensions); + } } /** @@ -115,9 +163,8 @@ public static class OpenSearchOnHeapCacheFactory implements Factory { public ICache create(CacheConfig config, CacheType cacheType, Map cacheFactories) { Map> settingList = OpenSearchOnHeapCacheSettings.getSettingListForCacheType(cacheType); Settings settings = config.getSettings(); - ICacheBuilder builder = new Builder().setMaximumWeightInBytes( - ((ByteSizeValue) settingList.get(MAXIMUM_SIZE_IN_BYTES_KEY).get(settings)).getBytes() - ) + ICacheBuilder builder = new Builder().setDimensionNames(config.getDimensionNames()) + .setMaximumWeightInBytes(((ByteSizeValue) settingList.get(MAXIMUM_SIZE_IN_BYTES_KEY).get(settings)).getBytes()) .setExpireAfterAccess(((TimeValue) settingList.get(EXPIRE_AFTER_ACCESS_KEY).get(settings))) .setWeigher(config.getWeigher()) .setRemovalListener(config.getRemovalListener()); @@ -145,6 +192,12 @@ public String getCacheName() { * @param Type of value */ public static class Builder extends ICacheBuilder { + private List dimensionNames; + + public Builder setDimensionNames(List dimensionNames) { + this.dimensionNames = dimensionNames; + return this; + } @Override public ICache build() { diff --git a/server/src/main/java/org/opensearch/common/cache/store/builders/ICacheBuilder.java b/server/src/main/java/org/opensearch/common/cache/store/builders/ICacheBuilder.java index 7ca9080ec1aa6..ac90fcc85ffef 100644 --- a/server/src/main/java/org/opensearch/common/cache/store/builders/ICacheBuilder.java +++ b/server/src/main/java/org/opensearch/common/cache/store/builders/ICacheBuilder.java @@ -10,6 +10,7 @@ import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -28,13 +29,13 @@ public abstract class ICacheBuilder { private long maxWeightInBytes; - private ToLongBiFunction weigher; + private ToLongBiFunction, V> weigher; private TimeValue expireAfterAcess; private Settings settings; - private RemovalListener removalListener; + private RemovalListener, V> removalListener; public ICacheBuilder() {} @@ -43,7 +44,7 @@ public ICacheBuilder setMaximumWeightInBytes(long sizeInBytes) { return this; } - public ICacheBuilder setWeigher(ToLongBiFunction weigher) { + public ICacheBuilder setWeigher(ToLongBiFunction, V> weigher) { this.weigher = weigher; return this; } @@ -58,7 +59,7 @@ public ICacheBuilder setSettings(Settings settings) { return this; } - public ICacheBuilder setRemovalListener(RemovalListener removalListener) { + public ICacheBuilder setRemovalListener(RemovalListener, V> removalListener) { this.removalListener = removalListener; return this; } @@ -71,11 +72,11 @@ public TimeValue getExpireAfterAcess() { return expireAfterAcess; } - public ToLongBiFunction getWeigher() { + public ToLongBiFunction, V> getWeigher() { return weigher; } - public RemovalListener getRemovalListener() { + public RemovalListener, V> getRemovalListener() { return this.removalListener; } diff --git a/server/src/main/java/org/opensearch/common/cache/store/config/CacheConfig.java b/server/src/main/java/org/opensearch/common/cache/store/config/CacheConfig.java index e537ece759e65..15cbdbd021d71 100644 --- a/server/src/main/java/org/opensearch/common/cache/store/config/CacheConfig.java +++ b/server/src/main/java/org/opensearch/common/cache/store/config/CacheConfig.java @@ -9,6 +9,7 @@ package org.opensearch.common.cache.store.config; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.policy.CachedQueryResult; import org.opensearch.common.cache.serializer.Serializer; @@ -16,6 +17,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import java.util.List; import java.util.function.Function; import java.util.function.ToLongBiFunction; @@ -42,9 +44,11 @@ public class CacheConfig { /** * Represents a function that calculates the size or weight of a key-value pair. */ - private final ToLongBiFunction weigher; + private final ToLongBiFunction, V> weigher; - private final RemovalListener removalListener; + private final RemovalListener, V> removalListener; + + private final List dimensionNames; // Serializers for keys and values. Not required for all caches. private final Serializer keySerializer; @@ -72,6 +76,7 @@ private CacheConfig(Builder builder) { this.weigher = builder.weigher; this.keySerializer = builder.keySerializer; this.valueSerializer = builder.valueSerializer; + this.dimensionNames = builder.dimensionNames; this.cachedResultParser = builder.cachedResultParser; this.maxSizeInBytes = builder.maxSizeInBytes; this.expireAfterAccess = builder.expireAfterAccess; @@ -90,7 +95,7 @@ public Settings getSettings() { return settings; } - public RemovalListener getRemovalListener() { + public RemovalListener, V> getRemovalListener() { return removalListener; } @@ -102,7 +107,7 @@ public RemovalListener getRemovalListener() { return valueSerializer; } - public ToLongBiFunction getWeigher() { + public ToLongBiFunction, V> getWeigher() { return weigher; } @@ -110,6 +115,10 @@ public Function getCachedResultParser() { return cachedResultParser; } + public List getDimensionNames() { + return dimensionNames; + } + public Long getMaxSizeInBytes() { return maxSizeInBytes; } @@ -135,12 +144,11 @@ public static class Builder { private Class valueType; - private RemovalListener removalListener; - + private RemovalListener, V> removalListener; + private List dimensionNames; private Serializer keySerializer; private Serializer valueSerializer; - - private ToLongBiFunction weigher; + private ToLongBiFunction, V> weigher; private Function cachedResultParser; private long maxSizeInBytes; @@ -165,11 +173,16 @@ public Builder setValueType(Class valueType) { return this; } - public Builder setRemovalListener(RemovalListener removalListener) { + public Builder setRemovalListener(RemovalListener, V> removalListener) { this.removalListener = removalListener; return this; } + public Builder setWeigher(ToLongBiFunction, V> weigher) { + this.weigher = weigher; + return this; + } + public Builder setKeySerializer(Serializer keySerializer) { this.keySerializer = keySerializer; return this; @@ -180,8 +193,8 @@ public Builder setValueSerializer(Serializer valueSerializer) { return this; } - public Builder setWeigher(ToLongBiFunction weigher) { - this.weigher = weigher; + public Builder setDimensionNames(List dimensionNames) { + this.dimensionNames = dimensionNames; return this; } diff --git a/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java b/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java index 607ff721bd357..eab772cda3213 100644 --- a/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java +++ b/server/src/main/java/org/opensearch/indices/IndicesRequestCache.java @@ -43,12 +43,14 @@ import org.opensearch.common.CheckedSupplier; import org.opensearch.common.cache.CacheType; import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.LoadAwareCacheLoader; import org.opensearch.common.cache.RemovalListener; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.policy.CachedQueryResult; import org.opensearch.common.cache.serializer.BytesReferenceSerializer; import org.opensearch.common.cache.service.CacheService; +import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder; import org.opensearch.common.cache.store.config.CacheConfig; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lucene.index.OpenSearchDirectoryReader; @@ -74,6 +76,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -100,7 +103,7 @@ * * @opensearch.internal */ -public final class IndicesRequestCache implements RemovalListener, Closeable { +public final class IndicesRequestCache implements RemovalListener, BytesReference>, Closeable { private static final Logger logger = LogManager.getLogger(IndicesRequestCache.class); @@ -146,6 +149,10 @@ public final class IndicesRequestCache implements RemovalListener> cacheEntityFunction, @@ -156,7 +163,7 @@ public final class IndicesRequestCache implements RemovalListener weigher = (k, v) -> k.ramBytesUsed() + v.ramBytesUsed(); + ToLongBiFunction, BytesReference> weigher = (k, v) -> k.ramBytesUsed(k.key.ramBytesUsed()) + v.ramBytesUsed(); this.cacheCleanupManager = new IndicesRequestCacheCleanupManager( threadPool, INDICES_REQUEST_CACHE_CLEAN_INTERVAL_SETTING.get(settings), @@ -171,6 +178,7 @@ public final class IndicesRequestCache implements RemovalListener { try { return CachedQueryResult.getPolicyValues(bytesReference); @@ -205,16 +213,39 @@ void clear(CacheEntity entity) { } @Override - public void onRemoval(RemovalNotification notification) { + public void onRemoval(RemovalNotification, BytesReference> notification) { // In case this event happens for an old shard, we can safely ignore this as we don't keep track for old // shards as part of request cache. - Key key = notification.getKey(); - cacheEntityLookup.apply(key.shardId).ifPresent(entity -> entity.onRemoval(notification)); + + // Pass a new removal notification containing Key rather than ICacheKey to the CacheEntity for backwards compatibility. + Key key = notification.getKey().key; + RemovalNotification newNotification = new RemovalNotification<>( + key, + notification.getValue(), + notification.getRemovalReason() + ); + + cacheEntityLookup.apply(key.shardId).ifPresent(entity -> entity.onRemoval(newNotification)); cacheCleanupManager.updateCleanupKeyToCountMapOnCacheEviction( new CleanupKey(cacheEntityLookup.apply(key.shardId).orElse(null), key.readerCacheKeyId) ); } + private ICacheKey getICacheKey(Key key) { + String indexDimensionValue = getIndexDimensionName(key); + String shardIdDimensionValue = getShardIdDimensionName(key); + List dimensions = List.of(indexDimensionValue, shardIdDimensionValue); + return new ICacheKey<>(key, dimensions); + } + + private String getShardIdDimensionName(Key key) { + return key.shardId.toString(); + } + + private String getIndexDimensionName(Key key) { + return key.shardId.getIndexName(); + } + BytesReference getOrCompute( IndicesService.IndexShardCacheEntity cacheEntity, CheckedSupplier loader, @@ -230,7 +261,7 @@ BytesReference getOrCompute( assert readerCacheKeyId != null; final Key key = new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId); Loader cacheLoader = new Loader(cacheEntity, loader); - BytesReference value = cache.computeIfAbsent(key, cacheLoader); + BytesReference value = cache.computeIfAbsent(getICacheKey(key), cacheLoader); if (cacheLoader.isLoaded()) { cacheEntity.onMiss(); // see if it's the first time we see this reader, and make sure to register a cleanup key @@ -261,7 +292,7 @@ void invalidate(IndicesService.IndexShardCacheEntity cacheEntity, DirectoryReade IndexReader.CacheHelper cacheHelper = ((OpenSearchDirectoryReader) reader).getDelegatingCacheHelper(); readerCacheKeyId = ((OpenSearchDirectoryReader.DelegatingCacheHelper) cacheHelper).getDelegatingCacheKey().getId(); } - cache.invalidate(new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId)); + cache.invalidate(getICacheKey(new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId))); } /** @@ -269,7 +300,7 @@ void invalidate(IndicesService.IndexShardCacheEntity cacheEntity, DirectoryReade * * @opensearch.internal */ - private static class Loader implements LoadAwareCacheLoader { + private static class Loader implements LoadAwareCacheLoader, BytesReference> { private final CacheEntity entity; private final CheckedSupplier loader; @@ -285,9 +316,9 @@ public boolean isLoaded() { } @Override - public BytesReference load(Key key) throws Exception { + public BytesReference load(ICacheKey key) throws Exception { BytesReference value = loader.get(); - entity.onCached(key, value); + entity.onCached(key.key, value); loaded = true; return value; } @@ -603,7 +634,8 @@ private synchronized void cleanCache(double stalenessThreshold) { iterator.remove(); if (cleanupKey.readerCacheKeyId == null || !cleanupKey.entity.isOpen()) { // null indicates full cleanup, as does a closed shard - cleanupKeysFromClosedShards.add(((IndexShard) cleanupKey.entity.getCacheIdentity()).shardId()); + ShardId shardId = ((IndexShard) cleanupKey.entity.getCacheIdentity()).shardId(); + cleanupKeysFromClosedShards.add(shardId); } else { cleanupKeysFromOutdatedReaders.add(cleanupKey); } @@ -613,17 +645,27 @@ private synchronized void cleanCache(double stalenessThreshold) { return; } - for (Iterator iterator = cache.keys().iterator(); iterator.hasNext();) { - Key key = iterator.next(); - if (cleanupKeysFromClosedShards.contains(key.shardId)) { + Set> dimensionListsToDrop = new HashSet<>(); + + for (Iterator> iterator = cache.keys().iterator(); iterator.hasNext();) { + ICacheKey key = iterator.next(); + if (cleanupKeysFromClosedShards.contains(key.key.shardId)) { + // Since the shard is closed, the cache should drop stats for this shard. + dimensionListsToDrop.add(key.dimensions); iterator.remove(); } else { - CleanupKey cleanupKey = new CleanupKey(cacheEntityLookup.apply(key.shardId).orElse(null), key.readerCacheKeyId); + CleanupKey cleanupKey = new CleanupKey(cacheEntityLookup.apply(key.key.shardId).orElse(null), key.key.readerCacheKeyId); if (cleanupKeysFromOutdatedReaders.contains(cleanupKey)) { iterator.remove(); } } } + for (List closedDimensions : dimensionListsToDrop) { + // Invalidate a dummy key containing the dimensions we need to drop stats for + ICacheKey dummyKey = new ICacheKey<>(null, closedDimensions); + dummyKey.setDropStatsForDimensions(true); + cache.invalidate(dummyKey); + } cache.refresh(); } @@ -714,6 +756,20 @@ long count() { return cache.count(); } + /** + * Returns the current size in bytes of the cache + */ + long getSizeInBytes() { + return cache.stats().getTotalSizeInBytes(); + } + + /** + * Returns the current cache stats. Pkg-private for testing. + */ + ImmutableCacheStatsHolder stats() { + return cache.stats(); + } + int numRegisteredCloseListeners() { // for testing return registeredClosedListeners.size(); } diff --git a/server/src/test/java/org/opensearch/common/cache/serializer/ICacheKeySerializerTests.java b/server/src/test/java/org/opensearch/common/cache/serializer/ICacheKeySerializerTests.java new file mode 100644 index 0000000000000..7713fdf1d0adc --- /dev/null +++ b/server/src/test/java/org/opensearch/common/cache/serializer/ICacheKeySerializerTests.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.serializer; + +import org.opensearch.OpenSearchException; +import org.opensearch.common.Randomness; +import org.opensearch.common.cache.ICacheKey; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.UUID; + +public class ICacheKeySerializerTests extends OpenSearchTestCase { + // For these tests, we use BytesReference as K, since we already have a Serializer implementation + public void testEquality() throws Exception { + BytesReferenceSerializer keySer = new BytesReferenceSerializer(); + ICacheKeySerializer serializer = new ICacheKeySerializer<>(keySer); + + int numDimensionsTested = 100; + for (int i = 0; i < numDimensionsTested; i++) { + String dim = getRandomDimValue(); + ICacheKey key = new ICacheKey<>(getRandomBytesReference(), List.of(dim)); + byte[] serialized = serializer.serialize(key); + assertTrue(serializer.equals(key, serialized)); + ICacheKey deserialized = serializer.deserialize(serialized); + assertEquals(key, deserialized); + assertTrue(serializer.equals(deserialized, serialized)); + } + } + + public void testInvalidInput() throws Exception { + BytesReferenceSerializer keySer = new BytesReferenceSerializer(); + ICacheKeySerializer serializer = new ICacheKeySerializer<>(keySer); + + Random rand = Randomness.get(); + byte[] randomInput = new byte[1000]; + rand.nextBytes(randomInput); + + assertThrows(OpenSearchException.class, () -> serializer.deserialize(randomInput)); + } + + public void testDimNumbers() throws Exception { + BytesReferenceSerializer keySer = new BytesReferenceSerializer(); + ICacheKeySerializer serializer = new ICacheKeySerializer<>(keySer); + + for (int numDims : new int[] { 0, 5, 1000 }) { + List dims = new ArrayList<>(); + for (int j = 0; j < numDims; j++) { + dims.add(getRandomDimValue()); + } + ICacheKey key = new ICacheKey<>(getRandomBytesReference(), dims); + byte[] serialized = serializer.serialize(key); + assertTrue(serializer.equals(key, serialized)); + ICacheKey deserialized = serializer.deserialize(serialized); + assertEquals(key, deserialized); + } + } + + public void testHashCodes() throws Exception { + ICacheKey key1 = new ICacheKey<>("key", List.of("dimension_value")); + ICacheKey key2 = new ICacheKey<>("key", List.of("dimension_value")); + + ICacheKey key3 = new ICacheKey<>(null, List.of("dimension_value")); + ICacheKey key4 = new ICacheKey<>(null, List.of("dimension_value")); + + assertEquals(key1, key2); + assertEquals(key1.hashCode(), key2.hashCode()); + + assertEquals(key3, key4); + assertEquals(key3.hashCode(), key4.hashCode()); + + assertNotEquals(key1, key3); + assertNotEquals("string", key3); + } + + public void testNullInputs() throws Exception { + BytesReferenceSerializer keySer = new BytesReferenceSerializer(); + ICacheKeySerializer serializer = new ICacheKeySerializer<>(keySer); + + assertNull(serializer.deserialize(null)); + ICacheKey nullKey = new ICacheKey<>(null, List.of(getRandomDimValue())); + assertNull(serializer.serialize(nullKey)); + assertNull(serializer.serialize(null)); + assertNull(serializer.serialize(new ICacheKey<>(getRandomBytesReference(), null))); + } + + private String getRandomDimValue() { + return UUID.randomUUID().toString(); + } + + private BytesReference getRandomBytesReference() { + byte[] bytesValue = new byte[1000]; + Random rand = Randomness.get(); + rand.nextBytes(bytesValue); + return new BytesArray(bytesValue); + } +} diff --git a/server/src/test/java/org/opensearch/common/cache/stats/CacheStatsHolderTests.java b/server/src/test/java/org/opensearch/common/cache/stats/CacheStatsHolderTests.java new file mode 100644 index 0000000000000..390cd4d601a4b --- /dev/null +++ b/server/src/test/java/org/opensearch/common/cache/stats/CacheStatsHolderTests.java @@ -0,0 +1,287 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import org.opensearch.common.Randomness; +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; + +public class CacheStatsHolderTests extends OpenSearchTestCase { + public void testAddAndGet() throws Exception { + List dimensionNames = List.of("dim1", "dim2", "dim3", "dim4"); + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(dimensionNames); + Map> usedDimensionValues = CacheStatsHolderTests.getUsedDimensionValues(cacheStatsHolder, 10); + Map, CacheStats> expected = CacheStatsHolderTests.populateStats(cacheStatsHolder, usedDimensionValues, 1000, 10); + + // test the value in the map is as expected for each distinct combination of values + for (List dimensionValues : expected.keySet()) { + CacheStats expectedCounter = expected.get(dimensionValues); + + ImmutableCacheStats actualStatsHolder = CacheStatsHolderTests.getNode(dimensionValues, cacheStatsHolder.getStatsRoot()) + .getImmutableStats(); + ImmutableCacheStats actualCacheStats = getNode(dimensionValues, cacheStatsHolder.getStatsRoot()).getImmutableStats(); + + assertEquals(expectedCounter.immutableSnapshot(), actualStatsHolder); + assertEquals(expectedCounter.immutableSnapshot(), actualCacheStats); + } + + // Check overall total matches + CacheStats expectedTotal = new CacheStats(); + for (List dims : expected.keySet()) { + expectedTotal.add(expected.get(dims)); + } + assertEquals(expectedTotal.immutableSnapshot(), cacheStatsHolder.getStatsRoot().getImmutableStats()); + + // Check sum of children stats are correct + assertSumOfChildrenStats(cacheStatsHolder.getStatsRoot()); + } + + public void testReset() throws Exception { + List dimensionNames = List.of("dim1", "dim2"); + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(dimensionNames); + Map> usedDimensionValues = getUsedDimensionValues(cacheStatsHolder, 10); + Map, CacheStats> expected = populateStats(cacheStatsHolder, usedDimensionValues, 100, 10); + + cacheStatsHolder.reset(); + + for (List dimensionValues : expected.keySet()) { + CacheStats originalCounter = expected.get(dimensionValues); + originalCounter.sizeInBytes = new CounterMetric(); + originalCounter.entries = new CounterMetric(); + + CacheStatsHolder.Node node = getNode(dimensionValues, cacheStatsHolder.getStatsRoot()); + ImmutableCacheStats actual = node.getImmutableStats(); + assertEquals(originalCounter.immutableSnapshot(), actual); + } + } + + public void testDropStatsForDimensions() throws Exception { + List dimensionNames = List.of("dim1", "dim2"); + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(dimensionNames); + + // Create stats for the following dimension sets + List> populatedStats = List.of(List.of("A1", "B1"), List.of("A2", "B2"), List.of("A2", "B3")); + for (List dims : populatedStats) { + cacheStatsHolder.incrementHits(dims); + } + + assertEquals(3, cacheStatsHolder.getStatsRoot().getImmutableStats().getHits()); + + // When we invalidate A2, B2, we should lose the node for B2, but not B3 or A2. + + cacheStatsHolder.removeDimensions(List.of("A2", "B2")); + + assertEquals(2, cacheStatsHolder.getStatsRoot().getImmutableStats().getHits()); + assertNull(getNode(List.of("A2", "B2"), cacheStatsHolder.getStatsRoot())); + assertNotNull(getNode(List.of("A2"), cacheStatsHolder.getStatsRoot())); + assertNotNull(getNode(List.of("A2", "B3"), cacheStatsHolder.getStatsRoot())); + + // When we invalidate A1, B1, we should lose the nodes for B1 and also A1, as it has no more children. + + cacheStatsHolder.removeDimensions(List.of("A1", "B1")); + + assertEquals(1, cacheStatsHolder.getStatsRoot().getImmutableStats().getHits()); + assertNull(getNode(List.of("A1", "B1"), cacheStatsHolder.getStatsRoot())); + assertNull(getNode(List.of("A1"), cacheStatsHolder.getStatsRoot())); + + // When we invalidate the last node, all nodes should be deleted except the root node + + cacheStatsHolder.removeDimensions(List.of("A2", "B3")); + assertEquals(0, cacheStatsHolder.getStatsRoot().getImmutableStats().getHits()); + assertEquals(0, cacheStatsHolder.getStatsRoot().children.size()); + } + + public void testCount() throws Exception { + List dimensionNames = List.of("dim1", "dim2"); + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(dimensionNames); + Map> usedDimensionValues = getUsedDimensionValues(cacheStatsHolder, 10); + Map, CacheStats> expected = populateStats(cacheStatsHolder, usedDimensionValues, 100, 10); + + long expectedCount = 0L; + for (CacheStats counter : expected.values()) { + expectedCount += counter.getEntries(); + } + assertEquals(expectedCount, cacheStatsHolder.count()); + } + + public void testConcurrentRemoval() throws Exception { + List dimensionNames = List.of("dim1", "dim2"); + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(dimensionNames); + + // Create stats for the following dimension sets + List> populatedStats = List.of(List.of("A1", "B1"), List.of("A2", "B2"), List.of("A2", "B3")); + for (List dims : populatedStats) { + cacheStatsHolder.incrementHits(dims); + } + + // Remove (A2, B2) and (A1, B1), before re-adding (A2, B2). At the end we should have stats for (A2, B2) but not (A1, B1). + + Thread[] threads = new Thread[3]; + CountDownLatch countDownLatch = new CountDownLatch(3); + threads[0] = new Thread(() -> { + cacheStatsHolder.removeDimensions(List.of("A2", "B2")); + countDownLatch.countDown(); + }); + threads[1] = new Thread(() -> { + cacheStatsHolder.removeDimensions(List.of("A1", "B1")); + countDownLatch.countDown(); + }); + threads[2] = new Thread(() -> { + cacheStatsHolder.incrementMisses(List.of("A2", "B2")); + cacheStatsHolder.incrementMisses(List.of("A2", "B3")); + countDownLatch.countDown(); + }); + for (Thread thread : threads) { + thread.start(); + // Add short sleep to ensure threads start their functions in order (so that incrementing doesn't happen before removal) + Thread.sleep(1); + } + countDownLatch.await(); + assertNull(getNode(List.of("A1", "B1"), cacheStatsHolder.getStatsRoot())); + assertNull(getNode(List.of("A1"), cacheStatsHolder.getStatsRoot())); + assertNotNull(getNode(List.of("A2", "B2"), cacheStatsHolder.getStatsRoot())); + assertEquals( + new ImmutableCacheStats(0, 1, 0, 0, 0), + getNode(List.of("A2", "B2"), cacheStatsHolder.getStatsRoot()).getImmutableStats() + ); + assertEquals( + new ImmutableCacheStats(1, 1, 0, 0, 0), + getNode(List.of("A2", "B3"), cacheStatsHolder.getStatsRoot()).getImmutableStats() + ); + } + + /** + * Returns the node found by following these dimension values down from the root node. + * Returns null if no such node exists. + */ + static CacheStatsHolder.Node getNode(List dimensionValues, CacheStatsHolder.Node root) { + CacheStatsHolder.Node current = root; + for (String dimensionValue : dimensionValues) { + current = current.getChildren().get(dimensionValue); + if (current == null) { + return null; + } + } + return current; + } + + static Map, CacheStats> populateStats( + CacheStatsHolder cacheStatsHolder, + Map> usedDimensionValues, + int numDistinctValuePairs, + int numRepetitionsPerValue + ) throws InterruptedException { + Map, CacheStats> expected = new ConcurrentHashMap<>(); + Thread[] threads = new Thread[numDistinctValuePairs]; + CountDownLatch countDownLatch = new CountDownLatch(numDistinctValuePairs); + Random rand = Randomness.get(); + List> dimensionsForThreads = new ArrayList<>(); + for (int i = 0; i < numDistinctValuePairs; i++) { + dimensionsForThreads.add(getRandomDimList(cacheStatsHolder.getDimensionNames(), usedDimensionValues, true, rand)); + int finalI = i; + threads[i] = new Thread(() -> { + Random threadRand = Randomness.get(); + List dimensions = dimensionsForThreads.get(finalI); + expected.computeIfAbsent(dimensions, (key) -> new CacheStats()); + for (int j = 0; j < numRepetitionsPerValue; j++) { + CacheStats statsToInc = new CacheStats( + threadRand.nextInt(10), + threadRand.nextInt(10), + threadRand.nextInt(10), + threadRand.nextInt(5000), + threadRand.nextInt(10) + ); + expected.get(dimensions).hits.inc(statsToInc.getHits()); + expected.get(dimensions).misses.inc(statsToInc.getMisses()); + expected.get(dimensions).evictions.inc(statsToInc.getEvictions()); + expected.get(dimensions).sizeInBytes.inc(statsToInc.getSizeInBytes()); + expected.get(dimensions).entries.inc(statsToInc.getEntries()); + CacheStatsHolderTests.populateStatsHolderFromStatsValueMap(cacheStatsHolder, Map.of(dimensions, statsToInc)); + } + countDownLatch.countDown(); + }); + } + for (Thread thread : threads) { + thread.start(); + } + countDownLatch.await(); + return expected; + } + + private static List getRandomDimList( + List dimensionNames, + Map> usedDimensionValues, + boolean pickValueForAllDims, + Random rand + ) { + List result = new ArrayList<>(); + for (String dimName : dimensionNames) { + if (pickValueForAllDims || rand.nextBoolean()) { // if pickValueForAllDims, always pick a value for each dimension, otherwise do + // so 50% of the time + int index = between(0, usedDimensionValues.get(dimName).size() - 1); + result.add(usedDimensionValues.get(dimName).get(index)); + } + } + return result; + } + + static Map> getUsedDimensionValues(CacheStatsHolder cacheStatsHolder, int numValuesPerDim) { + Map> usedDimensionValues = new HashMap<>(); + for (int i = 0; i < cacheStatsHolder.getDimensionNames().size(); i++) { + List values = new ArrayList<>(); + for (int j = 0; j < numValuesPerDim; j++) { + values.add(UUID.randomUUID().toString()); + } + usedDimensionValues.put(cacheStatsHolder.getDimensionNames().get(i), values); + } + return usedDimensionValues; + } + + private void assertSumOfChildrenStats(CacheStatsHolder.Node current) { + if (!current.children.isEmpty()) { + CacheStats expectedTotal = new CacheStats(); + for (CacheStatsHolder.Node child : current.children.values()) { + expectedTotal.add(child.getImmutableStats()); + } + assertEquals(expectedTotal.immutableSnapshot(), current.getImmutableStats()); + for (CacheStatsHolder.Node child : current.children.values()) { + assertSumOfChildrenStats(child); + } + } + } + + static void populateStatsHolderFromStatsValueMap(CacheStatsHolder cacheStatsHolder, Map, CacheStats> statsMap) { + for (Map.Entry, CacheStats> entry : statsMap.entrySet()) { + CacheStats stats = entry.getValue(); + List dims = entry.getKey(); + for (int i = 0; i < stats.getHits(); i++) { + cacheStatsHolder.incrementHits(dims); + } + for (int i = 0; i < stats.getMisses(); i++) { + cacheStatsHolder.incrementMisses(dims); + } + for (int i = 0; i < stats.getEvictions(); i++) { + cacheStatsHolder.incrementEvictions(dims); + } + cacheStatsHolder.incrementSizeInBytes(dims, stats.getSizeInBytes()); + for (int i = 0; i < stats.getEntries(); i++) { + cacheStatsHolder.incrementEntries(dims); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolderTests.java b/server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolderTests.java new file mode 100644 index 0000000000000..933b8abd6e392 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsHolderTests.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; + +public class ImmutableCacheStatsHolderTests extends OpenSearchTestCase { + + public void testGet() throws Exception { + List dimensionNames = List.of("dim1", "dim2", "dim3", "dim4"); + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(dimensionNames); + Map> usedDimensionValues = CacheStatsHolderTests.getUsedDimensionValues(cacheStatsHolder, 10); + Map, CacheStats> expected = CacheStatsHolderTests.populateStats(cacheStatsHolder, usedDimensionValues, 1000, 10); + ImmutableCacheStatsHolder stats = cacheStatsHolder.getImmutableCacheStatsHolder(); + + // test the value in the map is as expected for each distinct combination of values + for (List dimensionValues : expected.keySet()) { + CacheStats expectedCounter = expected.get(dimensionValues); + + ImmutableCacheStats actualCacheStatsHolder = CacheStatsHolderTests.getNode(dimensionValues, cacheStatsHolder.getStatsRoot()) + .getImmutableStats(); + ImmutableCacheStats actualImmutableCacheStatsHolder = getNode(dimensionValues, stats.getStatsRoot()).getStats(); + + assertEquals(expectedCounter.immutableSnapshot(), actualCacheStatsHolder); + assertEquals(expectedCounter.immutableSnapshot(), actualImmutableCacheStatsHolder); + } + + // test gets for total (this also checks sum-of-children logic) + CacheStats expectedTotal = new CacheStats(); + for (List dims : expected.keySet()) { + expectedTotal.add(expected.get(dims)); + } + assertEquals(expectedTotal.immutableSnapshot(), stats.getTotalStats()); + + assertEquals(expectedTotal.getHits(), stats.getTotalHits()); + assertEquals(expectedTotal.getMisses(), stats.getTotalMisses()); + assertEquals(expectedTotal.getEvictions(), stats.getTotalEvictions()); + assertEquals(expectedTotal.getSizeInBytes(), stats.getTotalSizeInBytes()); + assertEquals(expectedTotal.getEntries(), stats.getTotalEntries()); + + assertSumOfChildrenStats(stats.getStatsRoot()); + } + + public void testEmptyDimsList() throws Exception { + // If the dimension list is empty, the tree should have only the root node containing the total stats. + CacheStatsHolder cacheStatsHolder = new CacheStatsHolder(List.of()); + Map> usedDimensionValues = CacheStatsHolderTests.getUsedDimensionValues(cacheStatsHolder, 100); + CacheStatsHolderTests.populateStats(cacheStatsHolder, usedDimensionValues, 10, 100); + ImmutableCacheStatsHolder stats = cacheStatsHolder.getImmutableCacheStatsHolder(); + + ImmutableCacheStatsHolder.Node statsRoot = stats.getStatsRoot(); + assertEquals(0, statsRoot.children.size()); + assertEquals(stats.getTotalStats(), statsRoot.getStats()); + } + + private ImmutableCacheStatsHolder.Node getNode(List dimensionValues, ImmutableCacheStatsHolder.Node root) { + ImmutableCacheStatsHolder.Node current = root; + for (String dimensionValue : dimensionValues) { + current = current.getChildren().get(dimensionValue); + if (current == null) { + return null; + } + } + return current; + } + + private void assertSumOfChildrenStats(ImmutableCacheStatsHolder.Node current) { + if (!current.children.isEmpty()) { + CacheStats expectedTotal = new CacheStats(); + for (ImmutableCacheStatsHolder.Node child : current.children.values()) { + expectedTotal.add(child.getStats()); + } + assertEquals(expectedTotal.immutableSnapshot(), current.getStats()); + for (ImmutableCacheStatsHolder.Node child : current.children.values()) { + assertSumOfChildrenStats(child); + } + } + } +} diff --git a/server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsTests.java b/server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsTests.java new file mode 100644 index 0000000000000..50ddd81943c3b --- /dev/null +++ b/server/src/test/java/org/opensearch/common/cache/stats/ImmutableCacheStatsTests.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.stats; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.test.OpenSearchTestCase; + +public class ImmutableCacheStatsTests extends OpenSearchTestCase { + public void testSerialization() throws Exception { + ImmutableCacheStats immutableCacheStats = new ImmutableCacheStats(1, 2, 3, 4, 5); + BytesStreamOutput os = new BytesStreamOutput(); + immutableCacheStats.writeTo(os); + BytesStreamInput is = new BytesStreamInput(BytesReference.toBytes(os.bytes())); + ImmutableCacheStats deserialized = new ImmutableCacheStats(is); + + assertEquals(immutableCacheStats, deserialized); + } + + public void testAddSnapshots() throws Exception { + ImmutableCacheStats ics1 = new ImmutableCacheStats(1, 2, 3, 4, 5); + ImmutableCacheStats ics2 = new ImmutableCacheStats(6, 7, 8, 9, 10); + ImmutableCacheStats expected = new ImmutableCacheStats(7, 9, 11, 13, 15); + assertEquals(expected, ImmutableCacheStats.addSnapshots(ics1, ics2)); + } + + public void testEqualsAndHash() throws Exception { + ImmutableCacheStats ics1 = new ImmutableCacheStats(1, 2, 3, 4, 5); + ImmutableCacheStats ics2 = new ImmutableCacheStats(1, 2, 3, 4, 5); + ImmutableCacheStats ics3 = new ImmutableCacheStats(0, 2, 3, 4, 5); + + assertEquals(ics1, ics2); + assertNotEquals(ics1, ics3); + assertNotEquals(ics1, null); + assertNotEquals(ics1, "string"); + + assertEquals(ics1.hashCode(), ics2.hashCode()); + assertNotEquals(ics1.hashCode(), ics3.hashCode()); + } +} diff --git a/server/src/test/java/org/opensearch/common/cache/store/OpenSearchOnHeapCacheTests.java b/server/src/test/java/org/opensearch/common/cache/store/OpenSearchOnHeapCacheTests.java new file mode 100644 index 0000000000000..008dc7c2e0902 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/cache/store/OpenSearchOnHeapCacheTests.java @@ -0,0 +1,181 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.cache.store; + +import org.opensearch.common.Randomness; +import org.opensearch.common.cache.CacheType; +import org.opensearch.common.cache.ICache; +import org.opensearch.common.cache.ICacheKey; +import org.opensearch.common.cache.LoadAwareCacheLoader; +import org.opensearch.common.cache.RemovalListener; +import org.opensearch.common.cache.RemovalNotification; +import org.opensearch.common.cache.stats.ImmutableCacheStats; +import org.opensearch.common.cache.store.config.CacheConfig; +import org.opensearch.common.cache.store.settings.OpenSearchOnHeapCacheSettings; +import org.opensearch.common.metrics.CounterMetric; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.UUID; + +import static org.opensearch.common.cache.store.settings.OpenSearchOnHeapCacheSettings.MAXIMUM_SIZE_IN_BYTES_KEY; + +public class OpenSearchOnHeapCacheTests extends OpenSearchTestCase { + private final static long keyValueSize = 50; + private final static List dimensionNames = List.of("dim1", "dim2", "dim3"); + + public void testStats() throws Exception { + MockRemovalListener listener = new MockRemovalListener<>(); + int maxKeys = between(10, 50); + int numEvicted = between(10, 20); + OpenSearchOnHeapCache cache = getCache(maxKeys, listener); + + List> keysAdded = new ArrayList<>(); + int numAdded = maxKeys + numEvicted; + for (int i = 0; i < numAdded; i++) { + ICacheKey key = getICacheKey(UUID.randomUUID().toString()); + keysAdded.add(key); + cache.computeIfAbsent(key, getLoadAwareCacheLoader()); + + assertEquals(i + 1, cache.stats().getTotalMisses()); + assertEquals(0, cache.stats().getTotalHits()); + assertEquals(Math.min(maxKeys, i + 1), cache.stats().getTotalEntries()); + assertEquals(Math.min(maxKeys, i + 1) * keyValueSize, cache.stats().getTotalSizeInBytes()); + assertEquals(Math.max(0, i + 1 - maxKeys), cache.stats().getTotalEvictions()); + } + // do gets from the last part of the list, which should be hits + for (int i = numAdded - maxKeys; i < numAdded; i++) { + cache.computeIfAbsent(keysAdded.get(i), getLoadAwareCacheLoader()); + int numHits = i + 1 - (numAdded - maxKeys); + + assertEquals(numAdded, cache.stats().getTotalMisses()); + assertEquals(numHits, cache.stats().getTotalHits()); + assertEquals(maxKeys, cache.stats().getTotalEntries()); + assertEquals(maxKeys * keyValueSize, cache.stats().getTotalSizeInBytes()); + assertEquals(numEvicted, cache.stats().getTotalEvictions()); + } + + // invalidate keys + for (int i = numAdded - maxKeys; i < numAdded; i++) { + cache.invalidate(keysAdded.get(i)); + int numInvalidated = i + 1 - (numAdded - maxKeys); + + assertEquals(numAdded, cache.stats().getTotalMisses()); + assertEquals(maxKeys, cache.stats().getTotalHits()); + assertEquals(maxKeys - numInvalidated, cache.stats().getTotalEntries()); + assertEquals((maxKeys - numInvalidated) * keyValueSize, cache.stats().getTotalSizeInBytes()); + assertEquals(numEvicted, cache.stats().getTotalEvictions()); + } + } + + private OpenSearchOnHeapCache getCache(int maxSizeKeys, MockRemovalListener listener) { + ICache.Factory onHeapCacheFactory = new OpenSearchOnHeapCache.OpenSearchOnHeapCacheFactory(); + Settings settings = Settings.builder() + .put( + OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE) + .get(MAXIMUM_SIZE_IN_BYTES_KEY) + .getKey(), + maxSizeKeys * keyValueSize + "b" + ) + .build(); + + CacheConfig cacheConfig = new CacheConfig.Builder().setKeyType(String.class) + .setValueType(String.class) + .setWeigher((k, v) -> keyValueSize) + .setRemovalListener(listener) + .setSettings(settings) + .setDimensionNames(dimensionNames) + .setMaxSizeInBytes(maxSizeKeys * keyValueSize) + .build(); + return (OpenSearchOnHeapCache) onHeapCacheFactory.create(cacheConfig, CacheType.INDICES_REQUEST_CACHE, null); + } + + public void testInvalidateWithDropDimensions() throws Exception { + MockRemovalListener listener = new MockRemovalListener<>(); + int maxKeys = 50; + OpenSearchOnHeapCache cache = getCache(maxKeys, listener); + + List> keysAdded = new ArrayList<>(); + + for (int i = 0; i < maxKeys - 5; i++) { + ICacheKey key = new ICacheKey<>(UUID.randomUUID().toString(), getRandomDimensions()); + keysAdded.add(key); + cache.computeIfAbsent(key, getLoadAwareCacheLoader()); + } + + ICacheKey keyToDrop = keysAdded.get(0); + + ImmutableCacheStats snapshot = cache.stats().getStatsForDimensionValues(keyToDrop.dimensions); + assertNotNull(snapshot); + + keyToDrop.setDropStatsForDimensions(true); + cache.invalidate(keyToDrop); + + // Now assert the stats are gone for any key that has this combination of dimensions, but still there otherwise + for (ICacheKey keyAdded : keysAdded) { + snapshot = cache.stats().getStatsForDimensionValues(keyAdded.dimensions); + if (keyAdded.dimensions.equals(keyToDrop.dimensions)) { + assertNull(snapshot); + } else { + assertNotNull(snapshot); + } + } + } + + private List getRandomDimensions() { + Random rand = Randomness.get(); + int bound = 3; + List result = new ArrayList<>(); + for (String dimName : dimensionNames) { + result.add(String.valueOf(rand.nextInt(bound))); + } + return result; + } + + private static class MockRemovalListener implements RemovalListener, V> { + CounterMetric numRemovals; + + MockRemovalListener() { + numRemovals = new CounterMetric(); + } + + @Override + public void onRemoval(RemovalNotification, V> notification) { + numRemovals.inc(); + } + } + + private ICacheKey getICacheKey(String key) { + List dims = new ArrayList<>(); + for (String dimName : dimensionNames) { + dims.add("0"); + } + return new ICacheKey<>(key, dims); + } + + private LoadAwareCacheLoader, String> getLoadAwareCacheLoader() { + return new LoadAwareCacheLoader<>() { + boolean isLoaded = false; + + @Override + public String load(ICacheKey key) { + isLoaded = true; + return UUID.randomUUID().toString(); + } + + @Override + public boolean isLoaded() { + return isLoaded; + } + }; + } +} diff --git a/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java b/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java index 68da79a7fda84..e3dca1b7bfda2 100644 --- a/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java +++ b/server/src/test/java/org/opensearch/indices/IndicesRequestCacheTests.java @@ -45,11 +45,14 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.CheckedSupplier; +import org.opensearch.common.cache.ICacheKey; import org.opensearch.common.cache.RemovalNotification; import org.opensearch.common.cache.RemovalReason; import org.opensearch.common.cache.module.CacheModule; import org.opensearch.common.cache.service.CacheService; +import org.opensearch.common.cache.stats.ImmutableCacheStats; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.lucene.index.OpenSearchDirectoryReader; import org.opensearch.common.settings.Settings; @@ -69,6 +72,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardState; +import org.opensearch.index.shard.ShardNotFoundException; import org.opensearch.node.Node; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -77,6 +81,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; @@ -553,7 +558,13 @@ public void testStaleCount_OnRemovalNotificationOfStaleKey_DecrementsStaleCount( readerCacheKeyId ); - cache.onRemoval(new RemovalNotification(key, termBytes, RemovalReason.EVICTED)); + cache.onRemoval( + new RemovalNotification, BytesReference>( + new ICacheKey<>(key), + termBytes, + RemovalReason.EVICTED + ) + ); staleKeysCount = cache.cacheCleanupManager.getStaleKeysCount(); // eviction of previous stale key from the cache should decrement staleKeysCount in iRC assertEquals(0, staleKeysCount.get()); @@ -630,7 +641,13 @@ public void testStaleCount_OnRemovalNotificationOfStaleKey_DoesNotDecrementsStal readerCacheKeyId ); - cache.onRemoval(new RemovalNotification(key, termBytes, RemovalReason.EVICTED)); + cache.onRemoval( + new RemovalNotification, BytesReference>( + new ICacheKey<>(key), + termBytes, + RemovalReason.EVICTED + ) + ); staleKeysCount = cache.cacheCleanupManager.getStaleKeysCount(); // eviction of NON-stale key from the cache should NOT decrement staleKeysCount in iRC assertEquals(1, staleKeysCount.get()); @@ -771,6 +788,117 @@ public void testCacheCleanupBasedOnStaleThreshold_StalenessLesserThanThreshold() terminate(threadPool); } + public void testClosingIndexWipesStats() throws Exception { + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + // Create two indices each with multiple shards + int numShards = 3; + Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build(); + String indexToKeepName = "test"; + String indexToCloseName = "test2"; + IndexService indexToKeep = createIndex(indexToKeepName, indexSettings); + IndexService indexToClose = createIndex(indexToCloseName, indexSettings); + for (int i = 0; i < numShards; i++) { + // Check we can get all the shards we expect + assertNotNull(indexToKeep.getShard(i)); + assertNotNull(indexToClose.getShard(i)); + } + ThreadPool threadPool = getThreadPool(); + Settings settings = Settings.builder().put(INDICES_REQUEST_CACHE_STALENESS_THRESHOLD_SETTING.getKey(), "0.001%").build(); + IndicesRequestCache cache = new IndicesRequestCache(settings, (shardId -> { + IndexService indexService = null; + try { + indexService = indicesService.indexServiceSafe(shardId.getIndex()); + } catch (IndexNotFoundException ex) { + return Optional.empty(); + } + try { + return Optional.of(new IndicesService.IndexShardCacheEntity(indexService.getShard(shardId.id()))); + } catch (ShardNotFoundException ex) { + return Optional.empty(); + } + }), + new CacheModule(new ArrayList<>(), Settings.EMPTY).getCacheService(), + threadPool, + ClusterServiceUtils.createClusterService(threadPool) + ); + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + + writer.addDocument(newDoc(0, "foo")); + TermQueryBuilder termQuery = new TermQueryBuilder("id", "0"); + BytesReference termBytes = XContentHelper.toXContent(termQuery, MediaTypeRegistry.JSON, false); + if (randomBoolean()) { + writer.flush(); + IOUtils.close(writer); + writer = new IndexWriter(dir, newIndexWriterConfig()); + } + writer.updateDocument(new Term("id", "0"), newDoc(0, "bar")); + DirectoryReader secondReader = OpenSearchDirectoryReader.wrap(DirectoryReader.open(writer), new ShardId("foo", "bar", 1)); + + List readersToClose = new ArrayList<>(); + List readersToKeep = new ArrayList<>(); + // Put entries into the cache for each shard + for (IndexService indexService : new IndexService[] { indexToKeep, indexToClose }) { + for (int i = 0; i < numShards; i++) { + IndexShard indexShard = indexService.getShard(i); + IndicesService.IndexShardCacheEntity entity = new IndicesService.IndexShardCacheEntity(indexShard); + DirectoryReader reader = OpenSearchDirectoryReader.wrap(DirectoryReader.open(writer), indexShard.shardId()); + if (indexService == indexToClose) { + readersToClose.add(reader); + } else { + readersToKeep.add(reader); + } + Loader loader = new Loader(reader, 0); + cache.getOrCompute(entity, loader, reader, termBytes); + } + } + + // Check resulting stats + List> initialDimensionValues = new ArrayList<>(); + for (IndexService indexService : new IndexService[] { indexToKeep, indexToClose }) { + for (int i = 0; i < numShards; i++) { + ShardId shardId = indexService.getShard(i).shardId(); + List dimensionValues = List.of(shardId.getIndexName(), shardId.toString()); + initialDimensionValues.add(dimensionValues); + ImmutableCacheStats snapshot = cache.stats().getStatsForDimensionValues(dimensionValues); + assertNotNull(snapshot); + // check the values are not empty by confirming entries != 0, this should always be true since the missed value is loaded + // into the cache + assertNotEquals(0, snapshot.getEntries()); + } + } + + // Delete an index + indexToClose.close("test_deletion", true); + // This actually closes the shards associated with the readers, which is necessary for cache cleanup logic + // In this UT, manually close the readers as well; could not figure out how to connect all this up in a UT so that + // we could get readers that were properly connected to an index's directory + for (DirectoryReader reader : readersToClose) { + IOUtils.close(reader); + } + // Trigger cache cleanup + cache.cacheCleanupManager.cleanCache(); + + // Now stats for the closed index should be gone + for (List dimensionValues : initialDimensionValues) { + ImmutableCacheStats snapshot = cache.stats().getStatsForDimensionValues(dimensionValues); + if (dimensionValues.get(0).equals(indexToCloseName)) { + assertNull(snapshot); + } else { + assertNotNull(snapshot); + // check the values are not empty by confirming entries != 0, this should always be true since the missed value is loaded + // into the cache + assertNotEquals(0, snapshot.getEntries()); + } + } + + for (DirectoryReader reader : readersToKeep) { + IOUtils.close(reader); + } + IOUtils.close(secondReader, writer, dir, cache); + terminate(threadPool); + } + public void testEviction() throws Exception { final ByteSizeValue size; { @@ -802,14 +930,15 @@ public void testEviction() throws Exception { assertEquals("foo", value1.streamInput().readString()); BytesReference value2 = cache.getOrCompute(secondEntity, secondLoader, secondReader, termBytes); assertEquals("bar", value2.streamInput().readString()); - size = indexShard.requestCache().stats().getMemorySize(); + size = new ByteSizeValue(cache.getSizeInBytes()); IOUtils.close(reader, secondReader, writer, dir, cache); terminate(threadPool); } IndexShard indexShard = createIndex("test1").getShard(0); ThreadPool threadPool = getThreadPool(); IndicesRequestCache cache = new IndicesRequestCache( - Settings.builder().put(IndicesRequestCache.INDICES_CACHE_QUERY_SIZE.getKey(), size.getBytes() + 1 + "b").build(), + // Add 5 instead of 1; the key size now depends on the length of dimension names and values so there's more variation + Settings.builder().put(IndicesRequestCache.INDICES_CACHE_QUERY_SIZE.getKey(), size.getBytes() + 5 + "b").build(), (shardId -> Optional.of(new IndicesService.IndexShardCacheEntity(indexShard))), new CacheModule(new ArrayList<>(), Settings.EMPTY).getCacheService(), threadPool, From e828c180b28fd96b3121e92613fc0879004e791f Mon Sep 17 00:00:00 2001 From: Marc Handalian Date: Sat, 13 Apr 2024 21:58:17 -0700 Subject: [PATCH 3/4] Fix flakiness with SegmentReplicationSuiteIT (#11977) * Fix SegmentReplicationSuiteIT This test fails because of a race during shard/node shutdown with node-node replication. Fixed by properly synchronizing creation of new replication events with cancellation and cancelling after shards are closed. Signed-off-by: Marc Handalian * Remove CopyState caching from OngoingSegmentReplications. This change removes the responsibility of caching CopyState inside of OngoingSegmentReplications. 1. CopyState was originally cached to prevent frequent disk reads while building segment metadata. This is now cached lower down in IndexShard and is not required here. 2. Change prepareForReplication method to return SegmentReplicationSourceHandler directly 3. Move responsibility of creating and clearing CopyState to the handler. Signed-off-by: Marc Handalian * Fix comment for afterIndexShardClosed method. Signed-off-by: Marc Handalian * Fix comment on beforeIndexShardClosed Signed-off-by: Marc Handalian * Remove unnecessary method from OngoingSegmentReplications Signed-off-by: Marc Handalian --------- Signed-off-by: Marc Handalian --- .../SegmentReplicationSuiteIT.java | 3 +- .../replication/CheckpointInfoResponse.java | 6 + .../OngoingSegmentReplications.java | 185 ++++-------------- .../SegmentReplicationSourceHandler.java | 44 +++-- .../SegmentReplicationSourceService.java | 12 +- .../indices/replication/common/CopyState.java | 33 ++-- .../SegmentReplicationIndexShardTests.java | 16 +- .../OngoingSegmentReplicationsTests.java | 73 ++----- .../SegmentReplicationSourceHandlerTests.java | 26 +-- .../SegmentReplicationTargetServiceTests.java | 5 +- .../replication/common/CopyStateTests.java | 13 +- .../index/shard/IndexShardTestCase.java | 4 +- 12 files changed, 125 insertions(+), 295 deletions(-) diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java index 8c045c1560dd3..27b65432e0bac 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationSuiteIT.java @@ -8,7 +8,6 @@ package org.opensearch.indices.replication; -import org.apache.lucene.tests.util.LuceneTestCase; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; @@ -16,7 +15,6 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.junit.Before; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/9499") @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, minNumDataNodes = 2) public class SegmentReplicationSuiteIT extends SegmentReplicationBaseIT { @@ -64,6 +62,7 @@ public void testDropRandomNodeDuringReplication() throws Exception { ensureYellow(INDEX_NAME); client().prepareIndex(INDEX_NAME).setId(Integer.toString(docCount)).setSource("field", "value" + docCount).execute().get(); internalCluster().startDataOnlyNode(); + ensureGreen(INDEX_NAME); client().admin().indices().delete(new DeleteIndexRequest(INDEX_NAME)).actionGet(); } diff --git a/server/src/main/java/org/opensearch/indices/replication/CheckpointInfoResponse.java b/server/src/main/java/org/opensearch/indices/replication/CheckpointInfoResponse.java index 9fd3b7f3afb80..24b744bebc53d 100644 --- a/server/src/main/java/org/opensearch/indices/replication/CheckpointInfoResponse.java +++ b/server/src/main/java/org/opensearch/indices/replication/CheckpointInfoResponse.java @@ -40,6 +40,12 @@ public CheckpointInfoResponse( this.infosBytes = infosBytes; } + public CheckpointInfoResponse(final ReplicationCheckpoint checkpoint, final byte[] infosBytes) { + this.checkpoint = checkpoint; + this.infosBytes = infosBytes; + this.metadataMap = checkpoint.getMetadataMap(); + } + public CheckpointInfoResponse(StreamInput in) throws IOException { this.checkpoint = new ReplicationCheckpoint(in); this.metadataMap = in.readMap(StreamInput::readString, StoreFileMetadata::new); diff --git a/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java b/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java index 33967c0203516..6b99b3c0b0696 100644 --- a/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java +++ b/server/src/main/java/org/opensearch/indices/replication/OngoingSegmentReplications.java @@ -21,12 +21,10 @@ import org.opensearch.indices.IndicesService; import org.opensearch.indices.recovery.FileChunkWriter; import org.opensearch.indices.recovery.RecoverySettings; -import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; -import org.opensearch.indices.replication.common.CopyState; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -36,7 +34,6 @@ /** * Manages references to ongoing segrep events on a node. * Each replica will have a new {@link SegmentReplicationSourceHandler} created when starting replication. - * CopyStates will be cached for reuse between replicas and only released when all replicas have finished copying segments. * * @opensearch.internal */ @@ -45,7 +42,6 @@ class OngoingSegmentReplications { private static final Logger logger = LogManager.getLogger(OngoingSegmentReplications.class); private final RecoverySettings recoverySettings; private final IndicesService indicesService; - private final Map copyStateMap; private final Map allocationIdToHandlers; /** @@ -57,46 +53,9 @@ class OngoingSegmentReplications { OngoingSegmentReplications(IndicesService indicesService, RecoverySettings recoverySettings) { this.indicesService = indicesService; this.recoverySettings = recoverySettings; - this.copyStateMap = Collections.synchronizedMap(new HashMap<>()); this.allocationIdToHandlers = ConcurrentCollections.newConcurrentMap(); } - /* - Operations on the {@link #copyStateMap} member. - */ - - /** - * A synchronized method that checks {@link #copyStateMap} for the given {@link ReplicationCheckpoint} key - * and returns the cached value if one is present. If the key is not present, a {@link CopyState} - * object is constructed and stored in the map before being returned. - */ - synchronized CopyState getCachedCopyState(ReplicationCheckpoint checkpoint) throws IOException { - if (isInCopyStateMap(checkpoint)) { - final CopyState copyState = fetchFromCopyStateMap(checkpoint); - // we incref the copyState for every replica that is using this checkpoint. - // decref will happen when copy completes. - copyState.incRef(); - return copyState; - } else { - // From the checkpoint's shard ID, fetch the IndexShard - ShardId shardId = checkpoint.getShardId(); - final IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); - final IndexShard indexShard = indexService.getShard(shardId.id()); - // build the CopyState object and cache it before returning - final CopyState copyState = new CopyState(checkpoint, indexShard); - - /* - Use the checkpoint from the request as the key in the map, rather than - the checkpoint from the created CopyState. This maximizes cache hits - if replication targets make a request with an older checkpoint. - Replication targets are expected to fetch the checkpoint in the response - CopyState to bring themselves up to date. - */ - addToCopyStateMap(checkpoint, copyState); - return copyState; - } - } - /** * Start sending files to the replica. * @@ -114,12 +73,10 @@ void startSegmentCopy(GetSegmentFilesRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> { - final SegmentReplicationSourceHandler sourceHandler = allocationIdToHandlers.remove(request.getTargetAllocationId()); - if (sourceHandler != null) { - removeCopyState(sourceHandler.getCopyState()); - } - }); + final ActionListener wrappedListener = ActionListener.runBefore( + listener, + () -> allocationIdToHandlers.remove(request.getTargetAllocationId()) + ); handler.sendFiles(request, wrappedListener); } else { listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); @@ -127,38 +84,32 @@ void startSegmentCopy(GetSegmentFilesRequest request, ActionListener handler.getAllocationId().equals(request.getTargetAllocationId()), "cancel due to retry"); - assert allocationIdToHandlers.containsKey(request.getTargetAllocationId()) == false; - allocationIdToHandlers.put(request.getTargetAllocationId(), newHandler); - } - assert allocationIdToHandlers.containsKey(request.getTargetAllocationId()); - return copyState; + SegmentReplicationSourceHandler prepareForReplication(CheckpointInfoRequest request, FileChunkWriter fileChunkWriter) { + return allocationIdToHandlers.computeIfAbsent(request.getTargetAllocationId(), aId -> { + try { + // From the checkpoint's shard ID, fetch the IndexShard + final ShardId shardId = request.getCheckpoint().getShardId(); + final IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); + final IndexShard indexShard = indexService.getShard(shardId.id()); + return new SegmentReplicationSourceHandler( + request.getTargetNode(), + fileChunkWriter, + indexShard, + request.getTargetAllocationId(), + Math.toIntExact(recoverySettings.getChunkSize().getBytes()), + recoverySettings.getMaxConcurrentFileChunks() + ); + } catch (IOException e) { + throw new UncheckedIOException("Error creating replication handler", e); + } + }); } /** @@ -167,8 +118,8 @@ CopyState prepareForReplication(CheckpointInfoRequest request, FileChunkWriter f * @param shard {@link IndexShard} * @param reason {@link String} - Reason for the cancel */ - synchronized void cancel(IndexShard shard, String reason) { - cancelHandlers(handler -> handler.getCopyState().getShard().shardId().equals(shard.shardId()), reason); + void cancel(IndexShard shard, String reason) { + cancelHandlers(handler -> handler.shardId().equals(shard.shardId()), reason); } /** @@ -177,11 +128,10 @@ synchronized void cancel(IndexShard shard, String reason) { * @param allocationId {@link String} - Allocation ID. * @param reason {@link String} - Reason for the cancel */ - synchronized void cancel(String allocationId, String reason) { + void cancel(String allocationId, String reason) { final SegmentReplicationSourceHandler handler = allocationIdToHandlers.remove(allocationId); if (handler != null) { handler.cancel(reason); - removeCopyState(handler.getCopyState()); } } @@ -194,14 +144,6 @@ void cancelReplication(DiscoveryNode node) { cancelHandlers(handler -> handler.getTargetNode().equals(node), "Node left"); } - /** - * Checks if the {@link #copyStateMap} has the input {@link ReplicationCheckpoint} - * as a key by invoking {@link Map#containsKey(Object)}. - */ - boolean isInCopyStateMap(ReplicationCheckpoint replicationCheckpoint) { - return copyStateMap.containsKey(replicationCheckpoint); - } - int size() { return allocationIdToHandlers.size(); } @@ -211,58 +153,20 @@ Map getHandlers() { return allocationIdToHandlers; } - int cachedCopyStateSize() { - return copyStateMap.size(); - } - - private SegmentReplicationSourceHandler createTargetHandler( - DiscoveryNode node, - CopyState copyState, - String allocationId, - FileChunkWriter fileChunkWriter - ) { - return new SegmentReplicationSourceHandler( - node, - fileChunkWriter, - copyState.getShard().getThreadPool(), - copyState, - allocationId, - Math.toIntExact(recoverySettings.getChunkSize().getBytes()), - recoverySettings.getMaxConcurrentFileChunks() - ); - } - /** - * Adds the input {@link CopyState} object to {@link #copyStateMap}. - * The key is the CopyState's {@link ReplicationCheckpoint} object. - */ - private void addToCopyStateMap(ReplicationCheckpoint checkpoint, CopyState copyState) { - copyStateMap.putIfAbsent(checkpoint, copyState); - } - - /** - * Given a {@link ReplicationCheckpoint}, return the corresponding - * {@link CopyState} object, if any, from {@link #copyStateMap}. - */ - private CopyState fetchFromCopyStateMap(ReplicationCheckpoint replicationCheckpoint) { - return copyStateMap.get(replicationCheckpoint); - } - - /** - * Remove a CopyState. Intended to be called after a replication event completes. - * This method will remove a copyState from the copyStateMap only if its refCount hits 0. - * - * @param copyState {@link CopyState} + * Clear handlers for any allocationIds not in sync. + * @param shardId {@link ShardId} + * @param inSyncAllocationIds {@link List} of in-sync allocation Ids. */ - private synchronized void removeCopyState(CopyState copyState) { - if (copyState.decRef() == true) { - copyStateMap.remove(copyState.getRequestedReplicationCheckpoint()); - } + void clearOutOfSyncIds(ShardId shardId, Set inSyncAllocationIds) { + cancelHandlers( + (handler) -> handler.shardId().equals(shardId) && inSyncAllocationIds.contains(handler.getAllocationId()) == false, + "Shard is no longer in-sync with the primary" + ); } /** * Remove handlers from allocationIdToHandlers map based on a filter predicate. - * This will also decref the handler's CopyState reference. */ private void cancelHandlers(Predicate predicate, String reason) { final List allocationIds = allocationIdToHandlers.values() @@ -278,17 +182,4 @@ private void cancelHandlers(Predicate p cancel(allocationId, reason); } } - - /** - * Clear copystate and target handlers for any non insync allocationIds. - * @param shardId {@link ShardId} - * @param inSyncAllocationIds {@link List} of in-sync allocation Ids. - */ - public void clearOutOfSyncIds(ShardId shardId, Set inSyncAllocationIds) { - cancelHandlers( - (handler) -> handler.getCopyState().getShard().shardId().equals(shardId) - && inSyncAllocationIds.contains(handler.getAllocationId()) == false, - "Shard is no longer in-sync with the primary" - ); - } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java index 674c09311c645..bb64d6b0c60b6 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceHandler.java @@ -18,16 +18,18 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.io.IOUtils; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.recovery.FileChunkWriter; import org.opensearch.indices.recovery.MultiChunkTransfer; +import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; import org.opensearch.indices.replication.common.CopyState; import org.opensearch.indices.replication.common.ReplicationTimer; -import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transports; import java.io.Closeable; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -54,48 +56,46 @@ class SegmentReplicationSourceHandler { private final AtomicBoolean isReplicating = new AtomicBoolean(); private final DiscoveryNode targetNode; private final String allocationId; - private final FileChunkWriter writer; /** * Constructor. * - * @param targetNode - {@link DiscoveryNode} target node where files should be sent. + * @param targetNode {@link DiscoveryNode} target node where files should be sent. * @param writer {@link FileChunkWriter} implementation that sends file chunks over the transport layer. - * @param threadPool {@link ThreadPool} Thread pool. - * @param copyState {@link CopyState} CopyState holding segment file metadata. + * @param shard {@link IndexShard} The primary shard local to this node. * @param fileChunkSizeInBytes {@link Integer} * @param maxConcurrentFileChunks {@link Integer} */ SegmentReplicationSourceHandler( DiscoveryNode targetNode, FileChunkWriter writer, - ThreadPool threadPool, - CopyState copyState, + IndexShard shard, String allocationId, int fileChunkSizeInBytes, int maxConcurrentFileChunks - ) { + ) throws IOException { this.targetNode = targetNode; - this.shard = copyState.getShard(); + this.shard = shard; this.logger = Loggers.getLogger( SegmentReplicationSourceHandler.class, - copyState.getShard().shardId(), + shard.shardId(), "sending segments to " + targetNode.getName() ); this.segmentFileTransferHandler = new SegmentFileTransferHandler( - copyState.getShard(), + shard, targetNode, writer, logger, - threadPool, + shard.getThreadPool(), cancellableThreads, fileChunkSizeInBytes, maxConcurrentFileChunks ); this.allocationId = allocationId; - this.copyState = copyState; + this.copyState = new CopyState(shard); this.writer = writer; + resources.add(copyState); } /** @@ -109,6 +109,7 @@ public synchronized void sendFiles(GetSegmentFilesRequest request, ActionListene if (request.getFilesToFetch().isEmpty()) { // before completion, alert the primary of the replica's state. shard.updateVisibleCheckpointForShard(request.getTargetAllocationId(), copyState.getCheckpoint()); + IOUtils.closeWhileHandlingException(copyState); listener.onResponse(new GetSegmentFilesResponse(Collections.emptyList())); return; } @@ -183,10 +184,7 @@ public synchronized void sendFiles(GetSegmentFilesRequest request, ActionListene public void cancel(String reason) { writer.cancel(); cancellableThreads.cancel(reason); - } - - CopyState getCopyState() { - return copyState; + IOUtils.closeWhileHandlingException(copyState); } public boolean isReplicating() { @@ -200,4 +198,16 @@ public DiscoveryNode getTargetNode() { public String getAllocationId() { return allocationId; } + + public ReplicationCheckpoint getCheckpoint() { + return copyState.getCheckpoint(); + } + + public byte[] getInfosBytes() { + return copyState.getInfosBytes(); + } + + public ShardId shardId() { + return shard.shardId(); + } } diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java index a393faabae0ea..ca89741d5bb55 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java @@ -29,7 +29,6 @@ import org.opensearch.indices.IndicesService; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.recovery.RetryableTransportClient; -import org.opensearch.indices.replication.common.CopyState; import org.opensearch.indices.replication.common.ReplicationTimer; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; @@ -126,16 +125,17 @@ public void messageReceived(CheckpointInfoRequest request, TransportChannel chan new AtomicLong(0), (throttleTime) -> {} ); - final CopyState copyState = ongoingSegmentReplications.prepareForReplication(request, segmentSegmentFileChunkWriter); - channel.sendResponse( - new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) + final SegmentReplicationSourceHandler handler = ongoingSegmentReplications.prepareForReplication( + request, + segmentSegmentFileChunkWriter ); + channel.sendResponse(new CheckpointInfoResponse(handler.getCheckpoint(), handler.getInfosBytes())); timer.stop(); logger.trace( new ParameterizedMessage( "[replication id {}] Source node sent checkpoint info [{}] to target node [{}], timing: {}", request.getReplicationId(), - copyState.getCheckpoint(), + handler.getCheckpoint(), request.getTargetNode().getId(), timer.time() ) @@ -217,7 +217,7 @@ protected void doClose() throws IOException { /** * - * Cancels any replications on this node to a replica shard that is about to be closed. + * Before a primary shard is closed, cancel any ongoing replications to release incref'd segments. */ @Override public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) { diff --git a/server/src/main/java/org/opensearch/indices/replication/common/CopyState.java b/server/src/main/java/org/opensearch/indices/replication/common/CopyState.java index 3b7ae2af80ca0..7d3eb9083208b 100644 --- a/server/src/main/java/org/opensearch/indices/replication/common/CopyState.java +++ b/server/src/main/java/org/opensearch/indices/replication/common/CopyState.java @@ -13,11 +13,11 @@ import org.apache.lucene.store.ByteBuffersIndexOutput; import org.opensearch.common.collect.Tuple; import org.opensearch.common.concurrent.GatedCloseable; -import org.opensearch.common.util.concurrent.AbstractRefCounted; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.store.StoreFileMetadata; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; +import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; import java.util.Map; @@ -28,28 +28,21 @@ * * @opensearch.internal */ -public class CopyState extends AbstractRefCounted { +public class CopyState implements Closeable { private final GatedCloseable segmentInfosRef; - /** ReplicationCheckpoint requested */ - private final ReplicationCheckpoint requestedReplicationCheckpoint; /** Actual ReplicationCheckpoint returned by the shard */ private final ReplicationCheckpoint replicationCheckpoint; - private final Map metadataMap; private final byte[] infosBytes; private final IndexShard shard; - public CopyState(ReplicationCheckpoint requestedReplicationCheckpoint, IndexShard shard) throws IOException { - super("CopyState-" + shard.shardId()); - this.requestedReplicationCheckpoint = requestedReplicationCheckpoint; + public CopyState(IndexShard shard) throws IOException { this.shard = shard; final Tuple, ReplicationCheckpoint> latestSegmentInfosAndCheckpoint = shard .getLatestSegmentInfosAndCheckpoint(); this.segmentInfosRef = latestSegmentInfosAndCheckpoint.v1(); this.replicationCheckpoint = latestSegmentInfosAndCheckpoint.v2(); SegmentInfos segmentInfos = this.segmentInfosRef.get(); - this.metadataMap = shard.store().getSegmentMetadataMap(segmentInfos); - ByteBuffersDataOutput buffer = new ByteBuffersDataOutput(); // resource description and name are not used, but resource description cannot be null try (ByteBuffersIndexOutput indexOutput = new ByteBuffersIndexOutput(buffer, "", null)) { @@ -58,21 +51,12 @@ public CopyState(ReplicationCheckpoint requestedReplicationCheckpoint, IndexShar this.infosBytes = buffer.toArrayCopy(); } - @Override - protected void closeInternal() { - try { - segmentInfosRef.close(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - public ReplicationCheckpoint getCheckpoint() { return replicationCheckpoint; } public Map getMetadataMap() { - return metadataMap; + return replicationCheckpoint.getMetadataMap(); } public byte[] getInfosBytes() { @@ -83,7 +67,12 @@ public IndexShard getShard() { return shard; } - public ReplicationCheckpoint getRequestedReplicationCheckpoint() { - return requestedReplicationCheckpoint; + @Override + public void close() throws IOException { + try { + segmentInfosRef.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } } } diff --git a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java index e93d266dcab4c..2311fc582616f 100644 --- a/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/SegmentReplicationIndexShardTests.java @@ -1019,25 +1019,15 @@ protected void assertDocCounts(IndexShard indexShard, int expectedPersistedDocCo } protected void resolveCheckpointInfoResponseListener(ActionListener listener, IndexShard primary) { - final CopyState copyState; - try { - copyState = new CopyState( - ReplicationCheckpoint.empty(primary.shardId, primary.getLatestReplicationCheckpoint().getCodec()), - primary + try (final CopyState copyState = new CopyState(primary)) { + listener.onResponse( + new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) ); } catch (IOException e) { logger.error("Unexpected error computing CopyState", e); Assert.fail("Failed to compute copyState"); throw new UncheckedIOException(e); } - - try { - listener.onResponse( - new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) - ); - } finally { - copyState.decRef(); - } } protected void startReplicationAndAssertCancellation( diff --git a/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java b/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java index 44e2653cf01da..eb27850000bdd 100644 --- a/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/OngoingSegmentReplicationsTests.java @@ -25,7 +25,6 @@ import org.opensearch.indices.recovery.FileChunkWriter; import org.opensearch.indices.recovery.RecoverySettings; import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint; -import org.opensearch.indices.replication.common.CopyState; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.transport.TransportService; import org.junit.Assert; @@ -106,25 +105,21 @@ public void testPrepareAndSendSegments() throws IOException { final FileChunkWriter segmentSegmentFileChunkWriter = (fileMetadata, position, content, lastChunk, totalTranslogOps, listener) -> { listener.onResponse(null); }; - final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); - assertTrue(replications.isInCopyStateMap(request.getCheckpoint())); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); assertEquals(1, replications.size()); - assertEquals(1, copyState.refCount()); getSegmentFilesRequest = new GetSegmentFilesRequest( 1L, replica.routingEntry().allocationId().getId(), replicaDiscoveryNode, - new ArrayList<>(copyState.getMetadataMap().values()), + new ArrayList<>(handler.getCheckpoint().getMetadataMap().values()), testCheckpoint ); replications.startSegmentCopy(getSegmentFilesRequest, new ActionListener<>() { @Override public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { - assertEquals(copyState.getMetadataMap().size(), getSegmentFilesResponse.files.size()); - assertEquals(0, copyState.refCount()); - assertFalse(replications.isInCopyStateMap(request.getCheckpoint())); + assertEquals(handler.getCheckpoint().getMetadataMap().size(), getSegmentFilesResponse.files.size()); assertEquals(0, replications.size()); } @@ -148,14 +143,11 @@ public void testCancelReplication() throws IOException { // this shouldn't be called in this test. Assert.fail(); }; - final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); assertEquals(1, replications.size()); - assertEquals(1, replications.cachedCopyStateSize()); replications.cancelReplication(primaryDiscoveryNode); - assertEquals(0, copyState.refCount()); assertEquals(0, replications.size()); - assertEquals(0, replications.cachedCopyStateSize()); } public void testCancelReplication_AfterSendFilesStarts() throws IOException, InterruptedException { @@ -174,14 +166,13 @@ public void testCancelReplication_AfterSendFilesStarts() throws IOException, Int // cancel the replication as soon as the writer starts sending files. replications.cancel(replica.routingEntry().allocationId().getId(), "Test"); }; - final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); assertEquals(1, replications.size()); - assertEquals(1, replications.cachedCopyStateSize()); getSegmentFilesRequest = new GetSegmentFilesRequest( 1L, replica.routingEntry().allocationId().getId(), replicaDiscoveryNode, - new ArrayList<>(copyState.getMetadataMap().values()), + new ArrayList<>(handler.getCheckpoint().getMetadataMap().values()), testCheckpoint ); replications.startSegmentCopy(getSegmentFilesRequest, new ActionListener<>() { @@ -193,9 +184,7 @@ public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { @Override public void onFailure(Exception e) { assertEquals(CancellableThreads.ExecutionCancelledException.class, e.getClass()); - assertEquals(0, copyState.refCount()); assertEquals(0, replications.size()); - assertEquals(0, replications.cachedCopyStateSize()); latch.countDown(); } }); @@ -219,8 +208,7 @@ public void testMultipleReplicasUseSameCheckpoint() throws IOException { Assert.fail(); }; - final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); - assertEquals(1, copyState.refCount()); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); final CheckpointInfoRequest secondRequest = new CheckpointInfoRequest( 1L, @@ -230,15 +218,11 @@ public void testMultipleReplicasUseSameCheckpoint() throws IOException { ); replications.prepareForReplication(secondRequest, segmentSegmentFileChunkWriter); - assertEquals(2, copyState.refCount()); assertEquals(2, replications.size()); - assertEquals(1, replications.cachedCopyStateSize()); replications.cancelReplication(primaryDiscoveryNode); replications.cancelReplication(replicaDiscoveryNode); - assertEquals(0, copyState.refCount()); assertEquals(0, replications.size()); - assertEquals(0, replications.cachedCopyStateSize()); closeShards(secondReplica); } @@ -280,8 +264,7 @@ public void testShardAlreadyReplicatingToNode() throws IOException { listener.onResponse(null); }; replications.prepareForReplication(request, segmentSegmentFileChunkWriter); - CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); - assertEquals(1, copyState.refCount()); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); } public void testStartReplicationWithNoFilesToFetch() throws IOException { @@ -296,10 +279,8 @@ public void testStartReplicationWithNoFilesToFetch() throws IOException { // mock the FileChunkWriter so we can assert its ever called. final FileChunkWriter segmentSegmentFileChunkWriter = mock(FileChunkWriter.class); // Prepare for replication step - and ensure copyState is added to cache. - final CopyState copyState = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); - assertTrue(replications.isInCopyStateMap(request.getCheckpoint())); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, segmentSegmentFileChunkWriter); assertEquals(1, replications.size()); - assertEquals(1, copyState.refCount()); getSegmentFilesRequest = new GetSegmentFilesRequest( 1L, @@ -314,8 +295,6 @@ public void testStartReplicationWithNoFilesToFetch() throws IOException { @Override public void onResponse(GetSegmentFilesResponse getSegmentFilesResponse) { assertEquals(Collections.emptyList(), getSegmentFilesResponse.files); - assertEquals(0, copyState.refCount()); - assertFalse(replications.isInCopyStateMap(request.getCheckpoint())); verifyNoInteractions(segmentSegmentFileChunkWriter); } @@ -340,8 +319,7 @@ public void testCancelAllReplicationsForShard() throws IOException { testCheckpoint ); - final CopyState copyState = replications.prepareForReplication(request, mock(FileChunkWriter.class)); - assertEquals(1, copyState.refCount()); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, mock(FileChunkWriter.class)); final CheckpointInfoRequest secondRequest = new CheckpointInfoRequest( 1L, @@ -351,15 +329,11 @@ public void testCancelAllReplicationsForShard() throws IOException { ); replications.prepareForReplication(secondRequest, mock(FileChunkWriter.class)); - assertEquals(2, copyState.refCount()); assertEquals(2, replications.size()); - assertEquals(1, replications.cachedCopyStateSize()); // cancel the primary's ongoing replications. replications.cancel(primary, "Test"); - assertEquals(0, copyState.refCount()); assertEquals(0, replications.size()); - assertEquals(0, replications.cachedCopyStateSize()); closeShards(replica_2); } @@ -372,8 +346,7 @@ public void testCancelForMissingIds() throws IOException { final String replicaAllocationId = replica.routingEntry().allocationId().getId(); final CheckpointInfoRequest request = new CheckpointInfoRequest(1L, replicaAllocationId, primaryDiscoveryNode, testCheckpoint); - final CopyState copyState = replications.prepareForReplication(request, mock(FileChunkWriter.class)); - assertEquals(1, copyState.refCount()); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, mock(FileChunkWriter.class)); final String replica_2AllocationId = replica_2.routingEntry().allocationId().getId(); final CheckpointInfoRequest secondRequest = new CheckpointInfoRequest( @@ -384,23 +357,17 @@ public void testCancelForMissingIds() throws IOException { ); replications.prepareForReplication(secondRequest, mock(FileChunkWriter.class)); - assertEquals(2, copyState.refCount()); assertEquals(2, replications.size()); assertTrue(replications.getHandlers().containsKey(replicaAllocationId)); assertTrue(replications.getHandlers().containsKey(replica_2AllocationId)); - assertEquals(1, replications.cachedCopyStateSize()); replications.clearOutOfSyncIds(primary.shardId(), Set.of(replica_2AllocationId)); - assertEquals(1, copyState.refCount()); assertEquals(1, replications.size()); assertTrue(replications.getHandlers().containsKey(replica_2AllocationId)); - assertEquals(1, replications.cachedCopyStateSize()); // cancel the primary's ongoing replications. replications.clearOutOfSyncIds(primary.shardId(), Collections.emptySet()); - assertEquals(0, copyState.refCount()); assertEquals(0, replications.size()); - assertEquals(0, replications.cachedCopyStateSize()); closeShards(replica_2); } @@ -409,11 +376,8 @@ public void testPrepareForReplicationAlreadyReplicating() throws IOException { final String replicaAllocationId = replica.routingEntry().allocationId().getId(); final CheckpointInfoRequest request = new CheckpointInfoRequest(1L, replicaAllocationId, primaryDiscoveryNode, testCheckpoint); - final CopyState copyState = replications.prepareForReplication(request, mock(FileChunkWriter.class)); - - final SegmentReplicationSourceHandler handler = replications.getHandlers().get(replicaAllocationId); - assertEquals(handler.getCopyState(), copyState); - assertEquals(1, copyState.refCount()); + final SegmentReplicationSourceHandler handler = replications.prepareForReplication(request, mock(FileChunkWriter.class)); + assertEquals(handler, replications.getHandlers().get(replicaAllocationId)); ReplicationCheckpoint secondCheckpoint = new ReplicationCheckpoint( testCheckpoint.getShardId(), @@ -430,11 +394,10 @@ public void testPrepareForReplicationAlreadyReplicating() throws IOException { secondCheckpoint ); - final CopyState secondCopyState = replications.prepareForReplication(secondRequest, mock(FileChunkWriter.class)); - final SegmentReplicationSourceHandler secondHandler = replications.getHandlers().get(replicaAllocationId); - assertEquals(secondHandler.getCopyState(), secondCopyState); - assertEquals("New copy state is incref'd", 1, secondCopyState.refCount()); - assertEquals("Old copy state is cleaned up", 0, copyState.refCount()); - + final SegmentReplicationSourceHandler secondHandler = replications.prepareForReplication( + secondRequest, + mock(FileChunkWriter.class) + ); + assertEquals(secondHandler, replications.getHandlers().get(replicaAllocationId)); } } diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java index d586767290797..901dc28794cfc 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationSourceHandlerTests.java @@ -68,18 +68,16 @@ public void testSendFiles() throws IOException { chunkWriter = (fileMetadata, position, content, lastChunk, totalTranslogOps, listener) -> listener.onResponse(null); final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); - final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( localNode, chunkWriter, - threadPool, - copyState, + primary, replica.routingEntry().allocationId().getId(), 5000, 1 ); - final List expectedFiles = List.copyOf(copyState.getMetadataMap().values()); + final List expectedFiles = List.copyOf(handler.getCheckpoint().getMetadataMap().values()); final GetSegmentFilesRequest getSegmentFilesRequest = new GetSegmentFilesRequest( 1L, @@ -106,12 +104,10 @@ public void testSendFiles_emptyRequest() throws IOException { chunkWriter = mock(FileChunkWriter.class); final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); - final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( localNode, chunkWriter, - threadPool, - copyState, + primary, replica.routingEntry().allocationId().getId(), 5000, 1 @@ -148,12 +144,11 @@ public void testSendFileFails() throws IOException { ); final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); - final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); + final CopyState copyState = new CopyState(primary); SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( localNode, chunkWriter, - threadPool, - copyState, + primary, primary.routingEntry().allocationId().getId(), 5000, 1 @@ -180,19 +175,18 @@ public void onFailure(Exception e) { assertEquals(e.getClass(), OpenSearchException.class); } }); - copyState.decRef(); + copyState.close(); } public void testReplicationAlreadyRunning() throws IOException { chunkWriter = mock(FileChunkWriter.class); final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); - final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); + final CopyState copyState = new CopyState(primary); SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( localNode, chunkWriter, - threadPool, - copyState, + primary, replica.routingEntry().allocationId().getId(), 5000, 1 @@ -217,12 +211,10 @@ public void testCancelReplication() throws IOException, InterruptedException { chunkWriter = mock(FileChunkWriter.class); final ReplicationCheckpoint latestReplicationCheckpoint = primary.getLatestReplicationCheckpoint(); - final CopyState copyState = new CopyState(latestReplicationCheckpoint, primary); SegmentReplicationSourceHandler handler = new SegmentReplicationSourceHandler( localNode, chunkWriter, - threadPool, - copyState, + primary, primary.routingEntry().allocationId().getId(), 5000, 1 diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index 3c72dda2d8b5d..f06d5595afcd5 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -273,10 +273,7 @@ public void getCheckpointMetadata( ) { try { blockGetCheckpointMetadata.await(); - final CopyState copyState = new CopyState( - ReplicationCheckpoint.empty(primaryShard.shardId(), primaryShard.getLatestReplicationCheckpoint().getCodec()), - primaryShard - ); + final CopyState copyState = new CopyState(primaryShard); listener.onResponse( new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) ); diff --git a/server/src/test/java/org/opensearch/indices/replication/common/CopyStateTests.java b/server/src/test/java/org/opensearch/indices/replication/common/CopyStateTests.java index df180a8ab1007..0b30486038e3a 100644 --- a/server/src/test/java/org/opensearch/indices/replication/common/CopyStateTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/common/CopyStateTests.java @@ -18,7 +18,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.index.shard.ShardId; import org.opensearch.env.Environment; -import org.opensearch.index.codec.CodecService; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShardTestCase; import org.opensearch.index.store.Store; @@ -54,13 +53,7 @@ public class CopyStateTests extends IndexShardTestCase { public void testCopyStateCreation() throws IOException { final IndexShard mockIndexShard = createMockIndexShard(); - CopyState copyState = new CopyState( - ReplicationCheckpoint.empty( - mockIndexShard.shardId(), - new CodecService(null, mockIndexShard.indexSettings(), null).codec("default").getName() - ), - mockIndexShard - ); + CopyState copyState = new CopyState(mockIndexShard); ReplicationCheckpoint checkpoint = copyState.getCheckpoint(); assertEquals(TEST_SHARD_ID, checkpoint.getShardId()); // version was never set so this should be zero @@ -86,7 +79,9 @@ public static IndexShard createMockIndexShard() throws IOException { mockShard.getOperationPrimaryTerm(), 0L, 0L, - Codec.getDefault().getName() + 0L, + Codec.getDefault().getName(), + SI_SNAPSHOT.asMap() ); final Tuple, ReplicationCheckpoint> gatedCloseableReplicationCheckpointTuple = new Tuple<>( new GatedCloseable<>(testSegmentInfos, () -> {}), diff --git a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java index 4dd4c734a1701..6b609d8af62a1 100644 --- a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java @@ -1638,12 +1638,10 @@ public void getCheckpointMetadata( ReplicationCheckpoint checkpoint, ActionListener listener ) { - try { - final CopyState copyState = new CopyState(primaryShard.getLatestReplicationCheckpoint(), primaryShard); + try (final CopyState copyState = new CopyState(primaryShard)) { listener.onResponse( new CheckpointInfoResponse(copyState.getCheckpoint(), copyState.getMetadataMap(), copyState.getInfosBytes()) ); - copyState.decRef(); } catch (IOException e) { logger.error("Unexpected error computing CopyState", e); Assert.fail("Failed to compute copyState"); From 6bc04b49a0031e37466001469abcfd593e454813 Mon Sep 17 00:00:00 2001 From: maxliu <48641774+Ferrari248@users.noreply.github.com> Date: Sun, 14 Apr 2024 13:02:52 +0800 Subject: [PATCH 4/4] [segment replication] decouple the rateLimiter of segrep and recovery (#12959) * [segment replication] decouple the rateLimiter of segrep and recovery (12939) add setting "segrep.max_bytes_per_sec" Signed-off-by: maxliu * [segment replication] decouple the rateLimiter of segrep and recovery (12939) use setting "indices.replication.max_bytes_per_sec" if enable "indices.replication.use_individual_rate_limiter" Signed-off-by: maxliu * [segment replication] decouple the rateLimiter of segrep and recovery (12939) setting "indices.replication.max_bytes_per_sec" takes effect when not negative Signed-off-by: maxliu * [segment replication] decouple the rateLimiter of segrep and recovery (#12939) add setting "indices.replication.max_bytes_per_sec" which takes effect when not negative Signed-off-by: maxliu Adds change log Signed-off-by: maxliu --------- Signed-off-by: maxliu --- CHANGELOG.md | 1 + .../common/settings/ClusterSettings.java | 1 + .../recovery/PeerRecoveryTargetService.java | 8 +- .../indices/recovery/RecoverySettings.java | 79 +++++++++++++++---- .../recovery/RemoteRecoveryTargetHandler.java | 3 +- .../RemoteSegmentFileChunkWriter.java | 10 ++- .../SegmentReplicationSourceService.java | 3 +- .../SegmentReplicationTargetService.java | 2 +- .../blobstore/BlobStoreRepository.java | 4 +- .../RecoverySettingsDynamicUpdateTests.java | 24 +++++- 10 files changed, 107 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b41f1cd4ca01..d336b87197e1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add support for more than one protocol for transport ([#12967](https://github.com/opensearch-project/OpenSearch/pull/12967)) - [Tiered Caching] Add dimension-based stats to ICache implementations. ([#12531](https://github.com/opensearch-project/OpenSearch/pull/12531)) - Add changes for overriding remote store and replication settings during snapshot restore. ([#11868](https://github.com/opensearch-project/OpenSearch/pull/11868)) +- Add an individual setting of rate limiter for segment replication ([#12959](https://github.com/opensearch-project/OpenSearch/pull/12959)) ### Dependencies - Bump `org.apache.commons:commons-configuration2` from 2.10.0 to 2.10.1 ([#12896](https://github.com/opensearch-project/OpenSearch/pull/12896)) diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 004973e50d43a..ef74a794a9975 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -292,6 +292,7 @@ public void apply(Settings value, Settings current, Settings previous) { ShardLimitValidator.SETTING_CLUSTER_MAX_SHARDS_PER_CLUSTER, ShardLimitValidator.SETTING_CLUSTER_IGNORE_DOT_INDEXES, RecoverySettings.INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING, + RecoverySettings.INDICES_REPLICATION_MAX_BYTES_PER_SEC_SETTING, RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_STATE_SYNC_SETTING, RecoverySettings.INDICES_RECOVERY_RETRY_DELAY_NETWORK_SETTING, RecoverySettings.INDICES_RECOVERY_ACTIVITY_TIMEOUT_SETTING, diff --git a/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java b/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java index c24840d0c1333..6279a8ec3646c 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java +++ b/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java @@ -575,7 +575,13 @@ public void messageReceived(final FileChunkRequest request, TransportChannel cha try (ReplicationRef recoveryRef = onGoingRecoveries.getSafe(request.recoveryId(), request.shardId())) { final RecoveryTarget recoveryTarget = recoveryRef.get(); final ActionListener listener = recoveryTarget.createOrFinishListener(channel, Actions.FILE_CHUNK, request); - recoveryTarget.handleFileChunk(request, recoveryTarget, bytesSinceLastPause, recoverySettings.rateLimiter(), listener); + recoveryTarget.handleFileChunk( + request, + recoveryTarget, + bytesSinceLastPause, + recoverySettings.recoveryRateLimiter(), + listener + ); } } } diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java b/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java index 53b42347aa30d..8f9da6babdd99 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoverySettings.java @@ -65,6 +65,16 @@ public class RecoverySettings { Property.NodeScope ); + /** + * Individual speed setting for segment replication, default -1B to reuse the setting of recovery. + */ + public static final Setting INDICES_REPLICATION_MAX_BYTES_PER_SEC_SETTING = Setting.byteSizeSetting( + "indices.replication.max_bytes_per_sec", + new ByteSizeValue(-1), + Property.Dynamic, + Property.NodeScope + ); + /** * Controls the maximum number of file chunk requests that can be sent concurrently from the source node to the target node. */ @@ -169,11 +179,13 @@ public class RecoverySettings { // choose 512KB-16B to ensure that the resulting byte[] is not a humongous allocation in G1. public static final ByteSizeValue DEFAULT_CHUNK_SIZE = new ByteSizeValue(512 * 1024 - 16, ByteSizeUnit.BYTES); - private volatile ByteSizeValue maxBytesPerSec; + private volatile ByteSizeValue recoveryMaxBytesPerSec; + private volatile ByteSizeValue replicationMaxBytesPerSec; private volatile int maxConcurrentFileChunks; private volatile int maxConcurrentOperations; private volatile int maxConcurrentRemoteStoreStreams; - private volatile SimpleRateLimiter rateLimiter; + private volatile SimpleRateLimiter recoveryRateLimiter; + private volatile SimpleRateLimiter replicationRateLimiter; private volatile TimeValue retryDelayStateSync; private volatile TimeValue retryDelayNetwork; private volatile TimeValue activityTimeout; @@ -198,17 +210,20 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { this.internalActionLongTimeout = INDICES_RECOVERY_INTERNAL_LONG_ACTION_TIMEOUT_SETTING.get(settings); this.activityTimeout = INDICES_RECOVERY_ACTIVITY_TIMEOUT_SETTING.get(settings); - this.maxBytesPerSec = INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.get(settings); - if (maxBytesPerSec.getBytes() <= 0) { - rateLimiter = null; + this.recoveryMaxBytesPerSec = INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.get(settings); + if (recoveryMaxBytesPerSec.getBytes() <= 0) { + recoveryRateLimiter = null; } else { - rateLimiter = new SimpleRateLimiter(maxBytesPerSec.getMbFrac()); + recoveryRateLimiter = new SimpleRateLimiter(recoveryMaxBytesPerSec.getMbFrac()); } + this.replicationMaxBytesPerSec = INDICES_REPLICATION_MAX_BYTES_PER_SEC_SETTING.get(settings); + updateReplicationRateLimiter(); - logger.debug("using max_bytes_per_sec[{}]", maxBytesPerSec); + logger.debug("using recovery max_bytes_per_sec[{}]", recoveryMaxBytesPerSec); this.internalRemoteUploadTimeout = INDICES_INTERNAL_REMOTE_UPLOAD_TIMEOUT.get(settings); - clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING, this::setMaxBytesPerSec); + clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING, this::setRecoveryMaxBytesPerSec); + clusterSettings.addSettingsUpdateConsumer(INDICES_REPLICATION_MAX_BYTES_PER_SEC_SETTING, this::setReplicationMaxBytesPerSec); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_CONCURRENT_FILE_CHUNKS_SETTING, this::setMaxConcurrentFileChunks); clusterSettings.addSettingsUpdateConsumer(INDICES_RECOVERY_MAX_CONCURRENT_OPERATIONS_SETTING, this::setMaxConcurrentOperations); clusterSettings.addSettingsUpdateConsumer( @@ -227,8 +242,12 @@ public RecoverySettings(Settings settings, ClusterSettings clusterSettings) { } - public RateLimiter rateLimiter() { - return rateLimiter; + public RateLimiter recoveryRateLimiter() { + return recoveryRateLimiter; + } + + public RateLimiter replicationRateLimiter() { + return replicationRateLimiter; } public TimeValue retryDelayNetwork() { @@ -294,14 +313,40 @@ public void setInternalRemoteUploadTimeout(TimeValue internalRemoteUploadTimeout this.internalRemoteUploadTimeout = internalRemoteUploadTimeout; } - private void setMaxBytesPerSec(ByteSizeValue maxBytesPerSec) { - this.maxBytesPerSec = maxBytesPerSec; - if (maxBytesPerSec.getBytes() <= 0) { - rateLimiter = null; - } else if (rateLimiter != null) { - rateLimiter.setMBPerSec(maxBytesPerSec.getMbFrac()); + private void setRecoveryMaxBytesPerSec(ByteSizeValue recoveryMaxBytesPerSec) { + this.recoveryMaxBytesPerSec = recoveryMaxBytesPerSec; + if (recoveryMaxBytesPerSec.getBytes() <= 0) { + recoveryRateLimiter = null; + } else if (recoveryRateLimiter != null) { + recoveryRateLimiter.setMBPerSec(recoveryMaxBytesPerSec.getMbFrac()); } else { - rateLimiter = new SimpleRateLimiter(maxBytesPerSec.getMbFrac()); + recoveryRateLimiter = new SimpleRateLimiter(recoveryMaxBytesPerSec.getMbFrac()); + } + if (replicationMaxBytesPerSec.getBytes() < 0) updateReplicationRateLimiter(); + } + + private void setReplicationMaxBytesPerSec(ByteSizeValue replicationMaxBytesPerSec) { + this.replicationMaxBytesPerSec = replicationMaxBytesPerSec; + updateReplicationRateLimiter(); + } + + private void updateReplicationRateLimiter() { + if (replicationMaxBytesPerSec.getBytes() >= 0) { + if (replicationMaxBytesPerSec.getBytes() == 0) { + replicationRateLimiter = null; + } else if (replicationRateLimiter != null) { + replicationRateLimiter.setMBPerSec(replicationMaxBytesPerSec.getMbFrac()); + } else { + replicationRateLimiter = new SimpleRateLimiter(replicationMaxBytesPerSec.getMbFrac()); + } + } else { // when replicationMaxBytesPerSec = -1B, use setting of recovery + if (recoveryMaxBytesPerSec.getBytes() <= 0) { + replicationRateLimiter = null; + } else if (replicationRateLimiter != null) { + replicationRateLimiter.setMBPerSec(recoveryMaxBytesPerSec.getMbFrac()); + } else { + replicationRateLimiter = new SimpleRateLimiter(recoveryMaxBytesPerSec.getMbFrac()); + } } } diff --git a/server/src/main/java/org/opensearch/indices/recovery/RemoteRecoveryTargetHandler.java b/server/src/main/java/org/opensearch/indices/recovery/RemoteRecoveryTargetHandler.java index 37227596fdfe7..4d4c20f778ef3 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RemoteRecoveryTargetHandler.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RemoteRecoveryTargetHandler.java @@ -111,7 +111,8 @@ public RemoteRecoveryTargetHandler( shardId, PeerRecoveryTargetService.Actions.FILE_CHUNK, requestSeqNoGenerator, - onSourceThrottle + onSourceThrottle, + recoverySettings::recoveryRateLimiter ); this.remoteStoreEnabled = remoteStoreEnabled; } diff --git a/server/src/main/java/org/opensearch/indices/replication/RemoteSegmentFileChunkWriter.java b/server/src/main/java/org/opensearch/indices/replication/RemoteSegmentFileChunkWriter.java index b52fe66816098..179e497565326 100644 --- a/server/src/main/java/org/opensearch/indices/replication/RemoteSegmentFileChunkWriter.java +++ b/server/src/main/java/org/opensearch/indices/replication/RemoteSegmentFileChunkWriter.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Supplier; /** * This class handles sending file chunks over the transport layer to a target shard. @@ -36,11 +37,11 @@ public final class RemoteSegmentFileChunkWriter implements FileChunkWriter { private final AtomicLong requestSeqNoGenerator; private final RetryableTransportClient retryableTransportClient; private final ShardId shardId; - private final RecoverySettings recoverySettings; private final long replicationId; private final AtomicLong bytesSinceLastPause = new AtomicLong(); private final TransportRequestOptions fileChunkRequestOptions; private final Consumer onSourceThrottle; + private final Supplier rateLimiterSupplier; private final String action; public RemoteSegmentFileChunkWriter( @@ -50,14 +51,15 @@ public RemoteSegmentFileChunkWriter( ShardId shardId, String action, AtomicLong requestSeqNoGenerator, - Consumer onSourceThrottle + Consumer onSourceThrottle, + Supplier rateLimiterSupplier ) { this.replicationId = replicationId; - this.recoverySettings = recoverySettings; this.retryableTransportClient = retryableTransportClient; this.shardId = shardId; this.requestSeqNoGenerator = requestSeqNoGenerator; this.onSourceThrottle = onSourceThrottle; + this.rateLimiterSupplier = rateLimiterSupplier; this.fileChunkRequestOptions = TransportRequestOptions.builder() .withType(TransportRequestOptions.Type.RECOVERY) .withTimeout(recoverySettings.internalActionTimeout()) @@ -78,7 +80,7 @@ public void writeFileChunk( // Pause using the rate limiter, if desired, to throttle the recovery final long throttleTimeInNanos; // always fetch the ratelimiter - it might be updated in real-time on the recovery settings - final RateLimiter rl = recoverySettings.rateLimiter(); + final RateLimiter rl = rateLimiterSupplier.get(); if (rl != null) { long bytes = bytesSinceLastPause.addAndGet(content.length()); if (bytes > rl.getMinPauseCheckBytes()) { diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java index ca89741d5bb55..21fd066b8be2f 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationSourceService.java @@ -123,7 +123,8 @@ public void messageReceived(CheckpointInfoRequest request, TransportChannel chan request.getCheckpoint().getShardId(), SegmentReplicationTargetService.Actions.FILE_CHUNK, new AtomicLong(0), - (throttleTime) -> {} + (throttleTime) -> {}, + recoverySettings::replicationRateLimiter ); final SegmentReplicationSourceHandler handler = ongoingSegmentReplications.prepareForReplication( request, diff --git a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java index 4942d39cfa48a..fbd7ab7cea346 100644 --- a/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java +++ b/server/src/main/java/org/opensearch/indices/replication/SegmentReplicationTargetService.java @@ -635,7 +635,7 @@ public void messageReceived(final FileChunkRequest request, TransportChannel cha try (ReplicationRef ref = onGoingReplications.getSafe(request.recoveryId(), request.shardId())) { final SegmentReplicationTarget target = ref.get(); final ActionListener listener = target.createOrFinishListener(channel, Actions.FILE_CHUNK, request); - target.handleFileChunk(request, target, bytesSinceLastPause, recoverySettings.rateLimiter(), listener); + target.handleFileChunk(request, target, bytesSinceLastPause, recoverySettings.replicationRateLimiter(), listener); } } } diff --git a/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java index ce2ffd8bf3fb4..1a5701d9204ef 100644 --- a/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/opensearch/repositories/blobstore/BlobStoreRepository.java @@ -3162,7 +3162,7 @@ private static OffsetRangeInputStream maybeRateLimitRemoteTransfers( public InputStream maybeRateLimitRestores(InputStream stream) { return maybeRateLimit( maybeRateLimit(stream, () -> restoreRateLimiter, restoreRateLimitingTimeInNanos, BlobStoreTransferContext.SNAPSHOT_RESTORE), - recoverySettings::rateLimiter, + recoverySettings::recoveryRateLimiter, restoreRateLimitingTimeInNanos, BlobStoreTransferContext.SNAPSHOT_RESTORE ); @@ -3185,7 +3185,7 @@ public InputStream maybeRateLimitRemoteDownloadTransfers(InputStream inputStream remoteDownloadRateLimitingTimeInNanos, BlobStoreTransferContext.REMOTE_DOWNLOAD ), - recoverySettings::rateLimiter, + recoverySettings::recoveryRateLimiter, remoteDownloadRateLimitingTimeInNanos, BlobStoreTransferContext.REMOTE_DOWNLOAD ); diff --git a/server/src/test/java/org/opensearch/indices/recovery/RecoverySettingsDynamicUpdateTests.java b/server/src/test/java/org/opensearch/indices/recovery/RecoverySettingsDynamicUpdateTests.java index 75639661f539d..2793d446d66c8 100644 --- a/server/src/test/java/org/opensearch/indices/recovery/RecoverySettingsDynamicUpdateTests.java +++ b/server/src/test/java/org/opensearch/indices/recovery/RecoverySettingsDynamicUpdateTests.java @@ -35,6 +35,8 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeUnit; +import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.test.OpenSearchTestCase; import java.util.concurrent.TimeUnit; @@ -47,7 +49,27 @@ public void testZeroBytesPerSecondIsNoRateLimit() { clusterSettings.applySettings( Settings.builder().put(RecoverySettings.INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.getKey(), 0).build() ); - assertEquals(null, recoverySettings.rateLimiter()); + assertNull(recoverySettings.recoveryRateLimiter()); + clusterSettings.applySettings( + Settings.builder().put(RecoverySettings.INDICES_REPLICATION_MAX_BYTES_PER_SEC_SETTING.getKey(), 0).build() + ); + assertNull(recoverySettings.replicationRateLimiter()); + } + + public void testSetReplicationMaxBytesPerSec() { + assertEquals(40, (int) recoverySettings.replicationRateLimiter().getMBPerSec()); + clusterSettings.applySettings( + Settings.builder() + .put(RecoverySettings.INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.getKey(), new ByteSizeValue(60, ByteSizeUnit.MB)) + .build() + ); + assertEquals(60, (int) recoverySettings.replicationRateLimiter().getMBPerSec()); + clusterSettings.applySettings( + Settings.builder() + .put(RecoverySettings.INDICES_REPLICATION_MAX_BYTES_PER_SEC_SETTING.getKey(), new ByteSizeValue(80, ByteSizeUnit.MB)) + .build() + ); + assertEquals(80, (int) recoverySettings.replicationRateLimiter().getMBPerSec()); } public void testRetryDelayStateSync() {