diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java index 7f3840a177..bf6cbcd9cf 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/StatsIT.java @@ -9,7 +9,7 @@ import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.stats.KNNStats; import java.util.Collections; @@ -24,7 +24,7 @@ public class StatsIT extends AbstractRollingUpgradeTestCase { @Before public void setUp() throws Exception { super.setUp(); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreakerService.class); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreaker.class); this.knnStats = new KNNStats(nativeMemoryCircuitBreakerService); } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index f1dac09a4a..5aeb2e2ab8 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -383,7 +383,7 @@ public static ByteSizeValue parseknnMemoryCircuitBreakerValue(String sValue, Str */ public synchronized void updateBooleanSetting(String settingName, boolean value) { ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = new ClusterUpdateSettingsRequest(); - Settings circuitBreakerSettings = Settings.builder().put("unregistered-setting-lets-see-what-happens", value).build(); + Settings circuitBreakerSettings = Settings.builder().put(settingName, value).build(); clusterUpdateSettingsRequest.persistentSettings(circuitBreakerSettings); client.admin().cluster().updateSettings(clusterUpdateSettingsRequest, new ActionListener<>() { @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 9c89f209ca..829624bad5 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -32,7 +32,7 @@ import org.opensearch.index.mapper.ValueFetcher; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorIndexFieldData; @@ -145,12 +145,12 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected String efConstruction; protected ModelDao modelDao; - protected NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + protected NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; - public Builder(String name, ModelDao modelDao, NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { + public Builder(String name, ModelDao modelDao, NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { super(name); this.modelDao = modelDao; - this.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; } /** @@ -167,13 +167,13 @@ public Builder( String spaceType, String m, String efConstruction, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker ) { super(name); this.spaceType = spaceType; this.m = m; this.efConstruction = efConstruction; - this.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; } @Override @@ -227,7 +227,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { .stored(stored.get()) .hasDocValues(hasDocValues.get()) .knnMethodContext(knnMethodContext) - .nativeMemoryCircuitBreakerService(nativeMemoryCircuitBreakerService) + .nativeMemoryCircuitBreakerService(nativeMemoryCircuitBreaker) .build(); return new LuceneFieldMapper(createLuceneFieldMapperInput); } @@ -239,7 +239,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - nativeMemoryCircuitBreakerService, + nativeMemoryCircuitBreaker, knnMethodContext ); } @@ -259,7 +259,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - nativeMemoryCircuitBreakerService, + nativeMemoryCircuitBreaker, modelDao, modelIdAsString ); @@ -286,7 +286,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - nativeMemoryCircuitBreakerService, + nativeMemoryCircuitBreaker, spaceType, m, efConstruction @@ -318,23 +318,16 @@ public static class TypeParser implements Mapper.TypeParser { // Use a supplier here because in {@link org.opensearch.knn.KNNPlugin#getMappers()} the ModelDao has not yet // been initialized private final Supplier modelDaoSupplier; - private final Supplier nativeMemoryCircuitBreakerServiceSupplier; + private final Supplier nativeMemoryCircuitBreakerSupplier; - public TypeParser( - Supplier modelDaoSupplier, - Supplier knnCircuitBreakerServiceSupplier - ) { + public TypeParser(Supplier modelDaoSupplier, Supplier knnCircuitBreakerServiceSupplier) { this.modelDaoSupplier = modelDaoSupplier; - this.nativeMemoryCircuitBreakerServiceSupplier = knnCircuitBreakerServiceSupplier; + this.nativeMemoryCircuitBreakerSupplier = knnCircuitBreakerServiceSupplier; } @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder( - name, - modelDaoSupplier.get(), - nativeMemoryCircuitBreakerServiceSupplier.get() - ); + Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), nativeMemoryCircuitBreakerSupplier.get()); builder.parse(name, parserContext, node); // All ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; this.dimension = mappedFieldType.getDimension(); - this.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; updateEngineStats(); } @@ -472,7 +465,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx } void validateIfCircuitBreakerIsNotTriggered() { - if (nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()) { + if (nativeMemoryCircuitBreaker.isTripped()) { throw new IllegalStateException( "Indexing knn vector fields is rejected as circuit breaker triggered. Check _opendistro/_knn/stats for detailed state" ); @@ -545,7 +538,7 @@ protected boolean docValuesByDefault() { @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, nativeMemoryCircuitBreakerService).init(this); + return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, nativeMemoryCircuitBreaker).init(this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java index 1490dfc16b..a88a6f4115 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java @@ -10,7 +10,7 @@ import org.opensearch.common.Explicit; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.util.KNNEngine; @@ -46,12 +46,12 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService, + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker, String spaceType, String m, String efConstruction ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreakerService); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreaker); this.spaceType = spaceType; this.m = m; @@ -72,13 +72,8 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder( - simpleName(), - this.spaceType, - this.m, - this.efConstruction, - this.nativeMemoryCircuitBreakerService - ).init(this); + return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction, this.nativeMemoryCircuitBreaker) + .init(this); } static String getSpaceType(Settings indexSettings) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 55d97db2a9..6c5b5e1b01 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -15,7 +15,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; @@ -45,7 +45,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getNativeMemoryCircuitBreakerService() + input.getNativeMemoryCircuitBreaker() ); this.knnMethod = input.getKnnMethodContext(); @@ -131,6 +131,6 @@ static class CreateLuceneFieldMapperInput { @NonNull KNNMethodContext knnMethodContext; @NonNull - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index a896b2f93f..3d641b45a9 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -9,7 +9,7 @@ import org.opensearch.common.Explicit; import org.opensearch.common.Strings; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.util.KNNEngine; @@ -33,11 +33,11 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService, + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker, KNNMethodContext knnMethodContext ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreakerService); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreaker); this.knnMethod = knnMethodContext; diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 1a865d6e3d..1d90269b09 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -8,7 +8,7 @@ import org.apache.lucene.document.FieldType; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -29,11 +29,11 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService, + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker, ModelDao modelDao, String modelId ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreakerService); + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, nativeMemoryCircuitBreaker); this.modelId = modelId; this.modelDao = modelDao; diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java index 03cba84b81..1e2b21606a 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryCacheManager.java @@ -22,7 +22,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.stats.StatNames; import java.io.Closeable; @@ -43,7 +43,7 @@ public class NativeMemoryCacheManager implements Closeable { private static final Logger logger = LogManager.getLogger(NativeMemoryCacheManager.class); private static NativeMemoryCacheManager INSTANCE; - private static NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + private static NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; private Cache cache; private final ExecutorService executor; private AtomicBoolean cacheCapacityReached; @@ -71,8 +71,8 @@ public static synchronized NativeMemoryCacheManager getInstance() { private void initialize() { initialize( NativeMemoryCacheManagerDto.builder() - .isWeightLimited(nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()) - .maxWeight(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb()) + .isWeightLimited(nativeMemoryCircuitBreaker.isEnabled()) + .maxWeight(nativeMemoryCircuitBreaker.getLimit().getKb()) .isExpirationLimited(KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)) .expiryTimeInMin( ((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)).getMinutes() @@ -101,8 +101,8 @@ private void initialize(NativeMemoryCacheManagerDto nativeMemoryCacheDTO) { cache = cacheBuilder.build(); } - public static void initialize(NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { - NativeMemoryCacheManager.nativeMemoryCircuitBreakerService = nativeMemoryCircuitBreakerService; + public static void initialize(NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { + NativeMemoryCacheManager.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreaker; } /** @@ -111,8 +111,8 @@ public static void initialize(NativeMemoryCircuitBreakerService nativeMemoryCirc public synchronized void rebuildCache() { rebuildCache( NativeMemoryCacheManagerDto.builder() - .isWeightLimited(nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()) - .maxWeight(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb()) + .isWeightLimited(nativeMemoryCircuitBreaker.isEnabled()) + .maxWeight(nativeMemoryCircuitBreaker.getLimit().getKb()) .isExpirationLimited(KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED)) .expiryTimeInMin( ((TimeValue) KNNSettings.state().getSettingValue(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES)).getMinutes() @@ -372,7 +372,7 @@ private void onRemoval(RemovalNotification remov nativeMemoryAllocation.close(); if (RemovalCause.SIZE == removalNotification.getCause()) { - nativeMemoryCircuitBreakerService.setCircuitBreaker(true); + nativeMemoryCircuitBreaker.set(true); setCacheCapacityReached(true); } @@ -380,7 +380,7 @@ private void onRemoval(RemovalNotification remov } private Float getSizeAsPercentage(long size) { - long cbLimit = nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb(); + long cbLimit = nativeMemoryCircuitBreaker.getLimit().getKb(); if (cbLimit == 0) { return 0.0F; } diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java new file mode 100644 index 0000000000..fa40e2dffe --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreaker.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; + +/** + * The circuit breaker gets tripped based on memory demand tracked by the {@link NativeMemoryCacheManager}. + * When {@link NativeMemoryCacheManager}'s cache fills up, if the circuit breaking logic is enabled, it will trip the + * circuit breaker. Elsewhere in the code, the circuit breaker's value can be queried to prevent actions that should + * not happen during high memory pressure. + */ +@AllArgsConstructor +@Log4j2 +public class NativeMemoryCircuitBreaker { + private final KNNSettings knnSettings; + + /** + * Checks if the circuit breaker is triggered + * + * @return true if circuit breaker is triggered; false otherwise + */ + public boolean isTripped() { + return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED); + } + + /** + * Sets circuit breaker to new value + * + * @param circuitBreaker value to update circuit breaker to + */ + public void set(boolean circuitBreaker) { + knnSettings.updateBooleanSetting(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, circuitBreaker); + } + + /** + * Gets the limit of the circuit breaker + * + * @return limit as ByteSizeValue of native memory circuit breaker + */ + public ByteSizeValue getLimit() { + return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT); + } + + /** + * Determine if the circuit breaker is enabled + * + * @return true if circuit breaker is enabled. False otherwise. + */ + public boolean isEnabled() { + return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED); + } + + /** + * Returns the percentage as a double for when to unset the circuit breaker + * + * @return percentage as double for unsetting circuit breaker + */ + @VisibleForTesting + double getUnsetPercentage() { + return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE); + } +} diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java new file mode 100644 index 0000000000..8003c84dea --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitor.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; +import org.opensearch.knn.plugin.stats.StatNames; +import org.opensearch.knn.plugin.transport.KNNStatsAction; +import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; +import org.opensearch.knn.plugin.transport.KNNStatsRequest; +import org.opensearch.knn.plugin.transport.KNNStatsResponse; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.io.Closeable; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Job that runs periodically to monitor native memory usage on the node/in the cluster and untrip the + * NativeMemoryCircuitBreaker if necessary. + */ +@Log4j2 +public class NativeMemoryCircuitBreakerMonitor implements Closeable { + private final NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; + private final NativeMemoryCacheManager nativeMemoryCacheManager; + private final ClusterService clusterService; + private final Client client; + private final ThreadPool threadPool; + private final AtomicReference isStarted; + private Scheduler.Cancellable monitorFuture; + public static final int CB_TIME_INTERVAL = 2 * 60; // seconds + private static final TimeValue STATS_REQUEST_TIMEOUT = new TimeValue(1000 * 10); // 10 second timeout + + /** + * Constructor + * + * @param nativeMemoryCircuitBreakerMonitorDto contains necessary initialization values + */ + public NativeMemoryCircuitBreakerMonitor(NativeMemoryCircuitBreakerMonitorDto nativeMemoryCircuitBreakerMonitorDto) { + this.nativeMemoryCircuitBreaker = nativeMemoryCircuitBreakerMonitorDto.getNativeMemoryCircuitBreaker(); + this.nativeMemoryCacheManager = nativeMemoryCircuitBreakerMonitorDto.getNativeMemoryCacheManager(); + this.clusterService = nativeMemoryCircuitBreakerMonitorDto.getClusterService(); + this.client = nativeMemoryCircuitBreakerMonitorDto.getClient(); + this.threadPool = nativeMemoryCircuitBreakerMonitorDto.getThreadPool(); + this.isStarted = new AtomicReference<>(false); + this.monitorFuture = null; + } + + /** + * Schedules monitor job to be run + */ + public synchronized void start() { + // Ensure monitor future is only scheduled once + boolean isAlreadyStarted = this.isStarted.getAndSet(true); + if (isAlreadyStarted == false) { + this.monitorFuture = threadPool.scheduleWithFixedDelay( + this::monitor, + TimeValue.timeValueSeconds(CB_TIME_INTERVAL), + ThreadPool.Names.GENERIC + ); + } + } + + @Override + public synchronized void close() { + if (this.monitorFuture != null && this.monitorFuture.isCancelled() == false) { + this.monitorFuture.cancel(); + } + } + + @VisibleForTesting + void monitor() { + if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { + long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); + long circuitBreakerLimitSizeKiloBytes = nativeMemoryCircuitBreaker.getLimit().getKb(); + long circuitBreakerUnsetSizeKiloBytes = (long) ((nativeMemoryCircuitBreaker.getUnsetPercentage() / 100) + * circuitBreakerLimitSizeKiloBytes); + // Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes + if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { + nativeMemoryCacheManager.setCacheCapacityReached(false); + } + } + + // Leader node untriggers CB if all nodes have not reached their max capacity + if (nativeMemoryCircuitBreaker.isTripped() && clusterService.state().nodes().isLocalNodeElectedClusterManager()) { + KNNStatsRequest knnStatsRequest = new KNNStatsRequest(); + knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); + knnStatsRequest.timeout(STATS_REQUEST_TIMEOUT); + + try { + KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); + List nodeResponses = knnStatsResponse.getNodes(); + + List nodesAtMaxCapacity = new ArrayList<>(); + for (KNNStatsNodeResponse nodeResponse : nodeResponses) { + if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) { + nodesAtMaxCapacity.add(nodeResponse.getNode().getId()); + } + } + + if (nodesAtMaxCapacity.isEmpty() == false) { + log.info( + "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " + + String.join(",", nodesAtMaxCapacity) + + "." + ); + } else { + log.info( + "[KNN] Cache capacity below {}% of the circuit breaker limit for all nodes. Unsetting knn.circuit_breaker.triggered flag.", + nativeMemoryCircuitBreaker.getUnsetPercentage() + ); + nativeMemoryCircuitBreaker.set(false); + } + } catch (Exception e) { + log.error("[KNN] Error when trying to update the circuit breaker setting", e); + } + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java new file mode 100644 index 0000000000..59d46c9fb3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorDto.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import lombok.Builder; +import lombok.Value; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; +import org.opensearch.threadpool.ThreadPool; + +@Value +@Builder +public class NativeMemoryCircuitBreakerMonitorDto { + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; + NativeMemoryCacheManager nativeMemoryCacheManager; + ClusterService clusterService; + Client client; + ThreadPool threadPool; +} diff --git a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java b/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java deleted file mode 100644 index 1fa3af0d2a..0000000000 --- a/src/main/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerService.java +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.memory.breaker; - -import com.google.common.annotations.VisibleForTesting; -import lombok.Value; -import lombok.extern.log4j.Log4j2; -import org.opensearch.common.component.AbstractLifecycleComponent; -import org.opensearch.common.unit.ByteSizeValue; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.knn.plugin.transport.KNNStatsAction; -import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; -import org.opensearch.knn.plugin.transport.KNNStatsRequest; -import org.opensearch.knn.plugin.transport.KNNStatsResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.threadpool.Scheduler; -import org.opensearch.threadpool.ThreadPool; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.atomic.AtomicReference; - -/** - * Service handling native memory circuit breaking logic. The circuit breaker gets tripped based on memory demand - * tracked by the {@link NativeMemoryCacheManager}. When {@link NativeMemoryCacheManager}'s cache fills up, if the - * circuit breaking logic is enabled, it will trip the circuit. Elsewhere in the code, the circuit breaker's value can - * be queried to prevent actions that should not happen during high memory pressure. - */ -@Log4j2 -public class NativeMemoryCircuitBreakerService extends AbstractLifecycleComponent { - private final ThreadPool threadPool; - private final KNNSettings knnSettings; - private final ClusterService clusterService; - private final Client client; - // Cancellable task to track circuitBreakerRunnable. In order to schedule, doStart must be called. doStart will - // only start the future if the previous value is null. Therefore, to close this class, do NOT set the value of - // this variable to null. To stop this class, this variable should be set to null so that it may be restarted. - @VisibleForTesting - final AtomicReference circuitBreakerFuture; - - public static final int CB_TIME_INTERVAL = 2 * 60; // seconds - - /** - * Constructor for creation of circuit breaker service for KNN - * - * @param knnSettings Settings class for k-NN - * @param threadPool thread pool for circuit breaker monitor to run job - * @param clusterService cluster service used to retrieve information about the cluster - * @param client client used to make calls to the cluster - */ - public NativeMemoryCircuitBreakerService(KNNSettings knnSettings, ThreadPool threadPool, ClusterService clusterService, Client client) { - this.knnSettings = knnSettings; - this.threadPool = threadPool; - this.clusterService = clusterService; - this.client = client; - this.circuitBreakerFuture = new AtomicReference<>(null); - } - - /** - * Checks if the circuit breaker is triggered - * - * @return true if circuit breaker is triggered; false otherwise - */ - public boolean isCircuitBreakerTriggered() { - return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED); - } - - /** - * Sets circuit breaker to new value - * - * @param circuitBreaker value to update circuit breaker to - */ - public void setCircuitBreaker(boolean circuitBreaker) { - knnSettings.updateBooleanSetting(KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED, circuitBreaker); - } - - /** - * Gets the limit of the circuit breaker - * - * @return limit as ByteSizeValue of native memory circuit breaker - */ - public ByteSizeValue getCircuitBreakerLimit() { - return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT); - } - - /** - * Determine if the circuit breaker is enabled - * - * @return true if circuit breaker is enabled. False otherwise. - */ - public boolean isCircuitBreakerEnabled() { - return knnSettings.getSettingValue(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED); - } - - /** - * Returns the percentage as a double for when to unset the circuit breaker - * - * @return percentage as double for unsetting circuit breaker - */ - @VisibleForTesting - double getCircuitBreakerUnsetPercentage() { - return knnSettings.getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE); - } - - @Override - protected void doStart() { - Monitor monitor = getMonitor(); - this.circuitBreakerFuture.compareAndSet( - null, - threadPool.scheduleWithFixedDelay(monitor, TimeValue.timeValueSeconds(CB_TIME_INTERVAL), ThreadPool.Names.GENERIC) - ); - } - - @VisibleForTesting - Monitor getMonitor() { - return new Monitor(this, NativeMemoryCacheManager.getInstance(), clusterService, client); - } - - @Override - protected void doStop() { - Scheduler.Cancellable cancellable = this.circuitBreakerFuture.getAndSet(null); - if (cancellable != null && !cancellable.isCancelled()) { - cancellable.cancel(); - } - } - - @Override - protected void doClose() { - Scheduler.Cancellable cancellable = this.circuitBreakerFuture.get(); - if (cancellable != null && !cancellable.isCancelled()) { - cancellable.cancel(); - } - } - - @VisibleForTesting - @Value - static class Monitor implements Runnable { - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; - NativeMemoryCacheManager nativeMemoryCacheManager; - ClusterService clusterService; - Client client; - private static final TimeValue STATS_REQUEST_TIMEOUT = new TimeValue(1000 * 10); // 10 second timeout - - @Override - public void run() { - if (nativeMemoryCacheManager.isCacheCapacityReached() && clusterService.localNode().isDataNode()) { - long currentSizeKiloBytes = nativeMemoryCacheManager.getCacheSizeInKilobytes(); - long circuitBreakerLimitSizeKiloBytes = nativeMemoryCircuitBreakerService.getCircuitBreakerLimit().getKb(); - long circuitBreakerUnsetSizeKiloBytes = (long) ((nativeMemoryCircuitBreakerService.getCircuitBreakerUnsetPercentage() / 100) - * circuitBreakerLimitSizeKiloBytes); - // Unset capacityReached flag if currentSizeBytes is less than circuitBreakerUnsetSizeBytes - if (currentSizeKiloBytes <= circuitBreakerUnsetSizeKiloBytes) { - nativeMemoryCacheManager.setCacheCapacityReached(false); - } - } - - // Leader node untriggers CB if all nodes have not reached their max capacity - if (nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered() - && clusterService.state().nodes().isLocalNodeElectedClusterManager()) { - KNNStatsRequest knnStatsRequest = new KNNStatsRequest(); - knnStatsRequest.addStat(StatNames.CACHE_CAPACITY_REACHED.getName()); - knnStatsRequest.timeout(STATS_REQUEST_TIMEOUT); - - try { - KNNStatsResponse knnStatsResponse = client.execute(KNNStatsAction.INSTANCE, knnStatsRequest).get(); - List nodeResponses = knnStatsResponse.getNodes(); - - List nodesAtMaxCapacity = new ArrayList<>(); - for (KNNStatsNodeResponse nodeResponse : nodeResponses) { - if ((Boolean) nodeResponse.getStatsMap().get(StatNames.CACHE_CAPACITY_REACHED.getName())) { - nodesAtMaxCapacity.add(nodeResponse.getNode().getId()); - } - } - - if (nodesAtMaxCapacity.isEmpty() == false) { - log.info( - "[KNN] knn.circuit_breaker.triggered stays set. Nodes at max cache capacity: " - + String.join(",", nodesAtMaxCapacity) - + "." - ); - } else { - log.info( - "[KNN] Cache capacity below {}% of the circuit breaker limit for all nodes. Unsetting knn.circuit_breaker.triggered flag.", - nativeMemoryCircuitBreakerService.getCircuitBreakerUnsetPercentage() - ); - nativeMemoryCircuitBreakerService.setCircuitBreaker(false); - } - } catch (Exception e) { - log.error("[KNN] Error when trying to update the circuit breaker setting", e); - } - } - } - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index d1dbf0185b..1f6cad6ca1 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -12,8 +12,10 @@ import org.opensearch.index.engine.EngineFactory; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.KNNClusterUtil; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerMonitor; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerMonitorDto; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -158,13 +160,14 @@ public class KNNPlugin extends Plugin private KNNStats knnStats; private ClusterService clusterService; - private NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + private NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; + private NativeMemoryCircuitBreakerMonitor nativeMemoryCircuitBreakerMonitor; @Override public Map getMappers() { return Collections.singletonMap( KNNVectorFieldMapper.CONTENT_TYPE, - new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance, () -> nativeMemoryCircuitBreakerService) + new KNNVectorFieldMapper.TypeParser(ModelDao.OpenSearchKNNModelDao::getInstance, () -> nativeMemoryCircuitBreaker) ); } @@ -202,13 +205,21 @@ public Collection createComponents( KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); - nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService(KNNSettings.state(), threadPool, clusterService, client); - NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreakerService); - // Called after NativeMemoryCacheManager initialization. Start method requires access to an instance of the - // NativeMemoryCacheManager that needs to be initialized. - nativeMemoryCircuitBreakerService.start(); - knnStats = new KNNStats(nativeMemoryCircuitBreakerService); - return ImmutableList.of(knnStats, nativeMemoryCircuitBreakerService); + + nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(KNNSettings.state()); + NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreaker); + nativeMemoryCircuitBreakerMonitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(NativeMemoryCacheManager.getInstance()) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + nativeMemoryCircuitBreakerMonitor.start(); + knnStats = new KNNStats(nativeMemoryCircuitBreaker); + return ImmutableList.of(knnStats, nativeMemoryCircuitBreaker); } @Override @@ -367,4 +378,9 @@ public Settings additionalSettings() { ).collect(Collectors.toList()); return Settings.builder().putList(IndexModule.INDEX_STORE_HYBRID_MMAP_EXTENSIONS.getKey(), combinedSettings).build(); } + + @Override + public void close() { + nativeMemoryCircuitBreakerMonitor.close(); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index 3c64c95019..461263c048 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -9,7 +9,7 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; @@ -36,10 +36,10 @@ public class KNNStats { /** * Constructor * - * @param nativeMemoryCircuitBreakerService k-NN circuit breaker service for native memory + * @param nativeMemoryCircuitBreaker k-NN circuit breaker service for native memory */ - public KNNStats(NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { - this.knnStats = buildStatsMap(nativeMemoryCircuitBreakerService); + public KNNStats(NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { + this.knnStats = buildStatsMap(nativeMemoryCircuitBreaker); } /** @@ -80,10 +80,10 @@ private Map> getClusterOrNodeStats(Boolean getClusterStats) { return statsMap; } - private Map> buildStatsMap(NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService) { + private Map> buildStatsMap(NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker) { ImmutableMap.Builder> builder = ImmutableMap.>builder(); addQueryStats(builder); - addNativeMemoryStats(builder, nativeMemoryCircuitBreakerService); + addNativeMemoryStats(builder, nativeMemoryCircuitBreaker); addEngineStats(builder); addScriptStats(builder); addModelStats(builder); @@ -101,7 +101,7 @@ private void addQueryStats(ImmutableMap.Builder> builder) { private void addNativeMemoryStats( ImmutableMap.Builder> builder, - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker ) { builder.put(StatNames.HIT_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::hitCount))) .put(StatNames.MISS_COUNT.getName(), new KNNStat<>(false, new KNNInnerCacheStatsSupplier(CacheStats::missCount))) @@ -134,7 +134,7 @@ private void addNativeMemoryStats( .put(StatNames.GRAPH_INDEX_REQUESTS.getName(), new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.GRAPH_INDEX_REQUESTS))) .put( StatNames.CIRCUIT_BREAKER_TRIGGERED.getName(), - new KNNStat<>(true, new NativeMemoryCircuitBreakerSupplier(nativeMemoryCircuitBreakerService)) + new KNNStat<>(true, new NativeMemoryCircuitBreakerSupplier(nativeMemoryCircuitBreaker)) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java index 76b8d785d0..cb3e73a627 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/suppliers/NativeMemoryCircuitBreakerSupplier.java @@ -6,7 +6,7 @@ package org.opensearch.knn.plugin.stats.suppliers; import lombok.AllArgsConstructor; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import java.util.function.Supplier; @@ -16,10 +16,10 @@ @AllArgsConstructor public class NativeMemoryCircuitBreakerSupplier implements Supplier { - private final NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + private final NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; @Override public Boolean get() { - return nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered(); + return nativeMemoryCircuitBreaker.isTripped(); } } diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 9532f5965c..01a257208f 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -19,7 +19,7 @@ import org.opensearch.common.unit.ByteSizeValue; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; @@ -40,10 +40,10 @@ */ public class KNNTestCase extends OpenSearchTestCase { - protected static final NativeMemoryCircuitBreakerService NEVER_TRIGGERED_CB_SERVICE = mock(NativeMemoryCircuitBreakerService.class); + protected static final NativeMemoryCircuitBreaker NEVER_TRIGGERED_CB = mock(NativeMemoryCircuitBreaker.class); static { - when(NEVER_TRIGGERED_CB_SERVICE.isCircuitBreakerTriggered()).thenReturn(false); - when(NEVER_TRIGGERED_CB_SERVICE.getCircuitBreakerLimit()).thenReturn(new ByteSizeValue(100, ByteSizeUnit.KB)); + when(NEVER_TRIGGERED_CB.isTripped()).thenReturn(false); + when(NEVER_TRIGGERED_CB.getLimit()).thenReturn(new ByteSizeValue(100, ByteSizeUnit.KB)); } @Mock @@ -57,7 +57,7 @@ public class KNNTestCase extends OpenSearchTestCase { @Mock protected NativeMemoryCacheManager nativeMemoryCacheManager; @Mock - protected NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService; + protected NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker; @Mock protected DiscoveryNode node; @Mock @@ -97,7 +97,7 @@ public void resetState() { KNNSettings.state().setClusterService(clusterService); // Clean up the cache - NativeMemoryCacheManager.initialize(NEVER_TRIGGERED_CB_SERVICE); + NativeMemoryCacheManager.initialize(NEVER_TRIGGERED_CB); NativeMemoryCacheManager.getInstance().invalidateAll(); NativeMemoryCacheManager.getInstance().close(); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index ce2cfdb2c0..281979f446 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -30,7 +30,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -82,15 +82,15 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, NEVER_TRIGGERED_CB); assertEquals(6, builder.getParameters().size()); } public void testBuilder_build_fromKnnMethodContext() { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreakerService.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB_SERVICE); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreaker.class); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -127,7 +127,7 @@ public void testBuilder_build_fromKnnMethodContext() { public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -164,7 +164,7 @@ public void testBuilder_build_fromModel() { public void testBuilder_build_fromLegacy() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, NEVER_TRIGGERED_CB); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -193,7 +193,7 @@ public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOExcep Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); KNNEngine.LUCENE.setInitialized(false); @@ -279,7 +279,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int efConstruction = 321; @@ -343,7 +343,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidSpaceType() throws Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int efConstruction = 321; int dimension = 133; @@ -398,7 +398,7 @@ public void testTypeParser_parse_fromKnnMethodContext() throws IOException { Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int efConstruction = 321; int dimension = 133; @@ -495,7 +495,7 @@ public void testTypeParser_parse_fromModel() throws IOException { Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); String modelId = "test-id"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -529,7 +529,7 @@ public void testTypeParser_parse_fromLegacy() throws IOException { .build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int dimension = 122; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -555,7 +555,7 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB_SERVICE); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao, () -> NEVER_TRIGGERED_CB); int dimension = 133; int efConstruction = 321; @@ -631,7 +631,7 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser( () -> mockModelDao, - () -> NEVER_TRIGGERED_CB_SERVICE + () -> NEVER_TRIGGERED_CB ); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -701,7 +701,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues() throws IOExceptio .copyTo(FieldMapper.CopyTo.empty()) .hasDocValues(true) .ignoreMalformed(new Explicit<>(true, true)) - .nativeMemoryCircuitBreakerService(NEVER_TRIGGERED_CB_SERVICE) + .nativeMemoryCircuitBreakerService(NEVER_TRIGGERED_CB) .knnMethodContext(knnMethodContext); ParseContext.Document document = new ParseContext.Document(); @@ -796,7 +796,7 @@ public Mapper.TypeParser.ParserContext buildParserContext(String indexName, Sett return new Mapper.TypeParser.ParserContext( null, mapperService, - type -> new KNNVectorFieldMapper.TypeParser(() -> mockModelDao, () -> NEVER_TRIGGERED_CB_SERVICE), + type -> new KNNVectorFieldMapper.TypeParser(() -> mockModelDao, () -> NEVER_TRIGGERED_CB), CURRENT, null, null, diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java index 7ba3fc429f..ee4af5ab32 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryCacheManagerTests.java @@ -18,7 +18,7 @@ import org.opensearch.common.unit.ByteSizeValue; import org.opensearch.knn.common.exception.OutOfNativeMemoryException; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService; +import org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreaker; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchSingleNodeTestCase; @@ -29,7 +29,6 @@ import java.util.Map; import java.util.concurrent.ExecutionException; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; @@ -284,11 +283,10 @@ public void testGetIndexGraphCount() throws ExecutionException, IOException { public void testGetMaxCacheSizeInKB() { long cbLimitInKB = 100; ByteSizeValue defaultCBLimit = new ByteSizeValue(cbLimitInKB, ByteSizeUnit.KB); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = mock(NativeMemoryCircuitBreakerService.class); - when(nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()).thenReturn(true); - when(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit()).thenReturn(defaultCBLimit); - doNothing().when(nativeMemoryCircuitBreakerService).close(); - NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreakerService); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = mock(NativeMemoryCircuitBreaker.class); + when(nativeMemoryCircuitBreaker.isEnabled()).thenReturn(true); + when(nativeMemoryCircuitBreaker.getLimit()).thenReturn(defaultCBLimit); + NativeMemoryCacheManager.initialize(nativeMemoryCircuitBreaker); NativeMemoryCacheManager nativeMemoryCacheManager = new NativeMemoryCacheManager(); assertEquals(cbLimitInKB, nativeMemoryCacheManager.getMaxCacheSizeInKilobytes()); nativeMemoryCacheManager.close(); diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java index d85201a09a..c16103a646 100644 --- a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java +++ b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerIT.java @@ -20,7 +20,7 @@ import java.util.Collections; import java.util.Map; -import static org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerService.CB_TIME_INTERVAL; +import static org.opensearch.knn.index.memory.breaker.NativeMemoryCircuitBreakerMonitor.CB_TIME_INTERVAL; /** * Integration tests to test Circuit Breaker functionality diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java new file mode 100644 index 0000000000..28e346ab56 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerMonitorTests.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.plugin.stats.StatNames; +import org.opensearch.knn.plugin.transport.KNNStatsAction; +import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; +import org.opensearch.knn.plugin.transport.KNNStatsRequest; +import org.opensearch.knn.plugin.transport.KNNStatsResponse; +import org.opensearch.threadpool.Scheduler; + +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class NativeMemoryCircuitBreakerMonitorTests extends KNNTestCase { + + public void testStart_whenCalledMultipleTimes_thenScheduleMonitorOnce() { + Scheduler.Cancellable cancellable = createCancellable(); + when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + + monitor.start(); + monitor.start(); + monitor.start(); + + verify(threadPool, times(1)).scheduleWithFixedDelay(any(), any(), any()); + } + + public void testClose_whenCalled_thenCancel() { + Scheduler.Cancellable cancellable = createCancellable(); + when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + + monitor.start(); + assertFalse(cancellable.isCancelled()); + monitor.close(); + assertTrue(cancellable.isCancelled()); + } + + public void testMonitor_whenDataNodeCacheSizeLowerThanThreshold_thenUnsetCacheCapacityReached() { + // Setup state so that cache capacity is marked as reached but the ratio in the cache is less than + // the unset ratio + when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(true); + when(node.isDataNode()).thenReturn(true); + when(clusterService.localNode()).thenReturn(node); + long cacheSizeInKb = 1; + long cbLimitInKb = 100; + double unsetSizeInKb = 99; + when(nativeMemoryCacheManager.getCacheSizeInKilobytes()).thenReturn(cacheSizeInKb); + when(nativeMemoryCircuitBreaker.getLimit()).thenReturn(new ByteSizeValue(cbLimitInKb, ByteSizeUnit.KB)); + when(nativeMemoryCircuitBreaker.getUnsetPercentage()).thenReturn(unsetSizeInKb); + + // Avoid duties of cluster manager + when(nativeMemoryCircuitBreaker.isTripped()).thenReturn(false); + + doNothing().when(nativeMemoryCacheManager).setCacheCapacityReached(false); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + + monitor.monitor(); + verify(nativeMemoryCacheManager, times(1)).setCacheCapacityReached(false); + } + + @SneakyThrows + public void testMonitorRun_whenClusterManagerAndClusterHasCapacity_thenUnsetCircuitBreaker() { + // Setup state so that current node is cluster manager and should unset circuit breaker5 + when(nativeMemoryCircuitBreaker.isTripped()).thenReturn(true); + when(discoveryNodes.isLocalNodeElectedClusterManager()).thenReturn(true); + when(clusterState.nodes()).thenReturn(discoveryNodes); + when(clusterService.state()).thenReturn(clusterState); + + // Ensure all nodes have cache capacity as not reached + Map reachedMap = ImmutableMap.of(StatNames.CACHE_CAPACITY_REACHED.getName(), false); + List nodeResponses = List.of( + new KNNStatsNodeResponse(node, reachedMap), + new KNNStatsNodeResponse(node, reachedMap), + new KNNStatsNodeResponse(node, reachedMap) + ); + KNNStatsResponse knnStatsResponse = mock(KNNStatsResponse.class); + when(knnStatsResponse.getNodes()).thenReturn(nodeResponses); + + PlainActionFuture actionFuture = new PlainActionFuture<>() { + @Override + public KNNStatsResponse get() { + return knnStatsResponse; + } + }; + + when(client.execute(any(KNNStatsAction.class), any(KNNStatsRequest.class))).thenReturn(actionFuture); + + // Avoid duties of data node + when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(false); + + doNothing().when(nativeMemoryCircuitBreaker).set(false); + NativeMemoryCircuitBreakerMonitor monitor = new NativeMemoryCircuitBreakerMonitor( + NativeMemoryCircuitBreakerMonitorDto.builder() + .nativeMemoryCacheManager(nativeMemoryCacheManager) + .nativeMemoryCircuitBreaker(nativeMemoryCircuitBreaker) + .threadPool(threadPool) + .client(client) + .clusterService(clusterService) + .build() + ); + monitor.monitor(); + verify(nativeMemoryCircuitBreaker, times(1)).set(false); + } + + private Scheduler.Cancellable createCancellable() { + return new Scheduler.Cancellable() { + boolean isCancelled = false; + + @Override + public boolean cancel() { + isCancelled = true; + return true; + } + + @Override + public boolean isCancelled() { + return isCancelled; + } + }; + } + +} diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java deleted file mode 100644 index 0013aece84..0000000000 --- a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerServiceTests.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.memory.breaker; - -import com.google.common.collect.ImmutableMap; -import lombok.SneakyThrows; -import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.unit.ByteSizeUnit; -import org.opensearch.common.unit.ByteSizeValue; -import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.knn.plugin.transport.KNNStatsAction; -import org.opensearch.knn.plugin.transport.KNNStatsNodeResponse; -import org.opensearch.knn.plugin.transport.KNNStatsRequest; -import org.opensearch.knn.plugin.transport.KNNStatsResponse; -import org.opensearch.threadpool.Scheduler; -import org.opensearch.threadpool.ThreadPool; - -import java.util.List; -import java.util.Map; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED; -import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; -import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT; - -public class NativeMemoryCircuitBreakerServiceTests extends KNNTestCase { - - public void testCircuitBreaker_whenSet_thenCircuitBreakerUpdated() { - boolean isTriggered = randomBoolean(); - doNothing().when(knnSettings).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTriggered); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - nativeMemoryCircuitBreakerService.setCircuitBreaker(isTriggered); - verify(knnSettings, times(1)).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTriggered); - } - - public void testGetCircuitBreakerLimit() { - ByteSizeValue circuitBreakerLimit = new ByteSizeValue(randomIntBetween(10, 10000), ByteSizeUnit.KB); - when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT)).thenReturn(circuitBreakerLimit); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - assertEquals(circuitBreakerLimit, nativeMemoryCircuitBreakerService.getCircuitBreakerLimit()); - } - - public void testIsCircuitBreakerTriggered() { - boolean isTriggered = randomBoolean(); - when(knnSettings.getSettingValue(KNN_CIRCUIT_BREAKER_TRIGGERED)).thenReturn(isTriggered); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - assertEquals(isTriggered, nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()); - } - - public void testIsCircuitBreakerEnabled() { - boolean isEnabled = randomBoolean(); - when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)).thenReturn(isEnabled); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = new NativeMemoryCircuitBreakerService( - knnSettings, - threadPool, - clusterService, - client - ); - assertEquals(isEnabled, nativeMemoryCircuitBreakerService.isCircuitBreakerEnabled()); - } - - public void testStartStopLifeCycle() { - Scheduler.Cancellable cancellable = createCancellable(); - when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = createNativeMemoryCircuitBreakerServiceWithNoopMonitor( - knnSettings, - threadPool, - clusterService, - client - ); - nativeMemoryCircuitBreakerService.doStart(); - assertEquals(cancellable, nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - assertFalse(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get().isCancelled()); - nativeMemoryCircuitBreakerService.doStop(); - assertNull(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - } - - public void testStartCloseLifeCycle() { - Scheduler.Cancellable cancellable = createCancellable(); - when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); - NativeMemoryCircuitBreakerService nativeMemoryCircuitBreakerService = createNativeMemoryCircuitBreakerServiceWithNoopMonitor( - knnSettings, - threadPool, - clusterService, - client - ); - nativeMemoryCircuitBreakerService.doStart(); - assertEquals(cancellable, nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - assertFalse(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get().isCancelled()); - nativeMemoryCircuitBreakerService.doClose(); - assertEquals(cancellable, nativeMemoryCircuitBreakerService.circuitBreakerFuture.get()); - assertTrue(nativeMemoryCircuitBreakerService.circuitBreakerFuture.get().isCancelled()); - } - - public void testMonitorRun_whenDataNodeCacheSizeLowerThanThreshold_thenUnsetCacheCapacityReached() { - // Setup state so that cache capacity is marked as reached but the ratio in the cache is less than - // the unset ratio - when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(true); - when(node.isDataNode()).thenReturn(true); - when(clusterService.localNode()).thenReturn(node); - long cacheSizeInKb = 1; - long cbLimitInKb = 100; - double unsetSizeInKb = 99; - when(nativeMemoryCacheManager.getCacheSizeInKilobytes()).thenReturn(cacheSizeInKb); - when(nativeMemoryCircuitBreakerService.getCircuitBreakerLimit()).thenReturn(new ByteSizeValue(cbLimitInKb, ByteSizeUnit.KB)); - when(nativeMemoryCircuitBreakerService.getCircuitBreakerUnsetPercentage()).thenReturn(unsetSizeInKb); - - // Avoid duties of cluster manager - when(nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()).thenReturn(false); - - doNothing().when(nativeMemoryCacheManager).setCacheCapacityReached(false); - NativeMemoryCircuitBreakerService.Monitor monitor = new NativeMemoryCircuitBreakerService.Monitor( - nativeMemoryCircuitBreakerService, - nativeMemoryCacheManager, - clusterService, - client - ); - monitor.run(); - verify(nativeMemoryCacheManager, times(1)).setCacheCapacityReached(false); - } - - @SneakyThrows - public void testMonitorRun_whenClusterManagerAndClusterHasCapacity_thenUnsetCircuitBreaker() { - // Setup state so that current node is cluster manager and should unset circuit breaker5 - when(nativeMemoryCircuitBreakerService.isCircuitBreakerTriggered()).thenReturn(true); - when(discoveryNodes.isLocalNodeElectedClusterManager()).thenReturn(true); - when(clusterState.nodes()).thenReturn(discoveryNodes); - when(clusterService.state()).thenReturn(clusterState); - - // Ensure all nodes have cache capacity as not reached - Map reachedMap = ImmutableMap.of(StatNames.CACHE_CAPACITY_REACHED.getName(), false); - List nodeResponses = List.of( - new KNNStatsNodeResponse(node, reachedMap), - new KNNStatsNodeResponse(node, reachedMap), - new KNNStatsNodeResponse(node, reachedMap) - ); - KNNStatsResponse knnStatsResponse = mock(KNNStatsResponse.class); - when(knnStatsResponse.getNodes()).thenReturn(nodeResponses); - - PlainActionFuture actionFuture = new PlainActionFuture<>() { - @Override - public KNNStatsResponse get() { - return knnStatsResponse; - } - }; - - when(client.execute(any(KNNStatsAction.class), any(KNNStatsRequest.class))).thenReturn(actionFuture); - - // Avoid duties of data node - when(nativeMemoryCacheManager.isCacheCapacityReached()).thenReturn(false); - - doNothing().when(nativeMemoryCircuitBreakerService).setCircuitBreaker(false); - NativeMemoryCircuitBreakerService.Monitor monitor = new NativeMemoryCircuitBreakerService.Monitor( - nativeMemoryCircuitBreakerService, - nativeMemoryCacheManager, - clusterService, - client - ); - monitor.run(); - verify(nativeMemoryCircuitBreakerService, times(1)).setCircuitBreaker(false); - } - - private Scheduler.Cancellable createCancellable() { - return new Scheduler.Cancellable() { - boolean isCancelled = false; - - @Override - public boolean cancel() { - isCancelled = true; - return true; - } - - @Override - public boolean isCancelled() { - return isCancelled; - } - }; - } - - private NativeMemoryCircuitBreakerService createNativeMemoryCircuitBreakerServiceWithNoopMonitor( - KNNSettings knnSettings, - ThreadPool threadPool, - ClusterService clusterService, - Client client - ) { - return new NativeMemoryCircuitBreakerService(knnSettings, threadPool, clusterService, client) { - @Override - protected Monitor getMonitor() { - return mock(Monitor.class); - } - }; - } -} diff --git a/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java new file mode 100644 index 0000000000..785a3b8c15 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/breaker/NativeMemoryCircuitBreakerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory.breaker; + +import org.opensearch.common.unit.ByteSizeUnit; +import org.opensearch.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; + +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_TRIGGERED; +import static org.opensearch.knn.index.KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE; +import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; +import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT; + +public class NativeMemoryCircuitBreakerTests extends KNNTestCase { + + public void testIsTripped() { + boolean isTripped = randomBoolean(); + when(knnSettings.getSettingValue(KNN_CIRCUIT_BREAKER_TRIGGERED)).thenReturn(isTripped); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(isTripped, nativeMemoryCircuitBreaker.isTripped()); + } + + public void testSet() { + boolean isTripped = randomBoolean(); + doNothing().when(knnSettings).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTripped); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + nativeMemoryCircuitBreaker.set(isTripped); + verify(knnSettings, times(1)).updateBooleanSetting(KNN_CIRCUIT_BREAKER_TRIGGERED, isTripped); + } + + public void testGetLimit() { + ByteSizeValue circuitBreakerLimit = new ByteSizeValue(randomIntBetween(10, 10000), ByteSizeUnit.KB); + when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_LIMIT)).thenReturn(circuitBreakerLimit); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(circuitBreakerLimit, nativeMemoryCircuitBreaker.getLimit()); + } + + public void testIsEnabled() { + boolean isEnabled = randomBoolean(); + when(knnSettings.getSettingValue(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED)).thenReturn(isEnabled); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(isEnabled, nativeMemoryCircuitBreaker.isEnabled()); + } + + public void testGetUnsetPercentage() { + double unsetPercentage = 71; + when(knnSettings.getSettingValue(KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE)).thenReturn(unsetPercentage); + NativeMemoryCircuitBreaker nativeMemoryCircuitBreaker = new NativeMemoryCircuitBreaker(knnSettings); + assertEquals(unsetPercentage, nativeMemoryCircuitBreaker.getUnsetPercentage(), 0.0001); + } +}