From 3cc2fcaa90f37e1eec6f5f7ba2599bdc88f1b637 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 12:50:47 -0700 Subject: [PATCH] filtering stats api for hidden model (#2307) (#2328) * filtering stats api for hidden model Signed-off-by: Dhrubo Saha * addressing comments Signed-off-by: Dhrubo Saha * reverting to excludes instead of includes Signed-off-by: Dhrubo Saha * addressing comments Signed-off-by: Dhrubo Saha * spotlessApply Signed-off-by: Dhrubo Saha * renamed method name Signed-off-by: Dhrubo Saha * addressing comments Signed-off-by: Dhrubo Saha * spotlessApply Signed-off-by: Dhrubo Saha * fixing a test Signed-off-by: Dhrubo Saha * addressed comments Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha (cherry picked from commit a83a78f714e2aab869e61884b2754d50550689ef) Co-authored-by: Dhrubo Saha --- .../stats/MLStatsNodesTransportAction.java | 116 +++++++++- .../MLStatsNodesTransportActionTests.java | 213 +++++++++++++++--- 2 files changed, 291 insertions(+), 38 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java index f585a62edb..75ecd3137a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java @@ -7,17 +7,31 @@ import java.io.IOException; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.env.Environment; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.ActionName; import org.opensearch.ml.stats.MLActionStats; import org.opensearch.ml.stats.MLAlgoStats; @@ -26,15 +40,26 @@ import org.opensearch.ml.stats.MLStatLevel; import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.stats.MLStatsInput; +import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.search.SearchHit; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class MLStatsNodesTransportAction extends TransportNodesAction { private MLStats mlStats; private final JvmService jvmService; + private final Client client; + + private final MLModelManager mlModelManager; + /** * Constructor * @@ -52,7 +77,9 @@ public MLStatsNodesTransportAction( TransportService transportService, ActionFilters actionFilters, MLStats mlStats, - Environment environment + Environment environment, + Client client, + MLModelManager mlModelManager ) { super( MLStatsNodesAction.NAME, @@ -67,6 +94,8 @@ public MLStatsNodesTransportAction( ); this.mlStats = mlStats; this.jvmService = new JvmService(environment.settings()); + this.client = client; + this.mlModelManager = mlModelManager; } @Override @@ -127,21 +156,88 @@ MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStatsNodesRe } Map modelStats = new HashMap<>(); - // return model level stats if (mlStatsInput.includeModelStats()) { - for (String modelId : mlStats.getAllModels()) { - if (mlStatsInput.retrieveStatsForModel(modelId)) { - Map actionStatsMap = new HashMap<>(); - for (Map.Entry entry : mlStats.getModelStats(modelId).entrySet()) { - if (mlStatsInput.retrieveStatsForAction(entry.getKey())) { - actionStatsMap.put(entry.getKey(), entry.getValue()); + CountDownLatch latch = new CountDownLatch(1); + boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); + searchHiddenModels(ActionListener.wrap(hiddenModels -> { + for (String modelId : mlStats.getAllModels()) { + if (isSuperAdmin || !hiddenModels.contains(modelId)) { + if (mlStatsInput.retrieveStatsForModel(modelId)) { + Map actionStatsMap = new HashMap<>(); + for (Map.Entry entry : mlStats.getModelStats(modelId).entrySet()) { + if (mlStatsInput.retrieveStatsForAction(entry.getKey())) { + actionStatsMap.put(entry.getKey(), entry.getValue()); + } + } + modelStats.put(modelId, new MLModelStats(actionStatsMap)); } } - modelStats.put(modelId, new MLModelStats(actionStatsMap)); } + }, e -> { log.error("Search Hidden model wasn't successful"); }), latch); + // Wait for the asynchronous call to complete + try { + latch.await(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + // Handle interruption if necessary + Thread.currentThread().interrupt(); } } - return new MLStatsNodeResponse(clusterService.localNode(), statValues, algorithmStats, modelStats); } + + @VisibleForTesting + void searchHiddenModels(ActionListener> listener, CountDownLatch latch) { + SearchRequest searchRequest = buildHiddenModelSearchRequest(); + // Use a try-with-resources block to ensure resources are properly released + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + // Wrap the listener to restore thread context before calling it + ActionListener> internalListener = ActionListener.runAfter(listener, () -> { + latch.countDown(); + threadContext.restore(); + }); + // Wrap the search response handler to handle success and failure cases + // Notify the listener of any search failures + ActionListener al = ActionListener.wrap(response -> { + // Initialize the result set + Set result = new HashSet<>(response.getHits().getHits().length); // Set initial capacity to the number of hits + + // Iterate over the search hits and add their IDs to the result set + for (SearchHit hit : response.getHits()) { + result.add(hit.getId()); + } + // Notify the listener of the search results + internalListener.onResponse(result); + }, internalListener::onFailure); + + // Execute the search request asynchronously + client.search(searchRequest, al); + } catch (Exception e) { + // Notify the listener of any unexpected errors + listener.onFailure(e); + } + } + + private SearchRequest buildHiddenModelSearchRequest() { + SearchRequest searchRequest = new SearchRequest(CommonValue.ML_MODEL_INDEX); + // Build the query + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + boolQueryBuilder + .filter( + QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, true)) + // Add the additional filter to exclude documents where "chunk_number" exists + .mustNot(QueryBuilders.existsQuery("chunk_number")) + ); + searchRequest.source().query(boolQueryBuilder); + // Specify the fields to include in the search results (only the "_id" field) + // No fields to exclude + searchRequest.source().fetchSource(new String[] { "_id" }, new String[] {}); + return searchRequest; + } + + @VisibleForTesting + boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) { + return RestActionUtils.isSuperAdminUser(clusterService, client); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java index 35249f8ecf..1cff986b00 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java @@ -5,40 +5,54 @@ package org.opensearch.ml.action.stats; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; import static org.opensearch.ml.stats.MLNodeLevelStat.ML_JVM_HEAP_USAGE; +import static org.opensearch.ml.utils.TestHelper.builder; import java.io.IOException; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; +import java.util.*; +import java.util.concurrent.CountDownLatch; +import org.apache.lucene.search.TotalHits; import org.junit.Assert; import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; import org.opensearch.Version; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.env.Environment; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.stats.ActionName; -import org.opensearch.ml.stats.MLActionLevelStat; -import org.opensearch.ml.stats.MLActionStats; -import org.opensearch.ml.stats.MLAlgoStats; -import org.opensearch.ml.stats.MLClusterLevelStat; -import org.opensearch.ml.stats.MLModelStats; -import org.opensearch.ml.stats.MLNodeLevelStat; -import org.opensearch.ml.stats.MLStat; -import org.opensearch.ml.stats.MLStatLevel; -import org.opensearch.ml.stats.MLStats; -import org.opensearch.ml.stats.MLStatsInput; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.stats.*; import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.ml.stats.suppliers.SettableSupplier; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableSet; @@ -51,12 +65,22 @@ public class MLStatsNodesTransportActionTests extends OpenSearchIntegTestCase { private MLNodeLevelStat nodeStatName1; private Environment environment; + @Mock + private MLModelManager mlModelManager; + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + private final String modelId = "model_id"; @Override @Before public void setUp() throws Exception { super.setUp(); + MockitoAnnotations.openMocks(this); clusterStatName1 = MLClusterLevelStat.ML_MODEL_COUNT; nodeStatName1 = MLNodeLevelStat.ML_EXECUTING_TASK_COUNT; @@ -74,16 +98,24 @@ public void setUp() throws Exception { Settings settings = Settings.builder().build(); when(environment.settings()).thenReturn(settings); + when(client.threadPool()).thenReturn(threadPool); + + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(threadPool.getThreadContext()).thenReturn(threadContext); + action = new MLStatsNodesTransportAction( client().threadPool(), clusterService(), mock(TransportService.class), mock(ActionFilters.class), mlStats, - environment + environment, + client, + mlModelManager ); } + @Test public void testNewNodeRequest() { String nodeId = "nodeId1"; MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, new MLStatsInput()); @@ -94,6 +126,7 @@ public void testNewNodeRequest() { assertEquals(mlStatsNodeRequest1.getMlStatsNodesRequest(), mlStatsNodeRequest2.getMlStatsNodesRequest()); } + @Test public void testNewNodeResponse() throws IOException { Map statValues = new HashMap<>(); DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT); @@ -106,6 +139,7 @@ public void testNewNodeResponse() throws IOException { Assert.assertEquals(statsNodeResponse.getNodeLevelStatSize(), newStatsNodeResponse.getModelStatSize()); } + @Test public void testNodeOperation() { String nodeId = clusterService().localNode().getId(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, new MLStatsInput()); @@ -119,6 +153,7 @@ public void testNodeOperation() { assertNotNull(response.getNodeLevelStat(nodeStatName1)); } + @Test public void testNodeOperationWithJvmHeapUsage() { String nodeId = clusterService().localNode().getId(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, new MLStatsInput()); @@ -133,9 +168,10 @@ public void testNodeOperationWithJvmHeapUsage() { assertNotNull(response.getNodeLevelStat(ML_JVM_HEAP_USAGE)); } + @Test public void testNodeOperation_NoNodeLevelStat() { String nodeId = clusterService().localNode().getId(); - MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM, MLStatLevel.MODEL)).build(); + MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM)).build(); MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, mlStatsInput); MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); @@ -143,19 +179,34 @@ public void testNodeOperation_NoNodeLevelStat() { assertEquals(0, response.getNodeLevelStatSize()); } - public void testNodeOperation_NoNodeLevelStat_AlgoStat() { + @Test + public void testNodeOperation_NoNodeLevelStat_AlgoStatWithoutHiddenModel() { MLStats mlStats = new MLStats(statsMap); mlStats.createCounterStatIfAbsent(FunctionName.KMEANS, ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); - MLStatsNodesTransportAction action = new MLStatsNodesTransportAction( - client().threadPool(), - clusterService(), - mock(TransportService.class), - mock(ActionFilters.class), - mlStats, - environment - ); + MLStatsNodesTransportAction action = Mockito + .spy( + new MLStatsNodesTransportAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + mlStats, + environment, + client, + mlModelManager + ) + ); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + Set result = new HashSet<>(); + listener.onResponse(result); + CountDownLatch latch = invocation.getArgument(1); + latch.countDown(); // Ensure the latch is counted down after the listener is notified + return null; + }).when(action).searchHiddenModels(isA(ActionListener.class), isA(CountDownLatch.class)); String nodeId = clusterService().localNode().getId(); MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM, MLStatLevel.MODEL)).build(); @@ -179,4 +230,110 @@ public void testNodeOperation_NoNodeLevelStat_AlgoStat() { assertEquals(1l, actionStats.getActionStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); } + @Test + public void testNodeOperation_NoNodeLevelStat_AlgoStat_hiddenModel() { + MLStats mlStats = new MLStats(statsMap); + mlStats.createCounterStatIfAbsent(FunctionName.KMEANS, ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.PREDICT, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment(); + + MLStatsNodesTransportAction action = Mockito + .spy( + new MLStatsNodesTransportAction( + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + mlStats, + environment, + client, + mlModelManager + ) + ); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + Set result = new HashSet<>(); + result.add(modelId); + listener.onResponse(result); + CountDownLatch latch = invocation.getArgument(1); + latch.countDown(); // Ensure the latch is counted down after the listener is notified + return null; + }).when(action).searchHiddenModels(isA(ActionListener.class), isA(CountDownLatch.class)); + + String nodeId = clusterService().localNode().getId(); + MLStatsInput mlStatsInput = MLStatsInput.builder().targetStatLevels(EnumSet.of(MLStatLevel.ALGORITHM, MLStatLevel.MODEL)).build(); + MLStatsNodesRequest mlStatsNodesRequest = new MLStatsNodesRequest(new String[] { nodeId }, mlStatsInput); + + MLStatsNodeResponse response = action.nodeOperation(new MLStatsNodeRequest(mlStatsNodesRequest)); + + assertEquals(0, response.getNodeLevelStatSize()); + assertEquals(1, response.getAlgorithmStatSize()); + assertEquals(0, response.getModelStatSize()); + MLAlgoStats algorithmStats = response.getAlgorithmStats(FunctionName.KMEANS); + assertNotNull(algorithmStats); + MLActionStats actionStats = algorithmStats.getActionStats(ActionName.TRAIN); + assertNotNull(actionStats); + assertEquals(1l, actionStats.getActionStat(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); + + MLModelStats modelStats = response.getModelStats(modelId); + assertNull(modelStats); + } + + @Test + public void testSearchHiddenModels_successfulSearch() throws IOException { + + SearchResponse response = createSearchModelResponse(); + + ActionListener> mockListener = mock(ActionListener.class); + CountDownLatch latch = mock(CountDownLatch.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).search(captor.capture(), any()); + + action.searchHiddenModels(mockListener, latch); + + ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(Set.class); + + verify(mockListener).onResponse(argumentCaptor.capture()); + Set capturedSet = argumentCaptor.getValue(); + assertEquals(argumentCaptor.getValue().size(), 1); + assertTrue("Expected set to contain modelId", capturedSet.contains(modelId)); + verify(client).search(any(SearchRequest.class), any(ActionListener.class)); + } + + private SearchResponse createSearchModelResponse() throws IOException { + XContentBuilder content = builder(); + content.startObject(); + content.field(MLModel.MODEL_NAME_FIELD, FunctionName.METRICS_CORRELATION.name()); + content.field(MLModel.MODEL_VERSION_FIELD, MCORR_ML_VERSION); + content.field(MLModel.MODEL_ID_FIELD, modelId); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, modelId, null, null).sourceRef(BytesReference.bytes(content)); + + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + }