diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index 0a45981b9c..8a2d0fafdb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -10,10 +10,14 @@ import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Queue; import java.util.Spliterators; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.stream.Collectors; @@ -30,6 +34,9 @@ import org.opensearch.action.admin.indices.stats.IndexStats; import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; +import org.opensearch.action.pagination.IndexPaginationStrategy; +import org.opensearch.action.pagination.PageParams; +import org.opensearch.action.pagination.PageToken; import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.IndicesOptions; import org.opensearch.client.Client; @@ -38,6 +45,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Table; import org.opensearch.common.Table.Cell; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; @@ -50,10 +58,15 @@ import lombok.Getter; import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +@Slf4j @ToolAnnotation(CatIndexTool.TYPE) public class CatIndexTool implements Tool { public static final String TYPE = "CatIndexTool"; + // This needs to be changed once it's changed in opensearch core in RestIndicesListAction. + private static final int MAX_SUPPORTED_LIST_INDICES_PAGE_SIZE = 5000; + private static final int DEFAULT_PAGE_SIZE = 1; private static final String DEFAULT_DESCRIPTION = String .join( " ", @@ -106,13 +119,13 @@ public void run(Map parameters, ActionListener listener) final String[] indices = indexList.toArray(Strings.EMPTY_ARRAY); final IndicesOptions indicesOptions = IndicesOptions.strictExpand(); - final boolean local = parameters.containsKey("local") ? Boolean.parseBoolean("local") : false; - final TimeValue clusterManagerNodeTimeout = DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT; + final boolean local = parameters.containsKey("local") && Boolean.parseBoolean(parameters.get("local")); final boolean includeUnloadedSegments = Boolean.parseBoolean(parameters.get("include_unloaded_segments")); + final PageParams pageParams = new PageParams(null, PageParams.PARAM_ASC_SORT_VALUE, DEFAULT_PAGE_SIZE); final ActionListener internalListener = ActionListener.notifyOnce(ActionListener.wrap(table -> { // Handle empty table - if (table.getRows().isEmpty()) { + if (table == null || table.getRows().isEmpty()) { @SuppressWarnings("unchecked") T empty = (T) ("There were no results searching the indices parameter [" + parameters.get("indices") + "]."); listener.onResponse(empty); @@ -131,57 +144,150 @@ public void run(Map parameters, ActionListener listener) listener.onResponse(response); }, listener::onFailure)); - sendGetSettingsRequest( + fetchClusterInfoAndPages( indices, + local, + includeUnloadedSegments, + pageParams, indicesOptions, + new ConcurrentLinkedQueue<>(), + internalListener + ); + } + + private void fetchClusterInfoAndPages( + String[] indices, + boolean local, + boolean includeUnloadedSegments, + PageParams pageParams, + IndicesOptions indicesOptions, + Queue> pageResults, + ActionListener
originalListener + ) { + // First fetch metadata like index setting and cluster states and then fetch index details in batches to save efforts. + sendGetSettingsRequest(indices, indicesOptions, local, client, new ActionListener<>() { + @Override + public void onResponse(final GetSettingsResponse getSettingsResponse) { + // The list of indices that will be returned is determined by the indices returned from the Get Settings call. + // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the + // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so + // force the IndicesOptions for all the sub-requests to be as inclusive as possible. + final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); + // Indices that were successfully resolved during the get settings request might be deleted when the + // subsequent cluster state, cluster health and indices stats requests execute. We have to distinguish two cases: + // 1) the deleted index was explicitly passed as parameter to the /_cat/indices request. In this case we + // want the subsequent requests to fail. + // 2) the deleted index was resolved as part of a wildcard or _all. In this case, we want the subsequent + // requests not to fail on the deleted index (as we want to ignore wildcards that cannot be resolved). + // This behavior can be ensured by letting the cluster state, cluster health and indices stats requests + // re-resolve the index names with the same indices options that we used for the initial cluster state + // request (strictExpand). + sendClusterStateRequest(indices, subRequestIndicesOptions, local, client, new ActionListener<>() { + @Override + public void onResponse(ClusterStateResponse clusterStateResponse) { + // Starts to fetch index details here, if a batch fails build whatever we have and return. + fetchPages( + indices, + local, + DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT, + includeUnloadedSegments, + pageParams, + pageResults, + clusterStateResponse, + getSettingsResponse, + subRequestIndicesOptions, + originalListener + ); + } + + @Override + public void onFailure(final Exception e) { + originalListener.onFailure(e); + } + }); + } + + @Override + public void onFailure(final Exception e) { + originalListener.onFailure(e); + } + }); + } + + private void fetchPages( + String[] indices, + boolean local, + TimeValue clusterManagerNodeTimeout, + boolean includeUnloadedSegments, + PageParams pageParams, + Queue> pageResults, + ClusterStateResponse clusterStateResponse, + GetSettingsResponse getSettingsResponse, + IndicesOptions subRequestIndicesOptions, + ActionListener
originalListener + ) { + final ActionListener iterativeListener = ActionListener.wrap(r -> { + // when previous response returns, build next request with response and invoke again. + PageParams nextPageParams = new PageParams(r.getNextToken(), pageParams.getSort(), pageParams.getSize()); + // when next page doesn't exist or reaches max supported page size, return. + if (r.getNextToken() == null || pageResults.size() >= MAX_SUPPORTED_LIST_INDICES_PAGE_SIZE) { + Table table = buildTable(clusterStateResponse, getSettingsResponse, pageResults); + originalListener.onResponse(table); + } else { + fetchPages( + indices, + local, + clusterManagerNodeTimeout, + includeUnloadedSegments, + nextPageParams, + pageResults, + clusterStateResponse, + getSettingsResponse, + subRequestIndicesOptions, + originalListener + ); + } + }, e -> { + log.error("Failed to fetch index info for page: {}", pageParams.getRequestedToken()); + // Do not throw the exception, just return whatever we have. + originalListener.onResponse(buildTable(clusterStateResponse, getSettingsResponse, pageResults)); + }); + IndexPaginationStrategy paginationStrategy = getPaginationStrategy(pageParams, clusterStateResponse); + // For non-paginated queries, indicesToBeQueried would be same as indices retrieved from + // rest request and unresolved, while for paginated queries, it would be a list of indices + // already resolved by ClusterStateRequest and to be displayed in a page. + final String[] indicesToBeQueried = Objects.isNull(paginationStrategy) + ? indices + : paginationStrategy.getRequestedEntities().toArray(new String[0]); + // After the group listener returns, one page complete and prepare for next page. + final GroupedActionListener groupedListener = createGroupedListener( + pageResults, + paginationStrategy.getResponseToken(), + iterativeListener + ); + + sendIndicesStatsRequest( + indicesToBeQueried, + subRequestIndicesOptions, + includeUnloadedSegments, + client, + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) + ); + + sendClusterHealthRequest( + indicesToBeQueried, + subRequestIndicesOptions, local, clusterManagerNodeTimeout, client, - new ActionListener() { - @Override - public void onResponse(final GetSettingsResponse getSettingsResponse) { - final GroupedActionListener groupedListener = createGroupedListener(4, internalListener); - groupedListener.onResponse(getSettingsResponse); - - // The list of indices that will be returned is determined by the indices returned from the Get Settings call. - // All the other requests just provide additional detail, and wildcards may be resolved differently depending on the - // type of request in the presence of security plugins (looking at you, ClusterHealthRequest), so - // force the IndicesOptions for all the sub-requests to be as inclusive as possible. - final IndicesOptions subRequestIndicesOptions = IndicesOptions.lenientExpandHidden(); - - sendIndicesStatsRequest( - indices, - subRequestIndicesOptions, - includeUnloadedSegments, - client, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - sendClusterStateRequest( - indices, - subRequestIndicesOptions, - local, - clusterManagerNodeTimeout, - client, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - sendClusterHealthRequest( - indices, - subRequestIndicesOptions, - local, - clusterManagerNodeTimeout, - client, - ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) - ); - } - - @Override - public void onFailure(final Exception e) { - internalListener.onFailure(e); - } - } + ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure) ); } + protected IndexPaginationStrategy getPaginationStrategy(PageParams pageParams, ClusterStateResponse clusterStateResponse) { + return new IndexPaginationStrategy(pageParams, clusterStateResponse.getState()); + } + @Override public String getType() { return TYPE; @@ -199,7 +305,6 @@ private void sendGetSettingsRequest( final String[] indices, final IndicesOptions indicesOptions, final boolean local, - final TimeValue clusterManagerNodeTimeout, final Client client, final ActionListener listener ) { @@ -207,7 +312,7 @@ private void sendGetSettingsRequest( request.indices(indices); request.indicesOptions(indicesOptions); request.local(local); - request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); request.names(IndexSettings.INDEX_SEARCH_THROTTLED.getKey()); client.admin().indices().getSettings(request, listener); @@ -217,7 +322,6 @@ private void sendClusterStateRequest( final String[] indices, final IndicesOptions indicesOptions, final boolean local, - final TimeValue clusterManagerNodeTimeout, final Client client, final ActionListener listener ) { @@ -226,7 +330,7 @@ private void sendClusterStateRequest( request.indices(indices); request.indicesOptions(indicesOptions); request.local(local); - request.clusterManagerNodeTimeout(clusterManagerNodeTimeout); + request.clusterManagerNodeTimeout(DEFAULT_CLUSTER_MANAGER_NODE_TIMEOUT); client.admin().cluster().state(request, listener); } @@ -266,39 +370,24 @@ private void sendIndicesStatsRequest( client.admin().indices().stats(request, listener); } - private GroupedActionListener createGroupedListener(final int size, final ActionListener
listener) { - return new GroupedActionListener<>(new ActionListener>() { + // group listener only accept two action response: IndicesStatsResponse and ClusterHealthResponse + private GroupedActionListener createGroupedListener( + final Queue> pageResults, + final PageToken pageToken, + final ActionListener listener + ) { + return new GroupedActionListener<>(new ActionListener<>() { @Override public void onResponse(final Collection responses) { - try { - GetSettingsResponse settingsResponse = extractResponse(responses, GetSettingsResponse.class); - Map indicesSettings = StreamSupport - .stream(Spliterators.spliterator(settingsResponse.getIndexToSettings().entrySet(), 0), false) - .collect(Collectors.toMap(cursor -> cursor.getKey(), cursor -> cursor.getValue())); - - ClusterStateResponse stateResponse = extractResponse(responses, ClusterStateResponse.class); - Map indicesStates = StreamSupport - .stream(stateResponse.getState().getMetadata().spliterator(), false) - .collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); - - ClusterHealthResponse healthResponse = extractResponse(responses, ClusterHealthResponse.class); - Map indicesHealths = healthResponse.getIndices(); - - IndicesStatsResponse statsResponse = extractResponse(responses, IndicesStatsResponse.class); - Map indicesStats = statsResponse.getIndices(); - - Table responseTable = buildTable(indicesSettings, indicesHealths, indicesStats, indicesStates); - listener.onResponse(responseTable); - } catch (Exception e) { - onFailure(e); - } + pageResults.add(responses); + listener.onResponse(pageToken); } @Override public void onFailure(final Exception e) { listener.onFailure(e); } - }, size); + }, 2); } @Override @@ -396,21 +485,34 @@ private Table getTableWithHeader() { } private Table buildTable( - final Map indicesSettings, - final Map indicesHealths, - final Map indicesStats, - final Map indicesMetadatas + ClusterStateResponse clusterStateResponse, + GetSettingsResponse getSettingsResponse, + Queue> responses ) { + if (responses == null || responses.isEmpty()) { + return null; + } + Tuple, Map> tuple = aggregateResults(responses); final Table table = getTableWithHeader(); AtomicInteger rowNum = new AtomicInteger(0); + Map indicesSettings = StreamSupport + .stream(Spliterators.spliterator(getSettingsResponse.getIndexToSettings().entrySet(), 0), false) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + Map indicesStates = StreamSupport + .stream(clusterStateResponse.getState().getMetadata().spliterator(), false) + .collect(Collectors.toMap(indexMetadata -> indexMetadata.getIndex().getName(), Function.identity())); + + Map indicesHealths = tuple.v2(); + Map indicesStats = tuple.v1(); indicesSettings.forEach((indexName, settings) -> { - if (!indicesMetadatas.containsKey(indexName)) { + if (!indicesStates.containsKey(indexName)) { // the index exists in the Get Indices response but is not present in the cluster state: // it is likely that the index was deleted in the meanwhile, so we ignore it. return; } - final IndexMetadata indexMetadata = indicesMetadatas.get(indexName); + final IndexMetadata indexMetadata = indicesStates.get(indexName); final IndexMetadata.State indexState = indexMetadata.getState(); final IndexStats indexStats = indicesStats.get(indexName); @@ -448,15 +550,28 @@ private Table buildTable( table.addCell(totalStats.getStore() == null ? null : totalStats.getStore().size()); table.addCell(primaryStats.getStore() == null ? null : primaryStats.getStore().size()); - table.endRow(); }); - return table; } - @SuppressWarnings("unchecked") - private static A extractResponse(final Collection responses, Class c) { - return (A) responses.stream().filter(c::isInstance).findFirst().get(); + private Tuple, Map> aggregateResults(Queue> responses) { + // Each batch produces a collection of action response, aggregate them together to build table easier. + Map indexStatsMap = new HashMap<>(); + Map clusterIndexHealthMap = new HashMap<>(); + for (Collection response : responses) { + if (response != null && !response.isEmpty()) { + response.forEach(x -> { + if (x instanceof IndicesStatsResponse) { + indexStatsMap.putAll(((IndicesStatsResponse) x).getIndices()); + } else if (x instanceof ClusterHealthResponse) { + clusterIndexHealthMap.putAll(((ClusterHealthResponse) x).getIndices()); + } else { + throw new IllegalStateException("Unexpected action response type: " + x.getClass().getName()); + } + }); + } + } + return new Tuple<>(indexStatsMap, clusterIndexHealthMap); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java index 11b29070f3..d981cc40a1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/CatIndexToolTests.java @@ -150,8 +150,8 @@ public void testRunAsyncNoIndices() throws Exception { tool.run(otherParams, listener); settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); - statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); future.join(); @@ -214,8 +214,8 @@ public void testRunAsyncIndexStats() throws Exception { tool.run(otherParams, listener); settingsActionListenerCaptor.getValue().onResponse(getSettingsResponse); - statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); clusterStateActionListenerCaptor.getValue().onResponse(clusterStateResponse); + statsActionListenerCaptor.getValue().onResponse(indicesStatsResponse); clusterHealthActionListenerCaptor.getValue().onResponse(clusterHealthResponse); future.orTimeout(10, TimeUnit.SECONDS).join(); diff --git a/plugin/src/test/java/org/opensearch/ml/tools/CatIndexToolIT.java b/plugin/src/test/java/org/opensearch/ml/tools/CatIndexToolIT.java new file mode 100644 index 0000000000..7dc21694e0 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/tools/CatIndexToolIT.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.rest.RestBaseAgentToolsIT; +import org.opensearch.ml.utils.TestHelper; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CatIndexToolIT extends RestBaseAgentToolsIT { + private String agentId; + private final String question = "{\"parameters\":{\"question\":\"please help list all the index status in the current cluster?\"}}"; + + @Before + public void setUpCluster() throws Exception { + registerCatIndexFlowAgent(); + } + + private List createIndices(int count) throws IOException { + List indices = new ArrayList<>(); + for (int i = 0; i < count; i++) { + String indexName = "test" + i; + createIndex(indexName, Settings.EMPTY); + indices.add(indexName); + } + return indices; + } + + private void registerCatIndexFlowAgent() throws Exception { + String requestBody = Files + .readString( + Path.of(this.getClass().getClassLoader().getResource("org/opensearch/ml/tools/CatIndexAgentRegistration.json").toURI()) + ); + registerMLAgent(client(), requestBody, response -> agentId = (String) response.get("agent_id")); + } + + public void testCatIndexWithFewIndices() throws IOException { + List indices = createIndices(10); + Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null); + String responseStr = TestHelper.httpEntityToString(response.getEntity()); + String toolOutput = extractResult(responseStr); + String[] actualLines = toolOutput.split("\\n"); + // plus 2 as there are one line of header and one line of system agent index, but sometimes the ml-config index will be created + // then there will be one more line. + assert actualLines.length == indices.size() + 2 || actualLines.length == indices.size() + 3; + for (String index : indices) { + assert Objects.requireNonNull(toolOutput).contains(index); + } + } + + public void testCatIndexWithMoreThan100Indices() throws IOException { + List indices = createIndices(101); + Response response = TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null); + String responseStr = TestHelper.httpEntityToString(response.getEntity()); + String toolOutput = extractResult(responseStr); + String[] actualLines = toolOutput.split("\\n"); + assert actualLines.length == indices.size() + 2 || actualLines.length == indices.size() + 3; + for (String index : indices) { + assert Objects.requireNonNull(toolOutput).contains(index); + } + } + + private String extractResult(String responseStr) { + JsonArray output = JsonParser + .parseString(responseStr) + .getAsJsonObject() + .get("inference_results") + .getAsJsonArray() + .get(0) + .getAsJsonObject() + .get("output") + .getAsJsonArray(); + for (JsonElement element : output) { + if ("response".equals(element.getAsJsonObject().get("name").getAsString())) { + return element.getAsJsonObject().get("result").getAsString(); + } + } + return null; + } +} diff --git a/plugin/src/test/resources/org/opensearch/ml/tools/CatIndexAgentRegistration.json b/plugin/src/test/resources/org/opensearch/ml/tools/CatIndexAgentRegistration.json new file mode 100644 index 0000000000..e2b41cafad --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/tools/CatIndexAgentRegistration.json @@ -0,0 +1,19 @@ +{ + "name": "list index tool flow agent", + "type": "flow", + "description": "this is a test agent", + "llm": { + "model_id": "dummy_model", + "parameters": { + "max_iteration": 5, + "stop_when_no_tool_found": true + } + }, + "tools": [ + { + "type": "CatIndexTool", + "name": "CatIndexTool" + } + ], + "app_type": "my_app" +}