diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index f7539edde2..adacf88a8c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -5,66 +5,81 @@ package org.opensearch.ml.engine.tools; -import lombok.Builder; -import lombok.Data; import lombok.Getter; import lombok.Setter; -import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.util.Strings; import org.opensearch.action.admin.cluster.health.ClusterHealthRequest; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; import org.opensearch.action.admin.indices.stats.CommonStats; import org.opensearch.action.admin.indices.stats.IndexStats; import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.IndicesOptions; import org.opensearch.client.Client; import org.opensearch.cluster.health.ClusterIndexHealth; import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.Table; +import org.opensearch.common.Table.Cell; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.index.IndexSettings; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; - -import java.io.IOException; -import java.util.HashMap; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Set; +import java.util.Spliterators; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import static org.opensearch.action.support.clustermanager.ClusterManagerNodeRequest.DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; import static org.opensearch.ml.common.utils.StringUtils.gson; -@Log4j2 @ToolAnnotation(CatIndexTool.NAME) public class CatIndexTool implements Tool { public static final String NAME = "CatIndexTool"; + private static final String DEFAULT_DESCRIPTION = "Use this tool to get index information."; - @Setter @Getter - private String alias; - private static String DEFAULT_DESCRIPTION = "Use this tool to get index information."; - @Getter @Setter + @Setter + @Getter + private String name = CatIndexTool.NAME; + @Getter + @Setter private String description = DEFAULT_DESCRIPTION; + @Getter + private String type; + @Getter + private String version; + private Client client; - private String modelId; @Setter - private Parser inputParser; + private Parser inputParser; @Setter - private Parser outputParser; + private Parser outputParser; + @SuppressWarnings("unused") private ClusterService clusterService; - public CatIndexTool(Client client, ClusterService clusterService, String modelId) { + public CatIndexTool(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; - this.modelId = modelId; - outputParser = new Parser() { + outputParser = new Parser<>() { @Override public Object parse(Object o) { + @SuppressWarnings("unchecked") List mlModelOutputs = (List) o; return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); } @@ -73,144 +88,208 @@ public Object parse(Object o) { @Override public void run(Map parameters, ActionListener listener) { - List indexList = gson.fromJson(parameters.get("indices"), List.class); - String[] indices = parameters.containsKey("indices")? indexList.toArray(new String[0]) : new String[]{}; + // TODO: This logic exactly matches the OpenSearch _cat/indices REST action. If code at + // o.o.rest/action/cat/RestIndicesAction.java changes those changes need to be reflected here + // https://github.com/opensearch-project/ml-commons/pull/1582#issuecomment-1796962876 + @SuppressWarnings("unchecked") + List indexList = parameters.containsKey("indices") + ? gson.fromJson(parameters.get("indices"), List.class) + : Collections.emptyList(); + final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); - final IndicesOptions indicesOptions = IndicesOptions.lenientExpandHidden(); + final IndicesOptions indicesOptions = IndicesOptions.strictExpand(); + final boolean local = parameters.containsKey("local") ? Boolean.parseBoolean("local") : false; + final TimeValue clusterManagerNodeTimeout = DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; + final boolean includeUnloadedSegments = parameters.containsKey("include_unloaded_segments") + ? Boolean.parseBoolean(parameters.get("include_unloaded_segments")) + : false; - final IndicesStatsRequest request = new IndicesStatsRequest(); - request.indices(indices); - request.indicesOptions(indicesOptions); - request.all(); - boolean includeUnloadedSegments = parameters.containsKey("include_unloaded_segments")? Boolean.parseBoolean(parameters.get("include_unloaded_segments")) : false; - request.includeUnloadedSegments(includeUnloadedSegments); + final ActionListener internalListener = ActionListener.notifyOnce(ActionListener.wrap(table -> { + // Handle empty table + if (table.getRows().isEmpty()) { + @SuppressWarnings("unchecked") + T empty = (T) ("There were no results searching the indices parameter [" + parameters.get("indices") + "]."); + listener.onResponse(empty); + return; + } + StringBuilder sb = new StringBuilder( + // Currently using c.value which is short header matching _cat/indices + // May prefer to use c.attr.get("desc") for full description + table.getHeaders().stream().map(c -> c.value.toString()).collect(Collectors.joining("\t", "", "\n")) + ); + for (List row : table.getRows()) { + sb.append(row.stream().map(c -> c.value == null ? null : c.value.toString()).collect(Collectors.joining("\t", "", "\n"))); + } + @SuppressWarnings("unchecked") + T response = (T) sb.toString(); + listener.onResponse(response); + }, listener::onFailure)); + + sendGetSettingsRequest( + indices, + indicesOptions, + local, + clusterManagerNodeTimeout, + client, + new ActionListener() { + @Override + public void onResponse(final GetSettingsResponse getSettingsResponse) { + final GroupedActionListener groupedListener = createGroupedListener(4, internalListener); + groupedListener.onResponse(getSettingsResponse); - client.admin().indices().stats(request, ActionListener.wrap(r -> { - try { - Set indexSet = r.getIndices().keySet(); //TODO: handle empty case - XContentBuilder xContentBuilder = XContentBuilder.builder(XContentType.JSON.xContent()); - r.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); - String response = xContentBuilder.toString(); - - Map indexStateMap = new HashMap<>(); - Metadata metadata = clusterService.state().metadata(); - - for (String index : indexSet) { - IndexMetadata indexMetadata = metadata.index(index); - IndexStats indexStats = r.getIndices().get(index); - CommonStats totalStats = indexStats.getTotal(); - CommonStats primaryStats = indexStats.getPrimaries(); - IndexState.IndexStateBuilder indexStateBuilder = IndexState.builder(); - indexStateBuilder.status(indexMetadata.getState().toString()); - indexStateBuilder.index(indexStats.getIndex()); - indexStateBuilder.uuid(indexMetadata.getIndexUUID()); - indexStateBuilder.primaryShard(indexMetadata.getNumberOfShards()); - indexStateBuilder.replicaShard(indexMetadata.getNumberOfReplicas()); - indexStateBuilder.docCount(primaryStats.docs.getCount()); - indexStateBuilder.docDeleted(primaryStats.docs.getDeleted()); - indexStateBuilder.storeSize(totalStats.getStore().size().toString()); - indexStateBuilder.primaryStoreSize(primaryStats.getStore().getSize().toString()); - indexStateMap.put(index, indexStateBuilder.build()); + // The list of indices that will be returned is determined by the indices returned from the Get Settings call. + // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the + // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so + // force the IndicesOptions for all the sub-requests to be as inclusive as possible. + final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); + + sendIndicesStatsRequest( + indices, + subRequestIndicesOptions, + includeUnloadedSegments, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + sendClusterStateRequest( + indices, + subRequestIndicesOptions, + local, + clusterManagerNodeTimeout, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + sendClusterHealthRequest( + indices, + subRequestIndicesOptions, + local, + clusterManagerNodeTimeout, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); } - final ClusterHealthRequest clusterHealthRequest = new ClusterHealthRequest(); - clusterHealthRequest.indices(indexSet.toArray(new String[0])); - clusterHealthRequest.indicesOptions(indicesOptions); - boolean local = parameters.containsKey("local")? Boolean.parseBoolean("local") : false; - clusterHealthRequest.local(local); - clusterHealthRequest.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); - - client.admin().cluster().health(clusterHealthRequest, ActionListener.wrap(res-> { - Map indexHealthMap = res.getIndices(); - for (String index : indexHealthMap.keySet()) { - IndexStats indexStats = r.getIndices().get(index); - final ClusterIndexHealth indexHealth = indexHealthMap.get(index); - final String health; - if (indexHealth != null) { - health = indexHealth.getStatus().toString().toLowerCase(Locale.ROOT); - } else if (indexStats != null) { - health = "red*"; - } else { - health = ""; - } - indexStateMap.get(index).setHealth(health); - } - StringBuilder responseBuilder = new StringBuilder("health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size\n"); - for (String index : indexStateMap.keySet()) { - responseBuilder.append(indexStateMap.get(index).toString()).append("\n"); - } - listener.onResponse((T)responseBuilder.toString()); - }, ex->{listener.onFailure(ex);})); - } catch (IOException e) { - listener.onFailure(e); + @Override + public void onFailure(final Exception e) { + internalListener.onFailure(e); + } } - }, e -> { - listener.onFailure(e); - })); + ); } - @Override - public String getType() { - return null; + /** + * We're using the Get Settings API here to resolve the authorized indices for the user. + * This is because the Cluster State and Cluster Health APIs do not filter output based + * on index privileges, so they can't be used to determine which indices are authorized + * or not. On top of this, the Indices Stats API cannot be used either to resolve indices + * as it does not provide information for all existing indices (for example recovering + * indices or non replicated closed indices are not reported in indices stats response). + */ + private void sendGetSettingsRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + final GetSettingsRequest request = new GetSettingsRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.names(IndexSettings.INDEX_SEARCH_THROTTLED.getKey()); + + client.admin().indices().getSettings(request, listener); } - @Override - public String getVersion() { - return null; + private void sendClusterStateRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { + + final ClusterStateRequest request = new ClusterStateRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + + client.admin().cluster().state(request, listener); } - @Data - public static class IndexState { - private String health; - private String status; - private String index; - private String uuid; - private Integer primaryShard; - private Integer replicaShard; - private Long docCount; - private Long docDeleted; - private String storeSize; - private String primaryStoreSize; - - @Builder - public IndexState(String health, String status, String index, String uuid, Integer primaryShard, Integer replicaShard, Long docCount, Long docDeleted, String storeSize, String primaryStoreSize) { - this.health = health; - this.status = status; - this.index = index; - this.uuid = uuid; - this.primaryShard = primaryShard; - this.replicaShard = replicaShard; - this.docCount = docCount; - this.docDeleted = docDeleted; - this.storeSize = storeSize; - this.primaryStoreSize = primaryStoreSize; - } + private void sendClusterHealthRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean local, + final TimeValue clusterManagerNodeTimeout, + final Client client, + final ActionListener listener + ) { - @Override - public String toString() { - return - health + '\t' + - status + '\t' + - index + '\t' + - uuid + '\t' + - primaryShard + '\t' + - replicaShard + '\t' + - docCount + '\t' + - docDeleted + '\t' + - storeSize + '\t' + - primaryStoreSize; - } + final ClusterHealthRequest request = new ClusterHealthRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.local(local); + request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + + client.admin().cluster().health(request, listener); } + private void sendIndicesStatsRequest( + final String[] indices, + final IndicesOptions indicesOptions, + final boolean includeUnloadedSegments, + final Client client, + final ActionListener listener + ) { - @Override - public String getName() { - return CatIndexTool.NAME; + final IndicesStatsRequest request = new IndicesStatsRequest(); + request.indices(indices); + request.indicesOptions(indicesOptions); + request.all(); + request.includeUnloadedSegments(includeUnloadedSegments); + + client.admin().indices().stats(request, listener); } - @Override - public void setName(String s) { + private GroupedActionListener createGroupedListener(final int size, final ActionListener
listener) { + return new GroupedActionListener<>(new ActionListener>() { + @Override + public void onResponse(final Collection responses) { + try { + GetSettingsResponse settingsResponse = extractResponse(responses, GetSettingsResponse.class); + Map indicesSettings = StreamSupport.stream( + Spliterators.spliterator(settingsResponse.getIndexToSettings().entrySet(), 0), + false + ).collect(Collectors.toMap(cursor -> cursor.getKey(), cursor -> cursor.getValue())); + ClusterStateResponse stateResponse = extractResponse(responses, ClusterStateResponse.class); + Map indicesStates = StreamSupport.stream( + stateResponse.getState().getMetadata().spliterator(), + false + ).collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); + + ClusterHealthResponse healthResponse = extractResponse(responses, ClusterHealthResponse.class); + Map indicesHealths = healthResponse.getIndices(); + + IndicesStatsResponse statsResponse = extractResponse(responses, IndicesStatsResponse.class); + Map indicesStats = statsResponse.getIndices(); + + Table responseTable = buildTable(indicesSettings, indicesHealths, indicesStats, indicesStates); + listener.onResponse(responseTable); + } catch (Exception e) { + onFailure(e); + } + } + + @Override + public void onFailure(final Exception e) { + listener.onFailure(e); + } + }, size); } @Override @@ -221,11 +300,18 @@ public boolean validate(Map parameters) { return true; } + /** + * Factory for the {@link CatIndexTool} + */ public static class Factory implements Tool.Factory { private Client client; private ClusterService clusterService; private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ public static Factory getInstance() { if (INSTANCE != null) { return INSTANCE; @@ -239,6 +325,11 @@ public static Factory getInstance() { } } + /** + * Initialize this factory + * @param client The OpenSearch client + * @param clusterService The OpenSearch cluster service + */ public void init(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; @@ -246,7 +337,7 @@ public void init(Client client, ClusterService clusterService) { @Override public CatIndexTool create(Map map) { - return new CatIndexTool(client, clusterService, (String)map.get("model_id")); + return new CatIndexTool(client, clusterService); } @Override @@ -254,4 +345,88 @@ public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } } -} \ No newline at end of file + + private Table getTableWithHeader() { + Table table = new Table(); + table.startHeaders(); + // First param is cell.value which is currently returned + // Second param is cell.attr we may want to use attr.desc in the future + table.addCell("health", "alias:h;desc:current health status"); + table.addCell("status", "alias:s;desc:open/close status"); + table.addCell("index", "alias:i,idx;desc:index name"); + table.addCell("uuid", "alias:id,uuid;desc:index uuid"); + table.addCell("pri", "alias:p,shards.primary,shardsPrimary;text-align:right;desc:number of primary shards"); + table.addCell("rep", "alias:r,shards.replica,shardsReplica;text-align:right;desc:number of replica shards"); + table.addCell("docs.count", "alias:dc,docsCount;text-align:right;desc:available docs"); + table.addCell("docs.deleted", "alias:dd,docsDeleted;text-align:right;desc:deleted docs"); + table.addCell("store.size", "sibling:pri;alias:ss,storeSize;text-align:right;desc:store size of primaries & replicas"); + table.addCell("pri.store.size", "text-align:right;desc:store size of primaries"); + // Above includes all the default fields for cat indices. See RestIndicesAction for a lot more that could be included. + table.endHeaders(); + return table; + } + + private Table buildTable( + final Map indicesSettings, + final Map indicesHealths, + final Map indicesStats, + final Map indicesMetadatas + ) { + final Table table = getTableWithHeader(); + + indicesSettings.forEach((indexName, settings) -> { + if (indicesMetadatas.containsKey(indexName) == false) { + // the index exists in the Get Indices response but is not present in the cluster state: + // it is likely that the index was deleted in the meanwhile, so we ignore it. + return; + } + + final IndexMetadata indexMetadata = indicesMetadatas.get(indexName); + final IndexMetadata.State indexState = indexMetadata.getState(); + final IndexStats indexStats = indicesStats.get(indexName); + + final String health; + final ClusterIndexHealth indexHealth = indicesHealths.get(indexName); + if (indexHealth != null) { + health = indexHealth.getStatus().toString().toLowerCase(Locale.ROOT); + } else if (indexStats != null) { + health = "red*"; + } else { + health = ""; + } + + final CommonStats primaryStats; + final CommonStats totalStats; + + if (indexStats == null || indexState == IndexMetadata.State.CLOSE) { + primaryStats = new CommonStats(); + totalStats = new CommonStats(); + } else { + primaryStats = indexStats.getPrimaries(); + totalStats = indexStats.getTotal(); + } + table.startRow(); + table.addCell(health); + table.addCell(indexState.toString().toLowerCase(Locale.ROOT)); + table.addCell(indexName); + table.addCell(indexMetadata.getIndexUUID()); + table.addCell(indexHealth == null ? null : indexHealth.getNumberOfShards()); + table.addCell(indexHealth == null ? null : indexHealth.getNumberOfReplicas()); + + table.addCell(primaryStats.getDocs() == null ? null : primaryStats.getDocs().getCount()); + table.addCell(primaryStats.getDocs() == null ? null : primaryStats.getDocs().getDeleted()); + + table.addCell(totalStats.getStore() == null ? null : totalStats.getStore().size()); + table.addCell(primaryStats.getStore() == null ? null : primaryStats.getStore().size()); + + table.endRow(); + }); + + return table; + } + + @SuppressWarnings("unchecked") + private static A extractResponse(final Collection responses, Class c) { + return (A) responses.stream().filter(c::isInstance).findFirst().get(); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java new file mode 100644 index 0000000000..b696ecec73 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.shard.ShardPath; +import org.opensearch.ml.common.spi.tools.Tool; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.opensearch.Version; +import org.opensearch.action.admin.cluster.health.ClusterHealthResponse; +import org.opensearch.action.admin.cluster.state.ClusterStateResponse; +import org.opensearch.action.admin.indices.settings.get.GetSettingsResponse; +import org.opensearch.action.admin.indices.stats.CommonStats; +import org.opensearch.action.admin.indices.stats.CommonStatsFlags; +import org.opensearch.action.admin.indices.stats.IndexStats; +import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.admin.indices.stats.ShardStats; +import org.opensearch.action.admin.indices.stats.IndexStats.IndexStatsBuilder; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.health.ClusterIndexHealth; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.IndexMetadata.State; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.routing.IndexRoutingTable; +import org.opensearch.cluster.routing.IndexShardRoutingTable; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardRoutingState; +import org.opensearch.cluster.routing.TestShardRouting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.UUIDs; +import org.opensearch.common.settings.Settings; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CatIndexToolTests { + + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private GetSettingsResponse getSettingsResponse; + @Mock + private IndicesStatsResponse indicesStatsResponse; + @Mock + private ClusterStateResponse clusterStateResponse; + @Mock + private ClusterHealthResponse clusterHealthResponse; + @Mock + private IndexMetadata indexMetadata; + @Mock + private IndexRoutingTable indexRoutingTable; + + private Map indicesParams; + private Map otherParams; + private Map emptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + when(client.admin()).thenReturn(adminClient); + + when(indexMetadata.getState()).thenReturn(State.OPEN); + when(indexMetadata.getCreationVersion()).thenReturn(Version.CURRENT); + + when(metadata.index(any(String.class))).thenReturn(indexMetadata); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + + CatIndexTool.Factory.getInstance().init(client, clusterService); + + indicesParams = Map.of("index", "[\"foo\"]"); + otherParams = Map.of("other", "[\"bar\"]"); + emptyParams = Collections.emptyMap(); + } + + @Test + public void testRunAsyncNoIndices() throws Exception { + @SuppressWarnings("unchecked") + ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor.forClass( + ActionListener.class + ); + doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor.forClass( + ActionListener.class + ); + doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); + + when(getSettingsResponse.getIndexToSettings()).thenReturn(Collections.emptyMap()); + when(indicesStatsResponse.getIndices()).thenReturn(Collections.emptyMap()); + when(clusterStateResponse.getState()).thenReturn(clusterState); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.spliterator()).thenReturn(Arrays.spliterator(new IndexMetadata[0])); + + when(clusterHealthResponse.getIndices()).thenReturn(Collections.emptyMap()); + + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); + + future.join(); + assertEquals("There were no results searching the indices parameter [null].", future.get()); + } + + @Test + public void testRunAsyncIndexStats() throws Exception { + String indexName = "foo"; + Index index = new Index(indexName, UUIDs.base64UUID()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> settingsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).getSettings(any(), settingsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> statsActionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doNothing().when(indicesAdminClient).stats(any(), statsActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterStateActionListenerCaptor = ArgumentCaptor.forClass( + ActionListener.class + ); + doNothing().when(clusterAdminClient).state(any(), clusterStateActionListenerCaptor.capture()); + + @SuppressWarnings("unchecked") + ArgumentCaptor> clusterHealthActionListenerCaptor = ArgumentCaptor.forClass( + ActionListener.class + ); + doNothing().when(clusterAdminClient).health(any(), clusterHealthActionListenerCaptor.capture()); + + when(getSettingsResponse.getIndexToSettings()).thenReturn(Map.of("foo", Settings.EMPTY)); + + int shardId = 0; + ShardId shId = new ShardId(index, shardId); + Path path = Files.createTempDirectory("temp").resolve("indices").resolve(index.getUUID()).resolve(String.valueOf(shardId)); + ShardPath shardPath = new ShardPath(false, path, path, shId); + ShardRouting routing = TestShardRouting.newShardRouting(shId, "node", true, ShardRoutingState.STARTED); + CommonStats commonStats = new CommonStats(CommonStatsFlags.ALL); + IndexStats fooStats = new IndexStatsBuilder(index.getName(), index.getUUID()).add( + new ShardStats(routing, shardPath, commonStats, null, null, null) + ).build(); + when(indicesStatsResponse.getIndices()).thenReturn(Map.of(indexName, fooStats)); + + when(indexMetadata.getIndex()).thenReturn(index); + when(indexMetadata.getNumberOfShards()).thenReturn(5); + when(indexMetadata.getNumberOfReplicas()).thenReturn(1); + when(clusterStateResponse.getState()).thenReturn(clusterState); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.spliterator()).thenReturn( + Arrays.spliterator(new IndexMetadata[] { indexMetadata }) + ); + @SuppressWarnings("unchecked") + Iterator iterator = (Iterator) mock(Iterator.class); + when(iterator.hasNext()).thenReturn(false); + when(indexRoutingTable.iterator()).thenReturn(iterator); + ClusterIndexHealth fooHealth = new ClusterIndexHealth(indexMetadata, indexRoutingTable); + when(clusterHealthResponse.getIndices()).thenReturn(Map.of(indexName, fooHealth)); + + // Now make the call + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + tool.run(otherParams, listener); + settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); + clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); + + future.orTimeout(10, TimeUnit.SECONDS).join(); + String response = future.get(); + String[] responseRows = response.trim().split("\\n"); + + assertEquals(2, responseRows.length); + String header = responseRows[0]; + String fooRow = responseRows[1]; + assertEquals(header.split("\\t").length, fooRow.split("\\t").length); + assertEquals("health\tstatus\tindex\tuuid\tpri\trep\tdocs.count\tdocs.deleted\tstore.size\tpri.store.size", header); + assertEquals("red\topen\tfoo\tnull\t5\t1\t0\t0\t0b\t0b", fooRow); + } + + @Test + public void testTool() { + Tool tool = CatIndexTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(CatIndexTool.NAME, tool.getName()); + assertTrue(tool.validate(indicesParams)); + assertTrue(tool.validate(otherParams)); + assertFalse(tool.validate(emptyParams)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 11c329d7f8..0b12d607a4 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -809,7 +809,7 @@ public void loadExtensions(ExtensionLoader loader) { externalToolFactories = new HashMap<>(); for (MLCommonsExtension extension : loader.loadExtensions(MLCommonsExtension.class)) { List> toolFactories = extension.getToolFactories(); - for (Tool.Factory toolFactory : toolFactories) { + for (Tool.Factory toolFactory : toolFactories) { ToolAnnotation toolAnnotation = toolFactory.getClass().getDeclaringClass().getAnnotation(ToolAnnotation.class); if (toolAnnotation == null) { throw new IllegalArgumentException( diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Parser.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Parser.java index c621cf7467..09bb1994d5 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Parser.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Parser.java @@ -14,8 +14,8 @@ public interface Parser { /** * Parse input. - * @param input - * @return output + * @param input the parser input + * @return output the parser output */ T parse(S input); } diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index 2e2c7745ad..3c58731c20 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -23,60 +23,66 @@ default T run(Map parameters) { return null; }; + /** + * Run tool and return response asynchronously. + * @param parameters input parameters + * @param listener an action listener for the response + * @param The output type + */ default void run(Map parameters, ActionListener listener) {}; /** * Set input parser. - * @param parser + * @param parser the parser to set */ default void setInputParser(Parser parser) {}; /** * Set output parser. - * @param parser + * @param parser the parser to set */ default void setOutputParser(Parser parser) {}; /** * Get tool type mapping to the run function. - * @return + * @return the tool type mapping */ String getType(); /** * Get tool version. - * @return + * @return the tool version */ String getVersion(); /** * Get tool name which is displayed in prompt. - * @return + * @return the tool name */ String getName(); /** * Set tool name which is displayed in prompt. - * @param name + * @param name the tool name */ void setName(String name); /** * Get tool description. - * @return + * @return the tool description */ String getDescription(); /** * Set tool description. - * @param description + * @param description the description to set */ void setDescription(String description); /** * Validate if the input is good. * @param parameters input parameters - * @return + * @return true if the input is valid */ boolean validate(Map parameters); @@ -84,8 +90,8 @@ default T run(Map parameters) { * Check if should end the whole CoT immediately. * For example, if some critical error detected like high memory pressure, * the tool may end the whole CoT process by returning true. - * @param input - * @param toolParameters + * @param input tool input string + * @param toolParameters map of input parameters * @return true as a signal to CoT to end the chain, false to continue CoT */ default boolean end(String input, Map toolParameters) { @@ -105,7 +111,18 @@ default boolean useOriginalInput() { * @param The subclass this factory produces */ interface Factory { + /** + * Create an instance of this tool. + * + * @param params Parameters for the tool + * @return an instance of this tool + */ T create(Map params); + + /** + * Get the default description of this tool. + * @return the default description + */ String getDefaultDescription(); } }