From 791b489eba1e7076f83e261d1bec79314db27e74 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Mar 2023 16:18:15 -0700 Subject: [PATCH] Split CBService into separate classes Splits the CBService into 2 separate classes. One contains the logic for querying information about the circuit breaker (i.e. whether or not it is tripped). The other contains the logic for running the periodic monitoring workflow. Signed-off-by: John Mazanec --- .../java/org/opensearch/knn/bwc/StatsIT.java | 4 +- .../org/opensearch/knn/index/KNNSettings.java | 2 +- .../index/mapper/KNNVectorFieldMapper.java | 45 ++-- .../knn/index/mapper/LegacyFieldMapper.java | 15 +- .../knn/index/mapper/LuceneFieldMapper.java | 6 +- .../knn/index/mapper/MethodFieldMapper.java | 6 +- .../knn/index/mapper/ModelFieldMapper.java | 6 +- .../memory/NativeMemoryCacheManager.java | 20 +- .../breaker/NativeMemoryCircuitBreaker.java | 71 ++++++ .../NativeMemoryCircuitBreakerMonitor.java | 128 ++++++++++ .../NativeMemoryCircuitBreakerMonitorDto.java | 23 ++ .../NativeMemoryCircuitBreakerService.java | 201 ---------------- .../org/opensearch/knn/plugin/KNNPlugin.java | 36 ++- .../opensearch/knn/plugin/stats/KNNStats.java | 16 +- .../NativeMemoryCircuitBreakerSupplier.java | 6 +- .../java/org/opensearch/knn/KNNTestCase.java | 12 +- .../mapper/KNNVectorFieldMapperTests.java | 32 +-- .../memory/NativeMemoryCacheManagerTests.java | 12 +- .../breaker/NativeMemoryCircuitBreakerIT.java | 2 +- ...ativeMemoryCircuitBreakerMonitorTests.java | 164 +++++++++++++ ...ativeMemoryCircuitBreakerServiceTests.java | 220 ------------------ .../NativeMemoryCircuitBreakerTests.java | 58 +++++ 22 files changed, 555 insertions(+), 530 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java create mode 100644 src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java create mode 100644 src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java delete mode 100644 src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java create mode 100644 src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java delete mode 100644 src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java create mode 100644 src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java index 7f3840a177..bf6cbcd9cf 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java @@ -9,7 +9,7 @@ import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.stats.KNNStats; import java.util.Collections; @@ -24,7 +24,7 @@ public class StatsIT extends AbstractRollingUpgradeTestCase { @Before public void setUp() throws Exception { super.setUp(); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreakerService.class); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreaker.class); this.knnStats = new KNNStats(nativeMemoryCircuitBreakerService); } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index f1dac09a4a..5aeb2e2ab8 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -383,7 +383,7 @@ public static ByteSizeValue parseknnMemoryCircuitBreakerValue(String sValue, Str */ public synchronized void updateBooleanSetting(String settingName, boolean value) { ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder().put("unregistered-setting-lets-see-what-happens", value).build(); + Settings circuitBreakerSettings = Settings.builder().put(settingName, value).build(); clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, new ActionListener<>() { @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 9c89f209ca..829624bad5 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -32,7 +32,7 @@ import org.opensearch.index.mapper.ValueFetcher; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorIndexFieldData; @@ -145,12 +145,12 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected String efConstruction; protected ModelDao modelDao; - protected NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + protected NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; - public Builder(String name, ModelDao modelDao, NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { + public Builder(String name, ModelDao modelDao, NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { super(name); this.modelDao = modelDao; - this.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; } /** @@ -167,13 +167,13 @@ public Builder( String spaceType, String m, String efConstruction, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker ) { super(name); this.spaceType = spaceType; this.m = m; this.efConstruction = efConstruction; - this.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; } @Override @@ -227,7 +227,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { .stored(stored.get()) .hasDocValues(hasDocValues.get()) .knnMethodContext(knnMethodContext) - .nativeMemoryCircuitBreakerService(nativeMemoryCircuitBreakerService) + .nativeMemoryCircuitBreakerService(nativeMemoryCircuitBreaker) .build(); return new LuceneFieldMapper(createLuceneFieldMapperInput); } @@ -239,7 +239,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - nativeMemoryCircuitBreakerService, + nativeMemoryCircuitBreaker, knnMethodContext ); } @@ -259,7 +259,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - nativeMemoryCircuitBreakerService, + nativeMemoryCircuitBreaker, modelDao, modelIdAsString ); @@ -286,7 +286,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - nativeMemoryCircuitBreakerService, + nativeMemoryCircuitBreaker, spaceType, m, efConstruction @@ -318,23 +318,16 @@ public static class TypeParser implements Mapper.TypeParser { // Use a supplier here because in {@link org.opensearch.knn.KNNPlugin#getMappers()} the ModelDao has not yet // been initialized private final Supplier modelDaoSupplier; - private final Supplier nativeMemoryCircuitBreakerServiceSupplier; + private final Supplier nativeMemoryCircuitBreakerSupplier; - public TypeParser( - Supplier modelDaoSupplier, - Supplier knnCircuitBreakerServiceSupplier - ) { + public TypeParser(Supplier modelDaoSupplier, Supplier knnCircuitBreakerServiceSupplier) { this.modelDaoSupplier = modelDaoSupplier; - this.nativeMemoryCircuitBreakerServiceSupplier = knnCircuitBreakerServiceSupplier; + this.nativeMemoryCircuitBreakerSupplier = knnCircuitBreakerServiceSupplier; } @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder( - name, - modelDaoSupplier.get(), - nativeMemoryCircuitBreakerServiceSupplier.get() - ); + Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), nativeMemoryCircuitBreakerSupplier.get()); builder.parse(name, parserContext, node); // All ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; this.dimension = mappedFieldType.getDimension(); - this.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; updateEngineStats(); } @@ -472,7 +465,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx } void validateIfCircuitBreakerIsNotTriggered() { - if (nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()) { + if (nativeMemoryCircuitBreaker.isTripped()) { throw new IllegalStateException( "Indexing knn vector fields is rejected as circuit breaker triggered. Check _opendistro/_knn/stats for detailed state" ); @@ -545,7 +538,7 @@ protected boolean docValuesByDefault() { @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, nativeMemoryCircuitBreakerService).init(this); + return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, nativeMemoryCircuitBreaker).init(this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java index 1490dfc16b..a88a6f4115 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java @@ -10,7 +10,7 @@ import org.opensearch.common.Explicit; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.util.KNNEngine; @@ -46,12 +46,12 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService, + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker, String spaceType, String m, String efConstruction ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreakerService); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreaker); this.spaceType = spaceType; this.m = m; @@ -72,13 +72,8 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder( - simpleName(), - this.spaceType, - this.m, - this.efConstruction, - this.nativeMemoryCircuitBreakerService - ).init(this); + return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction, this.nativeMemoryCircuitBreaker) + .init(this); } static String getSpaceType(Settings indexSettings) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 55d97db2a9..6c5b5e1b01 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -15,7 +15,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; @@ -45,7 +45,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getNativeMemoryCircuitBreakerService() + input.getNativeMemoryCircuitBreaker() ); this.knnMethod = input.getKnnMethodContext(); @@ -131,6 +131,6 @@ static class CreateLuceneFieldMapperInput { @NonNull KNNMethodContext knnMethodContext; @NonNull - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index a896b2f93f..3d641b45a9 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -9,7 +9,7 @@ import org.opensearch.common.Explicit; import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.util.KNNEngine; @@ -33,11 +33,11 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService, + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker, KNNMethodContext knnMethodContext ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreakerService); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreaker); this.knnMethod = knnMethodContext; diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 1a865d6e3d..1d90269b09 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -8,7 +8,7 @@ import org.apache.lucene.document.FieldType; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -29,11 +29,11 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService, + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker, ModelDao modelDao, String modelId ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreakerService); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreaker); this.modelId = modelId; this.modelDao = modelDao; diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index 03cba84b81..1e2b21606a 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -22,7 +22,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.stats.StatNames; import java.io.Closeable; @@ -43,7 +43,7 @@ public class NativeMemoryCacheManager implements Closeable { private static final Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class); private static NativeMemoryCacheManager INSTANCE; - private static NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + private static NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; private Cache cache; private final ExecutorService executor; private AtomicBoolean cacheCapacityReached; @@ -71,8 +71,8 @@ public static synchronized NativeMemoryCacheManager getInstance() { private void initialize() { initialize( NativeMemoryCacheManagerDto.builder() - .isWeightLimited(nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()) - .maxWeight(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb()) + .isWeightLimited(nativeMemoryCircuitBreaker.isEnabled()) + .maxWeight(nativeMemoryCircuitBreaker.getLimit().getKb()) .isExpirationLimited(KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)) .expiryTimeInMin( ((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)).getMinutes() @@ -101,8 +101,8 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) { cache = cacheBuilder.build(); } - public static void initialize(NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { - NativeMemoryCacheManager.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + public static void initialize(NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { + NativeMemoryCacheManager.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; } /** @@ -111,8 +111,8 @@ public static void initialize(NativeMemoryCircuitBreakerService nativeMemoryCirc public synchronized void rebuildCache() { rebuildCache( NativeMemoryCacheManagerDto.builder() - .isWeightLimited(nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()) - .maxWeight(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb()) + .isWeightLimited(nativeMemoryCircuitBreaker.isEnabled()) + .maxWeight(nativeMemoryCircuitBreaker.getLimit().getKb()) .isExpirationLimited(KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)) .expiryTimeInMin( ((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)).getMinutes() @@ -372,7 +372,7 @@ private void onRemoval(RemovalNotification remov nativeMemoryAllocation.close(); if (RemovalCause.SIZE == removalNotification.getCause()) { - nativeMemoryCircuitBreakerService.setCircuitBreaker(true); + nativeMemoryCircuitBreaker.set(true); setCacheCapacityReached(true); } @@ -380,7 +380,7 @@ private void onRemoval(RemovalNotification remov } private Float getSizeAsPercentage(long size) { - long cbLimit = nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb(); + long cbLimit = nativeMemoryCircuitBreaker.getLimit().getKb(); if (cbLimit == 0) { return 0.0F; } diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java new file mode 100644 index 0000000000..fa40e2dffe --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; + +/** + * The circuit breaker gets tripped based on memory demand tracked by the {@link NativeMemoryCacheManager}. + * When {@link NativeMemoryCacheManager}'s cache fills up, if the circuit breaking logic is enabled, it will trip the + * circuit breaker. Elsewhere in the code, the circuit breaker's value can be queried to prevent actions that should + * not happen during high memory pressure. + */ +@AllArgsConstructor +@Log4j2 +public class NativeMemoryCircuitBreaker { + private final KNNSettings knnSettings; + + /** + * Checks if the circuit breaker is triggered + * + * @return true if circuit breaker is triggered; false otherwise + */ + public boolean isTripped() { + return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED); + } + + /** + * Sets circuit breaker to new value + * + * @param circuitBreaker value to update circuit breaker to + */ + public void set(boolean circuitBreaker) { + knnSettings.updateBooleanSetting(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, circuitBreaker); + } + + /** + * Gets the limit of the circuit breaker + * + * @return limit as ByteSizeValue of native memory circuit breaker + */ + public ByteSizeValue getLimit() { + return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT); + } + + /** + * Determine if the circuit breaker is enabled + * + * @return true if circuit breaker is enabled. False otherwise. + */ + public boolean isEnabled() { + return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED); + } + + /** + * Returns the percentage as a double for when to unset the circuit breaker + * + * @return percentage as double for unsetting circuit breaker + */ + @VisibleForTesting + double getUnsetPercentage() { + return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE); + } +} diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java new file mode 100644 index 0000000000..8003c84dea --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; +import org.opensearch.knn.plugin.stats.StatNames; +import org.opensearch.knn.plugin.transport.KNNStatsAction; +import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; +import org.opensearch.knn.plugin.transport.KNNStatsRequest; +import org.opensearch.knn.plugin.transport.KNNStatsResponse; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.io.Closeable; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Job that runs periodically to monitor native memory usage on the node/in the cluster and untrip the + * NativeMemoryCircuitBreaker if necessary. + */ +@Log4j2 +public class NativeMemoryCircuitBreakerMonitor implements Closeable { + private final NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; + private final NativeMemoryCacheManager nativeMemoryCacheManager; + private final ClusterService clusterService; + private final Client client; + private final ThreadPool threadPool; + private final AtomicReference isStarted; + private Scheduler.Cancellable monitorFuture; + public static final int CB_TIME_INTERVAL = 2 * 60; // seconds + private static final TimeValue STATS_REQUEST_TIMEOUT = new TimeValue(1000 * 10); // 10 second timeout + + /** + * Constructor + * + * @param nativeMemoryCircuitBreakerMonitorDto contains necessary initialization values + */ + public NativeMemoryCircuitBreakerMonitor(NativeMemoryCircuitBreakerMonitorDto nativeMemoryCircuitBreakerMonitorDto) { + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreakerMonitorDto.getNativeMemoryCircuitBreaker(); + this.nativeMemoryCacheManager = nativeMemoryCircuitBreakerMonitorDto.getNativeMemoryCacheManager(); + this.clusterService = nativeMemoryCircuitBreakerMonitorDto.getClusterService(); + this.client = nativeMemoryCircuitBreakerMonitorDto.getClient(); + this.threadPool = nativeMemoryCircuitBreakerMonitorDto.getThreadPool(); + this.isStarted = new AtomicReference<>(false); + this.monitorFuture = null; + } + + /** + * Schedules monitor job to be run + */ + public synchronized void start() { + // Ensure monitor future is only scheduled once + boolean isAlreadyStarted = this.isStarted.getAndSet(true); + if (isAlreadyStarted == false) { + this.monitorFuture = threadPool.scheduleWithFixedDelay( + this::monitor, + TimeValue.timeValueSeconds(CB_TIME_INTERVAL), + ThreadPool.Names.GENERIC + ); + } + } + + @Override + public synchronized void close() { + if (this.monitorFuture != null && this.monitorFuture.isCancelled() == false) { + this.monitorFuture.cancel(); + } + } + + @VisibleForTesting + void monitor() { + if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { + long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); + long circuitBreakerLimitSizeKiloBytes = nativeMemoryCircuitBreaker.getLimit().getKb(); + long circuitBreakerUnsetSizeKiloBytes = (long) ((nativeMemoryCircuitBreaker.getUnsetPercentage() / 100) + * circuitBreakerLimitSizeKiloBytes); + // Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes + if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { + nativeMemoryCacheManager.setCacheCapacityReached(false); + } + } + + // Leader node untriggers CB if all nodes have not reached their max capacity + if (nativeMemoryCircuitBreaker.isTripped() && clusterService.state().nodes().isLocalNodeElectedClusterManager()) { + KNNStatsRequest knnStatsRequest = new KNNStatsRequest(); + knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); + knnStatsRequest.timeout(STATS_REQUEST_TIMEOUT); + + try { + KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); + List nodeResponses = knnStatsResponse.getNodes(); + + List nodesAtMaxCapacity = new ArrayList<>(); + for (KNNStatsNodeResponse nodeResponse : nodeResponses) { + if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) { + nodesAtMaxCapacity.add(nodeResponse.getNode().getId()); + } + } + + if (nodesAtMaxCapacity.isEmpty() == false) { + log.info( + "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " + + String.join(",", nodesAtMaxCapacity) + + "." + ); + } else { + log.info( + "[KNN] Cache capacity below {}% of the circuit breaker limit for all nodes. Unsetting knn.circuit_breaker.triggered flag.", + nativeMemoryCircuitBreaker.getUnsetPercentage() + ); + nativeMemoryCircuitBreaker.set(false); + } + } catch (Exception e) { + log.error("[KNN] Error when trying to update the circuit breaker setting", e); + } + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java new file mode 100644 index 0000000000..59d46c9fb3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import lombok.Builder; +import lombok.Value; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; +import org.opensearch.threadpool.ThreadPool; + +@Value +@Builder +public class NativeMemoryCircuitBreakerMonitorDto { + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; + NativeMemoryCacheManager nativeMemoryCacheManager; + ClusterService clusterService; + Client client; + ThreadPool threadPool; +} diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java deleted file mode 100644 index 1fa3af0d2a..0000000000 --- a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.memory.breaker; - -import com.google.common.annotations.VisibleForTesting; -import lombok.Value; -import lombok.extern.log4j.Log4j2; -import org.opensearch.common.component.AbstractLifecycleComponent; -import org.opensearch.common.unit.ByteSizeValue; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.knn.plugin.transport.KNNStatsAction; -import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; -import org.opensearch.knn.plugin.transport.KNNStatsRequest; -import org.opensearch.knn.plugin.transport.KNNStatsResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.threadpool.Scheduler; -import org.opensearch.threadpool.ThreadPool; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicReference; - -/** - * Service handling native memory circuit breaking logic. The circuit breaker gets tripped based on memory demand - * tracked by the {@link NativeMemoryCacheManager}. When {@link NativeMemoryCacheManager}'s cache fills up, if the - * circuit breaking logic is enabled, it will trip the circuit. Elsewhere in the code, the circuit breaker's value can - * be queried to prevent actions that should not happen during high memory pressure. - */ -@Log4j2 -public class NativeMemoryCircuitBreakerService extends AbstractLifecycleComponent { - private final ThreadPool threadPool; - private final KNNSettings knnSettings; - private final ClusterService clusterService; - private final Client client; - // Cancellable task to track circuitBreakerRunnable. In order to schedule, doStart must be called. doStart will - // only start the future if the previous value is null. Therefore, to close this class, do NOT set the value of - // this variable to null. To stop this class, this variable should be set to null so that it may be restarted. - @VisibleForTesting - final AtomicReference circuitBreakerFuture; - - public static final int CB_TIME_INTERVAL = 2 * 60; // seconds - - /** - * Constructor for creation of circuit breaker service for KNN - * - * @param knnSettings Settings class for k-NN - * @param threadPool thread pool for circuit breaker monitor to run job - * @param clusterService cluster service used to retrieve information about the cluster - * @param client client used to make calls to the cluster - */ - public NativeMemoryCircuitBreakerService(KNNSettings knnSettings, ThreadPool threadPool, ClusterService clusterService, Client client) { - this.knnSettings = knnSettings; - this.threadPool = threadPool; - this.clusterService = clusterService; - this.client = client; - this.circuitBreakerFuture = new AtomicReference<>(null); - } - - /** - * Checks if the circuit breaker is triggered - * - * @return true if circuit breaker is triggered; false otherwise - */ - public boolean isCircuitBreakerTriggered() { - return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED); - } - - /** - * Sets circuit breaker to new value - * - * @param circuitBreaker value to update circuit breaker to - */ - public void setCircuitBreaker(boolean circuitBreaker) { - knnSettings.updateBooleanSetting(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, circuitBreaker); - } - - /** - * Gets the limit of the circuit breaker - * - * @return limit as ByteSizeValue of native memory circuit breaker - */ - public ByteSizeValue getCircuitBreakerLimit() { - return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT); - } - - /** - * Determine if the circuit breaker is enabled - * - * @return true if circuit breaker is enabled. False otherwise. - */ - public boolean isCircuitBreakerEnabled() { - return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED); - } - - /** - * Returns the percentage as a double for when to unset the circuit breaker - * - * @return percentage as double for unsetting circuit breaker - */ - @VisibleForTesting - double getCircuitBreakerUnsetPercentage() { - return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE); - } - - @Override - protected void doStart() { - Monitor monitor = getMonitor(); - this.circuitBreakerFuture.compareAndSet( - null, - threadPool.scheduleWithFixedDelay(monitor, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC) - ); - } - - @VisibleForTesting - Monitor getMonitor() { - return new Monitor(this, NativeMemoryCacheManager.getInstance(), clusterService, client); - } - - @Override - protected void doStop() { - Scheduler.Cancellable cancellable = this.circuitBreakerFuture.getAndSet(null); - if (cancellable != null && !cancellable.isCancelled()) { - cancellable.cancel(); - } - } - - @Override - protected void doClose() { - Scheduler.Cancellable cancellable = this.circuitBreakerFuture.get(); - if (cancellable != null && !cancellable.isCancelled()) { - cancellable.cancel(); - } - } - - @VisibleForTesting - @Value - static class Monitor implements Runnable { - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; - NativeMemoryCacheManager nativeMemoryCacheManager; - ClusterService clusterService; - Client client; - private static final TimeValue STATS_REQUEST_TIMEOUT = new TimeValue(1000 * 10); // 10 second timeout - - @Override - public void run() { - if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { - long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); - long circuitBreakerLimitSizeKiloBytes = nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb(); - long circuitBreakerUnsetSizeKiloBytes = (long) ((nativeMemoryCircuitBreakerService.getCircuitBreakerUnsetPercentage() / 100) - * circuitBreakerLimitSizeKiloBytes); - // Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes - if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { - nativeMemoryCacheManager.setCacheCapacityReached(false); - } - } - - // Leader node untriggers CB if all nodes have not reached their max capacity - if (nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered() - && clusterService.state().nodes().isLocalNodeElectedClusterManager()) { - KNNStatsRequest knnStatsRequest = new KNNStatsRequest(); - knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); - knnStatsRequest.timeout(STATS_REQUEST_TIMEOUT); - - try { - KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); - List nodeResponses = knnStatsResponse.getNodes(); - - List nodesAtMaxCapacity = new ArrayList<>(); - for (KNNStatsNodeResponse nodeResponse : nodeResponses) { - if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) { - nodesAtMaxCapacity.add(nodeResponse.getNode().getId()); - } - } - - if (nodesAtMaxCapacity.isEmpty() == false) { - log.info( - "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " - + String.join(",", nodesAtMaxCapacity) - + "." - ); - } else { - log.info( - "[KNN] Cache capacity below {}% of the circuit breaker limit for all nodes. Unsetting knn.circuit_breaker.triggered flag.", - nativeMemoryCircuitBreakerService.getCircuitBreakerUnsetPercentage() - ); - nativeMemoryCircuitBreakerService.setCircuitBreaker(false); - } - } catch (Exception e) { - log.error("[KNN] Error when trying to update the circuit breaker setting", e); - } - } - } - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index d1dbf0185b..1f6cad6ca1 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -12,8 +12,10 @@ import org.opensearch.index.engine.EngineFactory; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerMonitor; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerMonitorDto; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -158,13 +160,14 @@ public class KNNPlugin extends Plugin private KNNStats knnStats; private ClusterService clusterService; - private NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + private NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; + private NativeMemoryCircuitBreakerMonitor nativeMemoryCircuitBreakerMonitor; @Override public Map getMappers() { return Collections.singletonMap( KNNVectorFieldMapper.CONTENT_TYPE, - new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance, () -> nativeMemoryCircuitBreakerService) + new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance, () -> nativeMemoryCircuitBreaker) ); } @@ -202,13 +205,21 @@ public Collection createComponents( KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); - nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService(KNNSettings.state(), threadPool, clusterService, client); - NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreakerService); - // Called after NativeMemoryCacheManager initialization. Start method requires access to an instance of the - // NativeMemoryCacheManager that needs to be initialized. - nativeMemoryCircuitBreakerService.start(); - knnStats = new KNNStats(nativeMemoryCircuitBreakerService); - return ImmutableList.of(knnStats, nativeMemoryCircuitBreakerService); + + nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(KNNSettings.state()); + NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreaker); + nativeMemoryCircuitBreakerMonitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(NativeMemoryCacheManager.getInstance()) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + nativeMemoryCircuitBreakerMonitor.start(); + knnStats = new KNNStats(nativeMemoryCircuitBreaker); + return ImmutableList.of(knnStats, nativeMemoryCircuitBreaker); } @Override @@ -367,4 +378,9 @@ public Settings additionalSettings() { ).collect(Collectors.toList()); return Settings.builder().putList(IndexModule.INDEX_STORE_HYBRID_MMAP_EXTENSIONS.getKey(), combinedSettings).build(); } + + @Override + public void close() { + nativeMemoryCircuitBreakerMonitor.close(); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index 3c64c95019..461263c048 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -9,7 +9,7 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; @@ -36,10 +36,10 @@ public class KNNStats { /** * Constructor * - * @param nativeMemoryCircuitBreakerService k-NN circuit breaker service for native memory + * @param nativeMemoryCircuitBreaker k-NN circuit breaker service for native memory */ - public KNNStats(NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { - this.knnStats = buildStatsMap(nativeMemoryCircuitBreakerService); + public KNNStats(NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { + this.knnStats = buildStatsMap(nativeMemoryCircuitBreaker); } /** @@ -80,10 +80,10 @@ private Map> getClusterOrNodeStats(Boolean getClusterStats) { return statsMap; } - private Map> buildStatsMap(NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { + private Map> buildStatsMap(NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { ImmutableMap.Builder> builder = ImmutableMap.>builder(); addQueryStats(builder); - addNativeMemoryStats(builder, nativeMemoryCircuitBreakerService); + addNativeMemoryStats(builder, nativeMemoryCircuitBreaker); addEngineStats(builder); addScriptStats(builder); addModelStats(builder); @@ -101,7 +101,7 @@ private void addQueryStats(ImmutableMap.Builder> builder) { private void addNativeMemoryStats( ImmutableMap.Builder> builder, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker ) { builder.put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::missCount))) @@ -134,7 +134,7 @@ private void addNativeMemoryStats( .put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS))) .put( StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), - new KNNStat<>(true, new NativeMemoryCircuitBreakerSupplier(nativeMemoryCircuitBreakerService)) + new KNNStat<>(true, new NativeMemoryCircuitBreakerSupplier(nativeMemoryCircuitBreaker)) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java index 76b8d785d0..cb3e73a627 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java @@ -6,7 +6,7 @@ package org.opensearch.knn.plugin.stats.suppliers; import lombok.AllArgsConstructor; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import java.util.function.Supplier; @@ -16,10 +16,10 @@ @AllArgsConstructor public class NativeMemoryCircuitBreakerSupplier implements Supplier { - private final NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + private final NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; @Override public Boolean get() { - return nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered(); + return nativeMemoryCircuitBreaker.isTripped(); } } diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 9532f5965c..01a257208f 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -19,7 +19,7 @@ import org.opensearch.common.unit.ByteSizeValue; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; @@ -40,10 +40,10 @@ */ public class KNNTestCase extends OpenSearchTestCase { - protected static final NativeMemoryCircuitBreakerService NEVER_TRIGGERED_CB_SERVICE = mock(NativeMemoryCircuitBreakerService.class); + protected static final NativeMemoryCircuitBreaker NEVER_TRIGGERED_CB = mock(NativeMemoryCircuitBreaker.class); static { - when(NEVER_TRIGGERED_CB_SERVICE.isCircuitBreakerTriggered()).thenReturn(false); - when(NEVER_TRIGGERED_CB_SERVICE.getCircuitBreakerLimit()).thenReturn(new ByteSizeValue(100, ByteSizeUnit.KB)); + when(NEVER_TRIGGERED_CB.isTripped()).thenReturn(false); + when(NEVER_TRIGGERED_CB.getLimit()).thenReturn(new ByteSizeValue(100, ByteSizeUnit.KB)); } @Mock @@ -57,7 +57,7 @@ public class KNNTestCase extends OpenSearchTestCase { @Mock protected NativeMemoryCacheManager nativeMemoryCacheManager; @Mock - protected NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + protected NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; @Mock protected DiscoveryNode node; @Mock @@ -97,7 +97,7 @@ public void resetState() { KNNSettings.state().setClusterService(clusterService); // Clean up the cache - NativeMemoryCacheManager.initialize(NEVER_TRIGGERED_CB_SERVICE); + NativeMemoryCacheManager.initialize(NEVER_TRIGGERED_CB); NativeMemoryCacheManager.getInstance().invalidateAll(); NativeMemoryCacheManager.getInstance().close(); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index ce2cfdb2c0..281979f446 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -30,7 +30,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -82,15 +82,15 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, NEVER_TRIGGERED_CB); assertEquals(6, builder.getParameters().size()); } public void testBuilder_build_fromKnnMethodContext() { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreakerService.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB_SERVICE); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreaker.class); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -127,7 +127,7 @@ public void testBuilder_build_fromKnnMethodContext() { public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -164,7 +164,7 @@ public void testBuilder_build_fromModel() { public void testBuilder_build_fromLegacy() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -193,7 +193,7 @@ public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOExcep Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); KNNEngine.LUCENE.setInitialized(false); @@ -279,7 +279,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int efConstruction = 321; @@ -343,7 +343,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidSpaceType() throws Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int efConstruction = 321; int dimension = 133; @@ -398,7 +398,7 @@ public void testTypeParser_parse_fromKnnMethodContext() throws IOException { Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int efConstruction = 321; int dimension = 133; @@ -495,7 +495,7 @@ public void testTypeParser_parse_fromModel() throws IOException { Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); String modelId = "test-id"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -529,7 +529,7 @@ public void testTypeParser_parse_fromLegacy() throws IOException { .build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int dimension = 122; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -555,7 +555,7 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int dimension = 133; int efConstruction = 321; @@ -631,7 +631,7 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser( () -> mockModelDao, - () -> NEVER_TRIGGERED_CB_SERVICE + () -> NEVER_TRIGGERED_CB ); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -701,7 +701,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues() throws IOExceptio .copyTo(FieldMapper.CopyTo.empty()) .hasDocValues(true) .ignoreMalformed(new Explicit<>(true, true)) - .nativeMemoryCircuitBreakerService(NEVER_TRIGGERED_CB_SERVICE) + .nativeMemoryCircuitBreakerService(NEVER_TRIGGERED_CB) .knnMethodContext(knnMethodContext); ParseContext.Document document = new ParseContext.Document(); @@ -796,7 +796,7 @@ public Mapper.TypeParser.ParserContext buildParserContext(String indexName, Sett return new Mapper.TypeParser.ParserContext( null, mapperService, - type -> new KNNVectorFieldMapper.TypeParser(() -> mockModelDao, () -> NEVER_TRIGGERED_CB_SERVICE), + type -> new KNNVectorFieldMapper.TypeParser(() -> mockModelDao, () -> NEVER_TRIGGERED_CB), CURRENT, null, null, diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 7ba3fc429f..ee4af5ab32 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -18,7 +18,7 @@ import org.opensearch.common.unit.ByteSizeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -29,7 +29,6 @@ import java.util.Map; import java.util.concurrent.ExecutionException; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; @@ -284,11 +283,10 @@ public void testGetIndexGraphCount() throws ExecutionException, IOException { public void testGetMaxCacheSizeInKB() { long cbLimitInKB = 100; ByteSizeValue defaultCBLimit = new ByteSizeValue(cbLimitInKB, ByteSizeUnit.KB); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreakerService.class); - when(nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()).thenReturn(true); - when(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit()).thenReturn(defaultCBLimit); - doNothing().when(nativeMemoryCircuitBreakerService).close(); - NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreakerService); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = mock(NativeMemoryCircuitBreaker.class); + when(nativeMemoryCircuitBreaker.isEnabled()).thenReturn(true); + when(nativeMemoryCircuitBreaker.getLimit()).thenReturn(defaultCBLimit); + NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreaker); NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); assertEquals(cbLimitInKB, nativeMemoryCacheManager.getMaxCacheSizeInKilobytes()); nativeMemoryCacheManager.close(); diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java index d85201a09a..c16103a646 100644 --- a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java +++ b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java @@ -20,7 +20,7 @@ import java.util.Collections; import java.util.Map; -import static org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService.CB_TIME_INTERVAL; +import static org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerMonitor.CB_TIME_INTERVAL; /** * Integration tests to test Circuit Breaker functionality diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java new file mode 100644 index 0000000000..28e346ab56 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.plugin.stats.StatNames; +import org.opensearch.knn.plugin.transport.KNNStatsAction; +import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; +import org.opensearch.knn.plugin.transport.KNNStatsRequest; +import org.opensearch.knn.plugin.transport.KNNStatsResponse; +import org.opensearch.threadpool.Scheduler; + +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class NativeMemoryCircuitBreakerMonitorTests extends KNNTestCase { + + public void testStart_whenCalledMultipleTimes_thenScheduleMonitorOnce() { + Scheduler.Cancellable cancellable = createCancellable(); + when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + + monitor.start(); + monitor.start(); + monitor.start(); + + verify(threadPool, times(1)).scheduleWithFixedDelay(any(), any(), any()); + } + + public void testClose_whenCalled_thenCancel() { + Scheduler.Cancellable cancellable = createCancellable(); + when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + + monitor.start(); + assertFalse(cancellable.isCancelled()); + monitor.close(); + assertTrue(cancellable.isCancelled()); + } + + public void testMonitor_whenDataNodeCacheSizeLowerThanThreshold_thenUnsetCacheCapacityReached() { + // Setup state so that cache capacity is marked as reached but the ratio in the cache is less than + // the unset ratio + when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(true); + when(node.isDataNode()).thenReturn(true); + when(clusterService.localNode()).thenReturn(node); + long cacheSizeInKb = 1; + long cbLimitInKb = 100; + double unsetSizeInKb = 99; + when(nativeMemoryCacheManager.getCacheSizeInKilobytes()).thenReturn(cacheSizeInKb); + when(nativeMemoryCircuitBreaker.getLimit()).thenReturn(new ByteSizeValue(cbLimitInKb, ByteSizeUnit.KB)); + when(nativeMemoryCircuitBreaker.getUnsetPercentage()).thenReturn(unsetSizeInKb); + + // Avoid duties of cluster manager + when(nativeMemoryCircuitBreaker.isTripped()).thenReturn(false); + + doNothing().when(nativeMemoryCacheManager).setCacheCapacityReached(false); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + + monitor.monitor(); + verify(nativeMemoryCacheManager, times(1)).setCacheCapacityReached(false); + } + + @SneakyThrows + public void testMonitorRun_whenClusterManagerAndClusterHasCapacity_thenUnsetCircuitBreaker() { + // Setup state so that current node is cluster manager and should unset circuit breaker5 + when(nativeMemoryCircuitBreaker.isTripped()).thenReturn(true); + when(discoveryNodes.isLocalNodeElectedClusterManager()).thenReturn(true); + when(clusterState.nodes()).thenReturn(discoveryNodes); + when(clusterService.state()).thenReturn(clusterState); + + // Ensure all nodes have cache capacity as not reached + Map reachedMap = ImmutableMap.of(StatNames.CACHE_CAPACITY_REACHED.getName(), false); + List nodeResponses = List.of( + new KNNStatsNodeResponse(node, reachedMap), + new KNNStatsNodeResponse(node, reachedMap), + new KNNStatsNodeResponse(node, reachedMap) + ); + KNNStatsResponse knnStatsResponse = mock(KNNStatsResponse.class); + when(knnStatsResponse.getNodes()).thenReturn(nodeResponses); + + PlainActionFuture actionFuture = new PlainActionFuture<>() { + @Override + public KNNStatsResponse get() { + return knnStatsResponse; + } + }; + + when(client.execute(any(KNNStatsAction.class), any(KNNStatsRequest.class))).thenReturn(actionFuture); + + // Avoid duties of data node + when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(false); + + doNothing().when(nativeMemoryCircuitBreaker).set(false); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + monitor.monitor(); + verify(nativeMemoryCircuitBreaker, times(1)).set(false); + } + + private Scheduler.Cancellable createCancellable() { + return new Scheduler.Cancellable() { + boolean isCancelled = false; + + @Override + public boolean cancel() { + isCancelled = true; + return true; + } + + @Override + public boolean isCancelled() { + return isCancelled; + } + }; + } + +} diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java deleted file mode 100644 index 0013aece84..0000000000 --- a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.memory.breaker; - -import com.google.common.collect.ImmutableMap; -import lombok.SneakyThrows; -import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.unit.ByteSizeUnit; -import org.opensearch.common.unit.ByteSizeValue; -import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.knn.plugin.transport.KNNStatsAction; -import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; -import org.opensearch.knn.plugin.transport.KNNStatsRequest; -import org.opensearch.knn.plugin.transport.KNNStatsResponse; -import org.opensearch.threadpool.Scheduler; -import org.opensearch.threadpool.ThreadPool; - -import java.util.List; -import java.util.Map; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED; -import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; -import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT; - -public class NativeMemoryCircuitBreakerServiceTests extends KNNTestCase { - - public void testCircuitBreaker_whenSet_thenCircuitBreakerUpdated() { - boolean isTriggered = randomBoolean(); - doNothing().when(knnSettings).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTriggered); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - nativeMemoryCircuitBreakerService.setCircuitBreaker(isTriggered); - verify(knnSettings, times(1)).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTriggered); - } - - public void testGetCircuitBreakerLimit() { - ByteSizeValue circuitBreakerLimit = new ByteSizeValue(randomIntBetween(10, 10000), ByteSizeUnit.KB); - when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT)).thenReturn(circuitBreakerLimit); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - assertEquals(circuitBreakerLimit, nativeMemoryCircuitBreakerService.getCircuitBreakerLimit()); - } - - public void testIsCircuitBreakerTriggered() { - boolean isTriggered = randomBoolean(); - when(knnSettings.getSettingValue(KNN_CIRCUIT_BREAKER_TRIGGERED)).thenReturn(isTriggered); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - assertEquals(isTriggered, nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()); - } - - public void testIsCircuitBreakerEnabled() { - boolean isEnabled = randomBoolean(); - when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)).thenReturn(isEnabled); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - assertEquals(isEnabled, nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()); - } - - public void testStartStopLifeCycle() { - Scheduler.Cancellable cancellable = createCancellable(); - when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = createNativeMemoryCircuitBreakerServiceWithNoopMonitor( - knnSettings, - threadPool, - clusterService, - client - ); - nativeMemoryCircuitBreakerService.doStart(); - assertEquals(cancellable, nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - assertFalse(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get().isCancelled()); - nativeMemoryCircuitBreakerService.doStop(); - assertNull(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - } - - public void testStartCloseLifeCycle() { - Scheduler.Cancellable cancellable = createCancellable(); - when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = createNativeMemoryCircuitBreakerServiceWithNoopMonitor( - knnSettings, - threadPool, - clusterService, - client - ); - nativeMemoryCircuitBreakerService.doStart(); - assertEquals(cancellable, nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - assertFalse(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get().isCancelled()); - nativeMemoryCircuitBreakerService.doClose(); - assertEquals(cancellable, nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - assertTrue(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get().isCancelled()); - } - - public void testMonitorRun_whenDataNodeCacheSizeLowerThanThreshold_thenUnsetCacheCapacityReached() { - // Setup state so that cache capacity is marked as reached but the ratio in the cache is less than - // the unset ratio - when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(true); - when(node.isDataNode()).thenReturn(true); - when(clusterService.localNode()).thenReturn(node); - long cacheSizeInKb = 1; - long cbLimitInKb = 100; - double unsetSizeInKb = 99; - when(nativeMemoryCacheManager.getCacheSizeInKilobytes()).thenReturn(cacheSizeInKb); - when(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit()).thenReturn(new ByteSizeValue(cbLimitInKb, ByteSizeUnit.KB)); - when(nativeMemoryCircuitBreakerService.getCircuitBreakerUnsetPercentage()).thenReturn(unsetSizeInKb); - - // Avoid duties of cluster manager - when(nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()).thenReturn(false); - - doNothing().when(nativeMemoryCacheManager).setCacheCapacityReached(false); - NativeMemoryCircuitBreakerService.Monitor monitor = new NativeMemoryCircuitBreakerService.Monitor( - nativeMemoryCircuitBreakerService, - nativeMemoryCacheManager, - clusterService, - client - ); - monitor.run(); - verify(nativeMemoryCacheManager, times(1)).setCacheCapacityReached(false); - } - - @SneakyThrows - public void testMonitorRun_whenClusterManagerAndClusterHasCapacity_thenUnsetCircuitBreaker() { - // Setup state so that current node is cluster manager and should unset circuit breaker5 - when(nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()).thenReturn(true); - when(discoveryNodes.isLocalNodeElectedClusterManager()).thenReturn(true); - when(clusterState.nodes()).thenReturn(discoveryNodes); - when(clusterService.state()).thenReturn(clusterState); - - // Ensure all nodes have cache capacity as not reached - Map reachedMap = ImmutableMap.of(StatNames.CACHE_CAPACITY_REACHED.getName(), false); - List nodeResponses = List.of( - new KNNStatsNodeResponse(node, reachedMap), - new KNNStatsNodeResponse(node, reachedMap), - new KNNStatsNodeResponse(node, reachedMap) - ); - KNNStatsResponse knnStatsResponse = mock(KNNStatsResponse.class); - when(knnStatsResponse.getNodes()).thenReturn(nodeResponses); - - PlainActionFuture actionFuture = new PlainActionFuture<>() { - @Override - public KNNStatsResponse get() { - return knnStatsResponse; - } - }; - - when(client.execute(any(KNNStatsAction.class), any(KNNStatsRequest.class))).thenReturn(actionFuture); - - // Avoid duties of data node - when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(false); - - doNothing().when(nativeMemoryCircuitBreakerService).setCircuitBreaker(false); - NativeMemoryCircuitBreakerService.Monitor monitor = new NativeMemoryCircuitBreakerService.Monitor( - nativeMemoryCircuitBreakerService, - nativeMemoryCacheManager, - clusterService, - client - ); - monitor.run(); - verify(nativeMemoryCircuitBreakerService, times(1)).setCircuitBreaker(false); - } - - private Scheduler.Cancellable createCancellable() { - return new Scheduler.Cancellable() { - boolean isCancelled = false; - - @Override - public boolean cancel() { - isCancelled = true; - return true; - } - - @Override - public boolean isCancelled() { - return isCancelled; - } - }; - } - - private NativeMemoryCircuitBreakerService createNativeMemoryCircuitBreakerServiceWithNoopMonitor( - KNNSettings knnSettings, - ThreadPool threadPool, - ClusterService clusterService, - Client client - ) { - return new NativeMemoryCircuitBreakerService(knnSettings, threadPool, clusterService, client) { - @Override - protected Monitor getMonitor() { - return mock(Monitor.class); - } - }; - } -} diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java new file mode 100644 index 0000000000..785a3b8c15 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED; +import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE; +import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; +import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT; + +public class NativeMemoryCircuitBreakerTests extends KNNTestCase { + + public void testIsTripped() { + boolean isTripped = randomBoolean(); + when(knnSettings.getSettingValue(KNN_CIRCUIT_BREAKER_TRIGGERED)).thenReturn(isTripped); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(isTripped, nativeMemoryCircuitBreaker.isTripped()); + } + + public void testSet() { + boolean isTripped = randomBoolean(); + doNothing().when(knnSettings).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTripped); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + nativeMemoryCircuitBreaker.set(isTripped); + verify(knnSettings, times(1)).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTripped); + } + + public void testGetLimit() { + ByteSizeValue circuitBreakerLimit = new ByteSizeValue(randomIntBetween(10, 10000), ByteSizeUnit.KB); + when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT)).thenReturn(circuitBreakerLimit); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(circuitBreakerLimit, nativeMemoryCircuitBreaker.getLimit()); + } + + public void testIsEnabled() { + boolean isEnabled = randomBoolean(); + when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)).thenReturn(isEnabled); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(isEnabled, nativeMemoryCircuitBreaker.isEnabled()); + } + + public void testGetUnsetPercentage() { + double unsetPercentage = 71; + when(knnSettings.getSettingValue(KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE)).thenReturn(unsetPercentage); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(unsetPercentage, nativeMemoryCircuitBreaker.getUnsetPercentage(), 0.0001); + } +}