From 027d06e567954e51bfe2e866075ef06cb844aa2e Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Wed, 1 Nov 2023 18:18:54 -0700 Subject: [PATCH] CatIndexTool implementation Signed-off-by: Daniel Widdis --- .../ml/engine/tools/CatIndexTool.java | 178 +++++++-------- .../ml/engine/tools/CatIndexToolTests.java | 202 ++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 3 +- .../opensearch/ml/common/spi/tools/Tool.java | 6 + 4 files changed, 299 insertions(+), 90 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index 3c16ef18ae..a0fe4bd1fc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -20,16 +20,12 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -52,9 +48,9 @@ public class CatIndexTool implements Tool { private Client client; private String modelId; @Setter - private Parser inputParser; + private Parser inputParser; @Setter - private Parser outputParser; + private Parser outputParser; private ClusterService clusterService; public CatIndexTool(Client client, ClusterService clusterService, String modelId) { @@ -62,9 +58,10 @@ public CatIndexTool(Client client, ClusterService clusterService, String modelId this.clusterService = clusterService; this.modelId = modelId; - outputParser = new Parser() { + outputParser = new Parser<>() { @Override public Object parse(Object o) { + @SuppressWarnings("unchecked") List mlModelOutputs = (List) o; return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); } @@ -73,80 +70,98 @@ public Object parse(Object o) { @Override public void run(Map parameters, ActionListener listener) { - List indexList = gson.fromJson(parameters.get("indices"), List.class); - String[] indices = parameters.containsKey("indices")? indexList.toArray(new String[0]) : new String[]{}; - + String[] indices = null; + if (parameters.containsKey("indices")) { + @SuppressWarnings("unchecked") + List indexList = gson.fromJson(parameters.get("indices"), List.class); + indices = indexList.toArray(new String[0]); + } final IndicesOptions indicesOptions = IndicesOptions.lenientExpandHidden(); + final boolean includeUnloadedSegments = parameters.containsKey("include_unloaded_segments") + ? Boolean.parseBoolean(parameters.get("include_unloaded_segments")) + : false; - final IndicesStatsRequest request = new IndicesStatsRequest(); - request.indices(indices); - request.indicesOptions(indicesOptions); - request.all(); - boolean includeUnloadedSegments = parameters.containsKey("include_unloaded_segments")? Boolean.parseBoolean(parameters.get("include_unloaded_segments")) : false; - request.includeUnloadedSegments(includeUnloadedSegments); + final IndicesStatsRequest request = new IndicesStatsRequest().indices(indices) + .indicesOptions(indicesOptions) + .all() + .includeUnloadedSegments(includeUnloadedSegments); client.admin().indices().stats(request, ActionListener.wrap(r -> { - try { - Set indexSet = r.getIndices().keySet(); //TODO: handle empty case - XContentBuilder xContentBuilder = XContentBuilder.builder(XContentType.JSON.xContent()); - r.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); - String response = xContentBuilder.toString(); + Set indexSet = r.getIndices().keySet(); + // Handle empty set + if (indexSet.isEmpty()) { + @SuppressWarnings("unchecked") + T empty = (T) ("There were no results searching the indices parameter [" + parameters.get("indices") + "]."); + listener.onResponse(empty); + return; + } + + // Iterate indices in response and map index to stats + Map indexStateMap = new HashMap<>(); + Metadata metadata = clusterService.state().metadata(); + + for (String index : indexSet) { + IndexMetadata indexMetadata = metadata.index(index); + IndexStats indexStats = r.getIndices().get(index); + CommonStats totalStats = indexStats.getTotal(); + CommonStats primaryStats = indexStats.getPrimaries(); + IndexState.IndexStateBuilder indexStateBuilder = IndexState.builder(); + indexStateBuilder.status(indexMetadata.getState().toString()); + indexStateBuilder.index(indexStats.getIndex()); + indexStateBuilder.uuid(indexMetadata.getIndexUUID()); + indexStateBuilder.primaryShard(indexMetadata.getNumberOfShards()); + indexStateBuilder.replicaShard(indexMetadata.getNumberOfReplicas()); + indexStateBuilder.docCount(primaryStats.docs.getCount()); + indexStateBuilder.docDeleted(primaryStats.docs.getDeleted()); + indexStateBuilder.storeSize(totalStats.getStore().size().toString()); + indexStateBuilder.primaryStoreSize(primaryStats.getStore().getSize().toString()); + indexStateMap.put(index, indexStateBuilder.build()); + } - Map indexStateMap = new HashMap<>(); - Metadata metadata = clusterService.state().metadata(); + // Get cluster health for each index + final ClusterHealthRequest clusterHealthRequest = new ClusterHealthRequest(indexSet.toArray(new String[0])) + .indicesOptions(indicesOptions) + .local(parameters.containsKey("local") ? Boolean.parseBoolean("local") : false) + .clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); - for (String index : indexSet) { - IndexMetadata indexMetadata = metadata.index(index); + client.admin().cluster().health(clusterHealthRequest, ActionListener.wrap(res -> { + // Add health to index stats + Map indexHealthMap = res.getIndices(); + for (String index : indexHealthMap.keySet()) { IndexStats indexStats = r.getIndices().get(index); - CommonStats totalStats = indexStats.getTotal(); - CommonStats primaryStats = indexStats.getPrimaries(); - IndexState.IndexStateBuilder indexStateBuilder = IndexState.builder(); - indexStateBuilder.status(indexMetadata.getState().toString()); - indexStateBuilder.index(indexStats.getIndex()); - indexStateBuilder.uuid(indexMetadata.getIndexUUID()); - indexStateBuilder.primaryShard(indexMetadata.getNumberOfShards()); - indexStateBuilder.replicaShard(indexMetadata.getNumberOfReplicas()); - indexStateBuilder.docCount(primaryStats.docs.getCount()); - indexStateBuilder.docDeleted(primaryStats.docs.getDeleted()); - indexStateBuilder.storeSize(totalStats.getStore().size().toString()); - indexStateBuilder.primaryStoreSize(primaryStats.getStore().getSize().toString()); - indexStateMap.put(index, indexStateBuilder.build()); - } - - final ClusterHealthRequest clusterHealthRequest = new ClusterHealthRequest(); - clusterHealthRequest.indices(indexSet.toArray(new String[0])); - clusterHealthRequest.indicesOptions(indicesOptions); - boolean local = parameters.containsKey("local")? Boolean.parseBoolean("local") : false; - clusterHealthRequest.local(local); - clusterHealthRequest.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); - - client.admin().cluster().health(clusterHealthRequest, ActionListener.wrap(res-> { - Map indexHealthMap = res.getIndices(); - for (String index : indexHealthMap.keySet()) { - IndexStats indexStats = r.getIndices().get(index); - final ClusterIndexHealth indexHealth = indexHealthMap.get(index); - final String health; - if (indexHealth != null) { - health = indexHealth.getStatus().toString().toLowerCase(Locale.ROOT); - } else if (indexStats != null) { - health = "red*"; - } else { - health = ""; - } - indexStateMap.get(index).setHealth(health); - } - StringBuilder responseBuilder = new StringBuilder("health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size\n"); - for (String index : indexStateMap.keySet()) { - responseBuilder.append(indexStateMap.get(index).toString()).append("\n"); + final ClusterIndexHealth indexHealth = indexHealthMap.get(index); + final String health; + if (indexHealth != null) { + health = indexHealth.getStatus().toString().toLowerCase(Locale.ROOT); + } else if (indexStats != null) { + health = "red*"; + } else { + health = ""; } - listener.onResponse((T)responseBuilder.toString()); - }, ex->{listener.onFailure(ex);})); - } catch (IOException e) { - listener.onFailure(e); - } - }, e -> { - listener.onFailure(e); - })); + indexStateMap.get(index).setHealth(health); + } + // Prepare output with header row + StringBuilder responseBuilder = new StringBuilder( + "health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size\n" + ); + // Output a row for each index + for (IndexState state : indexStateMap.values()) { + responseBuilder.append(state.getHealth()).append('\t'); + responseBuilder.append(state.getStatus()).append('\t'); + responseBuilder.append(state.getIndex()).append('\t'); + responseBuilder.append(state.getUuid()).append('\t'); + responseBuilder.append(state.getPrimaryShard()).append('\t'); + responseBuilder.append(state.getReplicaShard()).append('\t'); + responseBuilder.append(state.getDocCount()).append('\t'); + responseBuilder.append(state.getDocDeleted()).append('\t'); + responseBuilder.append(state.getStoreSize()).append('\t'); + responseBuilder.append(state.getPrimaryStoreSize()).append('\n'); + } + @SuppressWarnings("unchecked") + T s = (T) responseBuilder.toString(); + listener.onResponse(s); + }, ex -> { listener.onFailure(ex); })); + }, e -> { listener.onFailure(e); })); } @Data @@ -175,21 +190,6 @@ public IndexState(String health, String status, String index, String uuid, Integ this.storeSize = storeSize; this.primaryStoreSize = primaryStoreSize; } - - @Override - public String toString() { - return - health + '\t' + - status + '\t' + - index + '\t' + - uuid + '\t' + - primaryShard + '\t' + - replicaShard + '\t' + - docCount + '\t' + - docDeleted + '\t' + - storeSize + '\t' + - primaryStoreSize; - } } @@ -231,7 +231,7 @@ public void init(Client client, ClusterService clusterService) { @Override public CatIndexTool create(Map map) { - return new CatIndexTool(client, clusterService, (String)map.get("model_id")); + return new CatIndexTool(client, clusterService, (String) map.get("model_id")); } @Override diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java new file mode 100644 index 0000000000..e244823b0a --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -0,0 +1,202 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.shard.ShardPath; +import org.opensearch.ml.common.spi.tools.Tool; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.CommonStatsFlags; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.admin.indices.stats.ShardStats; +import org.opensearch.action.admin.indices.stats.IndexStats.IndexStatsBuilder; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexMetadata.State; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.IndexRoutingTable; +import org.opensearch.cluster.routing.IndexShardRoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.TestShardRouting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CatIndexToolTests { + + @Mock + private Client client; + + @Mock + private AdminClient adminClient; + + @Mock + private IndicesAdminClient indicesAdminClient; + + @Mock + private ClusterAdminClient clusterAdminClient; + + @Mock + private ClusterService clusterService; + + @Mock + private ClusterState clusterState; + + @Mock + private Metadata metadata; + + @Mock + private IndicesStatsResponse indicesStatsResponse; + + @Mock + private ClusterHealthResponse clusterHealthResponse; + + @Mock + private IndexMetadata indexMetadata; + + @Mock + private IndexRoutingTable indexRoutingTable; + + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(client.admin()).thenReturn(adminClient); + + when(indexMetadata.getState()).thenReturn(State.OPEN); + when(metadata.index(any(String.class))).thenReturn(indexMetadata); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + + CatIndexTool.Factory.getInstance().init(client, clusterService); + + indicesParams = Map.of("indices", "foo"); + otherParams = Map.of("other", "bar"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testRunAsyncNoIndices() throws Exception { + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), actionListenerCaptor.capture()); + when(indicesStatsResponse.getIndices()).thenReturn(Collections.emptyMap()); + + Tool tool = CatIndexTool.Factory.getInstance().create(Map.of("model_id", "test")); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + actionListenerCaptor.getValue().onResponse(indicesStatsResponse); + future.join(); + assertEquals("There were no results searching the indices parameter [null].", future.get()); + } + + @Test + public void testRunAsyncIndexStats() throws Exception { + String indexName = "foo"; + Index index = new Index(indexName, UUIDs.base64UUID()); + + // Setup indices query + @SuppressWarnings("unchecked") + ArgumentCaptor> indicesStatsListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), indicesStatsListenerCaptor.capture()); + + int shardId = 0; + ShardId shId = new ShardId(index, shardId); + Path path = Files.createTempDirectory("temp").resolve("indices").resolve(index.getUUID()).resolve(String.valueOf(shardId)); + ShardPath shardPath = new ShardPath(false, path, path, shId); + ShardRouting routing = TestShardRouting.newShardRouting(shId, "node", true, ShardRoutingState.STARTED); + CommonStats commonStats = new CommonStats(CommonStatsFlags.ALL); + IndexStats fooStats = new IndexStatsBuilder(index.getName(), index.getUUID()).add( + new ShardStats(routing, shardPath, commonStats, null, null, null) + ).build(); + when(indicesStatsResponse.getIndices()).thenReturn(Map.of(indexName, fooStats)); + + // Setup cluster health query + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(clusterAdminClient).health(any(), clusterHealthListenerCaptor.capture()); + + when(indexMetadata.getIndex()).thenReturn(index); + when(indexMetadata.getNumberOfShards()).thenReturn(1); + when(indexMetadata.getNumberOfReplicas()).thenReturn(0); + @SuppressWarnings("unchecked") + Iterator iterator = (Iterator) mock(Iterator.class); + when(iterator.hasNext()).thenReturn(false); + when(indexRoutingTable.iterator()).thenReturn(iterator); + ClusterIndexHealth fooHealth = new ClusterIndexHealth(indexMetadata, indexRoutingTable); + when(clusterHealthResponse.getIndices()).thenReturn(Map.of(indexName, fooHealth)); + + // Now make the call + Tool tool = CatIndexTool.Factory.getInstance().create(Map.of("model_id", "test")); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + indicesStatsListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterHealthListenerCaptor.getValue().onResponse(clusterHealthResponse); + future.orTimeout(10, TimeUnit.SECONDS).join(); + String response = future.get(); + assertEquals( + "health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size\n" + + "red\tOPEN\tfoo\tnull\t1\t0\t0\t0\t0b\t0b\n", + response + ); + } + + @Test + public void testRun() { + Tool tool = CatIndexTool.Factory.getInstance().create(Map.of("model_id", "test")); + // TODO This is not implemented on the interface, need to change this test if/when it is + assertNull(tool.run(emptyParams)); + } + + @Test + public void testTool() { + Tool tool = CatIndexTool.Factory.getInstance().create(Map.of("model_id", "test")); + assertEquals(CatIndexTool.NAME, tool.getName()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertFalse(tool.validate(emptyParams)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 463cd8fe69..de3d4ce8ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -94,6 +94,7 @@ import org.opensearch.ml.common.spi.MLCommonsExtension; import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.Tool.Factory; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction; @@ -818,7 +819,7 @@ public void loadExtensions(ExtensionLoader loader) { } List> toolFactories = extension.getToolFactories(); - for (Tool.Factory toolFactory : toolFactories) { + for (Tool.Factory toolFactory : toolFactories) { ToolAnnotation toolAnnotation = toolFactory.getClass().getDeclaringClass().getAnnotation(ToolAnnotation.class); if (toolAnnotation == null) { throw new IllegalArgumentException( diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index 05fb90b804..72504444a9 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -23,6 +23,12 @@ default T run(Map parameters) { return null; }; + /** + * Run tool and return response asynchronously. + * @param parameters input parameters + * @listener an action listener for the response + * @param The output type + */ default void run(Map parameters, ActionListener listener) {}; /**