From 83b53df1e5ceaf3a8c70a35eac1d531bb3b40003 Mon Sep 17 00:00:00 2001 From: Chenyang Ji Date: Fri, 7 Jun 2024 09:12:04 -0700 Subject: [PATCH] Query-level resource usages tracking (#13172) * Query-level resource usages tracking Signed-off-by: Chenyang Ji * Moving TaskResourceTrackingService to clusterService Signed-off-by: Chenyang Ji * use shard response header to piggyback task resource usages Signed-off-by: Chenyang Ji * split changes for query insights plugin Signed-off-by: Chenyang Ji * improve the supplier logic and other misc items Signed-off-by: Chenyang Ji * track resource usage for failed requests Signed-off-by: Chenyang Ji * move resource usages interactions into TaskResourceTrackingService Signed-off-by: Chenyang Ji --------- Signed-off-by: Chenyang Ji --- CHANGELOG.md | 1 + .../resourcetracker/ResourceUsageInfo.java | 4 + .../resourcetracker/TaskResourceInfo.java | 225 ++++++++++++++++++ .../search/AbstractSearchAsyncAction.java | 14 ++ .../action/search/FetchSearchPhase.java | 2 + .../action/search/SearchPhaseContext.java | 5 + .../action/search/SearchRequestContext.java | 34 ++- .../SearchRequestOperationsListener.java | 13 + .../action/search/TransportSearchAction.java | 12 +- .../common/util/concurrent/ThreadContext.java | 9 + .../main/java/org/opensearch/node/Node.java | 9 +- .../org/opensearch/search/SearchService.java | 19 +- .../main/java/org/opensearch/tasks/Task.java | 12 + .../tasks/TaskResourceTrackingService.java | 96 ++++++++ .../AbstractSearchAsyncActionTests.java | 94 +++++++- .../CanMatchPreFilterSearchPhaseTests.java | 15 +- .../action/search/MockSearchPhaseContext.java | 8 + .../action/search/SearchAsyncActionTests.java | 12 +- .../SearchQueryThenFetchAsyncActionTests.java | 3 +- ...earchRequestOperationsListenerSupport.java | 3 +- .../search/SearchRequestSlowLogTests.java | 15 +- .../search/SearchRequestStatsTests.java | 6 +- .../search/SearchResponseMergerTests.java | 36 ++- .../search/TransportSearchActionTests.java | 18 +- .../snapshots/SnapshotResiliencyTests.java | 6 +- .../tasks/TaskResourceInfoTests.java | 106 +++++++++ .../TaskResourceTrackingServiceTests.java | 35 +++ .../java/org/opensearch/node/MockNode.java | 10 +- .../opensearch/search/MockSearchService.java | 7 +- 29 files changed, 763 insertions(+), 66 deletions(-) create mode 100644 libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java create mode 100644 server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index b628c4ee2070c..539f5a6628dac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Remote Store] Add support to disable flush based on translog reader count ([#14027](https://github.com/opensearch-project/OpenSearch/pull/14027)) - [Query Insights] Add exporter support for top n queries ([#12982](https://github.com/opensearch-project/OpenSearch/pull/12982)) - [Query Insights] Add X-Opaque-Id to search request metadata for top n queries ([#13374](https://github.com/opensearch-project/OpenSearch/pull/13374)) +- Add support for query level resource usage tracking ([#13172](https://github.com/opensearch-project/OpenSearch/pull/13172)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java index a278b61894a65..e7b51c3389b52 100644 --- a/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java +++ b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/ResourceUsageInfo.java @@ -104,6 +104,10 @@ public long getTotalValue() { return endValue.get() - startValue; } + public long getStartValue() { + return startValue; + } + @Override public String toString() { return String.valueOf(getTotalValue()); diff --git a/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java new file mode 100644 index 0000000000000..373cdbfa7e9a1 --- /dev/null +++ b/libs/core/src/main/java/org/opensearch/core/tasks/resourcetracker/TaskResourceInfo.java @@ -0,0 +1,225 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.core.tasks.resourcetracker; + +import org.opensearch.common.annotation.PublicApi; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ConstructingObjectParser; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static org.opensearch.core.xcontent.ConstructingObjectParser.constructorArg; + +/** + * Task resource usage information with minimal information about the task + *

+ * Writeable TaskResourceInfo objects are used to represent resource usage + * information of running tasks, which can be propagated to coordinator node + * to infer query-level resource usage + * + * @opensearch.api + */ +@PublicApi(since = "2.15.0") +public class TaskResourceInfo implements Writeable, ToXContentObject { + private final String action; + private final long taskId; + private final long parentTaskId; + private final String nodeId; + private final TaskResourceUsage taskResourceUsage; + + private static final ParseField ACTION = new ParseField("action"); + private static final ParseField TASK_ID = new ParseField("taskId"); + private static final ParseField PARENT_TASK_ID = new ParseField("parentTaskId"); + private static final ParseField NODE_ID = new ParseField("nodeId"); + private static final ParseField TASK_RESOURCE_USAGE = new ParseField("taskResourceUsage"); + + public TaskResourceInfo( + final String action, + final long taskId, + final long parentTaskId, + final String nodeId, + final TaskResourceUsage taskResourceUsage + ) { + this.action = action; + this.taskId = taskId; + this.parentTaskId = parentTaskId; + this.nodeId = nodeId; + this.taskResourceUsage = taskResourceUsage; + } + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "task_resource_info", + a -> new Builder().setAction((String) a[0]) + .setTaskId((Long) a[1]) + .setParentTaskId((Long) a[2]) + .setNodeId((String) a[3]) + .setTaskResourceUsage((TaskResourceUsage) a[4]) + .build() + ); + + static { + PARSER.declareString(constructorArg(), ACTION); + PARSER.declareLong(constructorArg(), TASK_ID); + PARSER.declareLong(constructorArg(), PARENT_TASK_ID); + PARSER.declareString(constructorArg(), NODE_ID); + PARSER.declareObject(constructorArg(), TaskResourceUsage.PARSER, TASK_RESOURCE_USAGE); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTION.getPreferredName(), this.action); + builder.field(TASK_ID.getPreferredName(), this.taskId); + builder.field(PARENT_TASK_ID.getPreferredName(), this.parentTaskId); + builder.field(NODE_ID.getPreferredName(), this.nodeId); + builder.startObject(TASK_RESOURCE_USAGE.getPreferredName()); + this.taskResourceUsage.toXContent(builder, params); + builder.endObject(); + builder.endObject(); + return builder; + } + + /** + * Builder for {@link TaskResourceInfo} + */ + public static class Builder { + private TaskResourceUsage taskResourceUsage; + private String action; + private long taskId; + private long parentTaskId; + private String nodeId; + + public Builder setTaskResourceUsage(final TaskResourceUsage taskResourceUsage) { + this.taskResourceUsage = taskResourceUsage; + return this; + } + + public Builder setAction(final String action) { + this.action = action; + return this; + } + + public Builder setTaskId(final long taskId) { + this.taskId = taskId; + return this; + } + + public Builder setParentTaskId(final long parentTaskId) { + this.parentTaskId = parentTaskId; + return this; + } + + public Builder setNodeId(final String nodeId) { + this.nodeId = nodeId; + return this; + } + + public TaskResourceInfo build() { + return new TaskResourceInfo(action, taskId, parentTaskId, nodeId, taskResourceUsage); + } + } + + /** + * Read task info from a stream. + * + * @param in StreamInput to read + * @return {@link TaskResourceInfo} + * @throws IOException IOException + */ + public static TaskResourceInfo readFromStream(StreamInput in) throws IOException { + return new TaskResourceInfo.Builder().setAction(in.readString()) + .setTaskId(in.readLong()) + .setParentTaskId(in.readLong()) + .setNodeId(in.readString()) + .setTaskResourceUsage(TaskResourceUsage.readFromStream(in)) + .build(); + } + + /** + * Get TaskResourceUsage + * + * @return taskResourceUsage + */ + public TaskResourceUsage getTaskResourceUsage() { + return taskResourceUsage; + } + + /** + * Get parent task id + * + * @return parent task id + */ + public long getParentTaskId() { + return parentTaskId; + } + + /** + * Get task id + * @return task id + */ + public long getTaskId() { + return taskId; + } + + /** + * Get node id + * @return node id + */ + public String getNodeId() { + return nodeId; + } + + /** + * Get task action + * @return task action + */ + public String getAction() { + return action; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(action); + out.writeLong(taskId); + out.writeLong(parentTaskId); + out.writeString(nodeId); + taskResourceUsage.writeTo(out); + } + + @Override + public String toString() { + return Strings.toString(MediaTypeRegistry.JSON, this); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || obj.getClass() != TaskResourceInfo.class) { + return false; + } + TaskResourceInfo other = (TaskResourceInfo) obj; + return action.equals(other.action) + && taskId == other.taskId + && parentTaskId == other.parentTaskId + && Objects.equals(nodeId, other.nodeId) + && taskResourceUsage.equals(other.taskResourceUsage); + } + + @Override + public int hashCode() { + return Objects.hash(action, taskId, parentTaskId, nodeId, taskResourceUsage); + } +} diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 9bf4a4b1e18f1..f0fc05c595d6f 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -51,6 +51,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ShardOperationFailedException; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.AliasFilter; @@ -469,6 +470,10 @@ private void onRequestEnd(SearchRequestContext searchRequestContext) { this.searchRequestContext.getSearchRequestOperationsListener().onRequestEnd(this, searchRequestContext); } + private void onRequestFailure(SearchRequestContext searchRequestContext) { + this.searchRequestContext.getSearchRequestOperationsListener().onRequestFailure(this, searchRequestContext); + } + private void executePhase(SearchPhase phase) { Span phaseSpan = tracer.startSpan(SpanCreationContext.server().name("[phase/" + phase.getName() + "]")); try (final SpanScope scope = tracer.withSpanInScope(phaseSpan)) { @@ -507,6 +512,7 @@ ShardSearchFailure[] buildShardFailures() { private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget shard, final SearchShardIterator shardIt, Exception e) { // we always add the shard failure for a specific shard instance // we do make sure to clean it on a successful response from a shard + setPhaseResourceUsages(); onShardFailure(shardIndex, shard, e); SearchShardTarget nextShard = FailAwareWeightedRouting.getInstance() .findNext(shardIt, clusterState, e, () -> totalOps.incrementAndGet()); @@ -618,9 +624,15 @@ protected void onShardResult(Result result, SearchShardIterator shardIt) { if (logger.isTraceEnabled()) { logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null); } + this.setPhaseResourceUsages(); results.consumeResult(result, () -> onShardResultConsumed(result, shardIt)); } + public void setPhaseResourceUsages() { + TaskResourceInfo taskResourceUsage = searchRequestContext.getTaskResourceUsageSupplier().get(); + searchRequestContext.recordPhaseResourceUsage(taskResourceUsage); + } + private void onShardResultConsumed(Result result, SearchShardIterator shardIt) { successfulOps.incrementAndGet(); // clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level @@ -751,6 +763,7 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At @Override public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) { + setPhaseResourceUsages(); if (currentPhaseHasLifecycle) { this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this, cause); } @@ -780,6 +793,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { }); } Releasables.close(releasables); + onRequestFailure(searchRequestContext); listener.onFailure(exception); } diff --git a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java index ebb2f33f8f37d..2ad7f8a29896c 100644 --- a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java @@ -240,6 +240,7 @@ private void executeFetch( public void innerOnResponse(FetchSearchResult result) { try { progressListener.notifyFetchResult(shardIndex); + context.setPhaseResourceUsages(); counter.onResult(result); } catch (Exception e) { context.onPhaseFailure(FetchSearchPhase.this, "", e); @@ -254,6 +255,7 @@ public void onFailure(Exception e) { e ); progressListener.notifyFetchFailure(shardIndex, shardTarget, e); + context.setPhaseResourceUsages(); counter.onFailure(shardIndex, shardTarget, e); } finally { // the search context might not be cleared on the node where the fetch was executed for example diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java index df451e0745e3c..55f2a22749e70 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseContext.java @@ -150,4 +150,9 @@ default void sendReleaseSearchContext( * Registers a {@link Releasable} that will be closed when the search request finishes or fails. */ void addReleasable(Releasable releasable); + + /** + * Set the resource usage info for this phase + */ + void setPhaseResourceUsages(); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java index 5b133ba0554f4..111d9c64550b3 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestContext.java @@ -8,13 +8,20 @@ package org.opensearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.opensearch.common.annotation.InternalApi; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; +import java.util.ArrayList; import java.util.EnumMap; import java.util.HashMap; +import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Supplier; /** * This class holds request-level context for search queries at the coordinator node @@ -23,6 +30,7 @@ */ @InternalApi public class SearchRequestContext { + private static final Logger logger = LogManager.getLogger(); private final SearchRequestOperationsListener searchRequestOperationsListener; private long absoluteStartNanos; private final Map phaseTookMap; @@ -30,13 +38,21 @@ public class SearchRequestContext { private final EnumMap shardStats; private final SearchRequest searchRequest; - - SearchRequestContext(final SearchRequestOperationsListener searchRequestOperationsListener, final SearchRequest searchRequest) { + private final LinkedBlockingQueue phaseResourceUsage; + private final Supplier taskResourceUsageSupplier; + + SearchRequestContext( + final SearchRequestOperationsListener searchRequestOperationsListener, + final SearchRequest searchRequest, + final Supplier taskResourceUsageSupplier + ) { this.searchRequestOperationsListener = searchRequestOperationsListener; this.absoluteStartNanos = System.nanoTime(); this.phaseTookMap = new HashMap<>(); this.shardStats = new EnumMap<>(ShardStatsFieldNames.class); this.searchRequest = searchRequest; + this.phaseResourceUsage = new LinkedBlockingQueue<>(); + this.taskResourceUsageSupplier = taskResourceUsageSupplier; } SearchRequestOperationsListener getSearchRequestOperationsListener() { @@ -108,6 +124,20 @@ String formattedShardStats() { } } + public Supplier getTaskResourceUsageSupplier() { + return taskResourceUsageSupplier; + } + + public void recordPhaseResourceUsage(TaskResourceInfo usage) { + if (usage != null) { + this.phaseResourceUsage.add(usage); + } + } + + public List getPhaseResourceUsage() { + return new ArrayList<>(phaseResourceUsage); + } + public SearchRequest getRequest() { return searchRequest; } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java index b944572cef122..61f19977ae5ce 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestOperationsListener.java @@ -51,6 +51,8 @@ protected void onRequestStart(SearchRequestContext searchRequestContext) {} protected void onRequestEnd(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} + protected void onRequestFailure(SearchPhaseContext context, SearchRequestContext searchRequestContext) {} + protected boolean isEnabled(SearchRequest searchRequest) { return isEnabled(); } @@ -133,6 +135,17 @@ public void onRequestEnd(SearchPhaseContext context, SearchRequestContext search } } + @Override + public void onRequestFailure(SearchPhaseContext context, SearchRequestContext searchRequestContext) { + for (SearchRequestOperationsListener listener : listeners) { + try { + listener.onRequestFailure(context, searchRequestContext); + } catch (Exception e) { + logger.warn(() -> new ParameterizedMessage("onRequestFailure listener [{}] failed", listener), e); + } + } + } + public List getListeners() { return listeners; } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 143b01af3f62f..6e380775355a2 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -87,6 +87,7 @@ import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.telemetry.metrics.MetricsRegistry; import org.opensearch.telemetry.tracing.Span; import org.opensearch.telemetry.tracing.SpanBuilder; @@ -186,6 +187,7 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new); this.client = client; @@ -224,6 +227,7 @@ public TransportSearchAction( clusterService.getClusterSettings() .addSettingsUpdateConsumer(SEARCH_QUERY_METRICS_ENABLED_SETTING, this::setSearchQueryMetricsEnabled); this.tracer = tracer; + this.taskResourceTrackingService = taskResourceTrackingService; } private void setSearchQueryMetricsEnabled(boolean searchQueryMetricsEnabled) { @@ -451,7 +455,11 @@ private void executeRequest( logger, TraceableSearchRequestOperationsListener.create(tracer, requestSpan) ); - SearchRequestContext searchRequestContext = new SearchRequestContext(requestOperationsListeners, originalSearchRequest); + SearchRequestContext searchRequestContext = new SearchRequestContext( + requestOperationsListeners, + originalSearchRequest, + taskResourceTrackingService::getTaskResourceUsageFromThreadContext + ); searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext); PipelinedRequest searchRequest; diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 6580b0e0085ef..0b1aa9a4a759a 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -483,6 +483,15 @@ public void addResponseHeader(final String key, final String value) { addResponseHeader(key, value, v -> v); } + /** + * Remove the {@code value} for the specified {@code key}. + * + * @param key the header name + */ + public void removeResponseHeader(final String key) { + threadLocal.get().responseHeaders.remove(key); + } + /** * Add the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate * {@code value} after applying {@code uniqueValue} is ignored. diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index cb1f2caa082fc..f7a901335f34a 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -1261,7 +1261,8 @@ protected Node( searchModule.getFetchPhase(), responseCollectorService, circuitBreakerService, - searchModule.getIndexSearcherExecutor(threadPool) + searchModule.getIndexSearcherExecutor(threadPool), + taskResourceTrackingService ); final List> tasksExecutors = pluginsService.filterPlugins(PersistentTaskPlugin.class) @@ -1905,7 +1906,8 @@ protected SearchService newSearchService( FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, - Executor indexSearcherExecutor + Executor indexSearcherExecutor, + TaskResourceTrackingService taskResourceTrackingService ) { return new SearchService( clusterService, @@ -1917,7 +1919,8 @@ protected SearchService newSearchService( fetchPhase, responseCollectorService, circuitBreakerService, - indexSearcherExecutor + indexSearcherExecutor, + taskResourceTrackingService ); } diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index d371d69a57804..45f111d889522 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -137,6 +137,7 @@ import org.opensearch.search.sort.SortOrder; import org.opensearch.search.suggest.Suggest; import org.opensearch.search.suggest.completion.CompletionSuggestion; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPool.Names; @@ -338,6 +339,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv private final AtomicInteger openPitContexts = new AtomicInteger(); private final String sessionId = UUIDs.randomBase64UUID(); private final Executor indexSearcherExecutor; + private final TaskResourceTrackingService taskResourceTrackingService; public SearchService( ClusterService clusterService, @@ -349,7 +351,8 @@ public SearchService( FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, - Executor indexSearcherExecutor + Executor indexSearcherExecutor, + TaskResourceTrackingService taskResourceTrackingService ) { Settings settings = clusterService.getSettings(); this.threadPool = threadPool; @@ -366,6 +369,7 @@ public SearchService( circuitBreakerService.getBreaker(CircuitBreaker.REQUEST) ); this.indexSearcherExecutor = indexSearcherExecutor; + this.taskResourceTrackingService = taskResourceTrackingService; TimeValue keepAliveInterval = KEEPALIVE_INTERVAL_SETTING.get(settings); setKeepAlives(DEFAULT_KEEPALIVE_SETTING.get(settings), MAX_KEEPALIVE_SETTING.get(settings)); setPitKeepAlives(DEFAULT_KEEPALIVE_SETTING.get(settings), MAX_PIT_KEEPALIVE_SETTING.get(settings)); @@ -558,6 +562,8 @@ private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardT logger.trace("Dfs phase failed", e); processFailure(readerContext, e); throw e; + } finally { + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } } @@ -660,6 +666,8 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh logger.trace("Query phase failed", e); processFailure(readerContext, e); throw e; + } finally { + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } } @@ -705,6 +713,8 @@ public void executeQueryPhase( logger.trace("Query phase failed", e); // we handle the failure in the failure listener below throw e; + } finally { + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -737,6 +747,8 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, logger.trace("Query phase failed", e); // we handle the failure in the failure listener below throw e; + } finally { + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -786,6 +798,8 @@ public void executeFetchPhase( logger.trace("Fetch phase failed", e); // we handle the failure in the failure listener below throw e; + } finally { + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -816,6 +830,8 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); // we handle the failure in the failure listener below throw e; + } finally { + taskResourceTrackingService.writeTaskResourceUsage(task, clusterService.localNode().getId()); } }, wrapFailureListener(listener, readerContext, markAsUsed)); } @@ -1749,6 +1765,7 @@ public CanMatchResponse(boolean canMatch, MinAndMax estimatedMinAndMax) { @Override public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); out.writeBoolean(canMatch); out.writeOptionalWriteable(estimatedMinAndMax); } diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index a21a454a65d0e..0fa65bc16516f 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -476,6 +476,18 @@ public void stopThreadResourceTracking(long threadId, ResourceStatsType statsTyp throw new IllegalStateException("cannot update final values if active thread resource entry is not present"); } + public ThreadResourceInfo getActiveThreadResourceInfo(long threadId, ResourceStatsType statsType) { + final List threadResourceInfoList = resourceStats.get(threadId); + if (threadResourceInfoList != null) { + for (ThreadResourceInfo threadResourceInfo : threadResourceInfoList) { + if (threadResourceInfo.getStatsType() == statsType && threadResourceInfo.isActive()) { + return threadResourceInfo; + } + } + } + return null; + } + /** * Individual tasks can override this if they want to support task resource tracking. We just need to make sure that * the ThreadPool on which the task runs on have runnable wrapper similar to diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index f32559f6314c0..564eff6c10df6 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -14,6 +14,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchShardTask; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -22,12 +23,23 @@ import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.ConcurrentMapLong; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageInfo; import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; +import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.threadpool.ThreadPool; +import java.io.IOException; import java.lang.management.ManagementFactory; import java.util.ArrayList; import java.util.Collections; @@ -51,6 +63,7 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene Setting.Property.NodeScope ); public static final String TASK_ID = "TASK_ID"; + public static final String TASK_RESOURCE_USAGE = "TASK_RESOURCE_USAGE"; private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); @@ -261,6 +274,89 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { return storedContext; } + /** + * Get the current task level resource usage. + * + * @param task {@link SearchShardTask} + * @param nodeId the local nodeId + */ + public void writeTaskResourceUsage(SearchShardTask task, String nodeId) { + try { + // Get resource usages from when the task started + ThreadResourceInfo threadResourceInfo = task.getActiveThreadResourceInfo( + Thread.currentThread().getId(), + ResourceStatsType.WORKER_STATS + ); + if (threadResourceInfo == null) { + return; + } + Map startValues = threadResourceInfo.getResourceUsageInfo().getStatsInfo(); + if (!(startValues.containsKey(ResourceStats.CPU) && startValues.containsKey(ResourceStats.MEMORY))) { + return; + } + // Get current resource usages + ResourceUsageMetric[] endValues = getResourceUsageMetricsForThread(Thread.currentThread().getId()); + long cpu = -1, mem = -1; + for (ResourceUsageMetric endValue : endValues) { + if (endValue.getStats() == ResourceStats.MEMORY) { + mem = endValue.getValue(); + } else if (endValue.getStats() == ResourceStats.CPU) { + cpu = endValue.getValue(); + } + } + if (cpu == -1 || mem == -1) { + logger.debug("Invalid resource usage value, cpu [{}], memory [{}]: ", cpu, mem); + return; + } + + // Build task resource usage info + TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setAction(task.getAction()) + .setTaskId(task.getId()) + .setParentTaskId(task.getParentTaskId().getId()) + .setNodeId(nodeId) + .setTaskResourceUsage( + new TaskResourceUsage( + cpu - startValues.get(ResourceStats.CPU).getStartValue(), + mem - startValues.get(ResourceStats.MEMORY).getStartValue() + ) + ) + .build(); + // Remove the existing TASK_RESOURCE_USAGE header since it would have come from an earlier phase in the same request. + synchronized (this) { + threadPool.getThreadContext().removeResponseHeader(TASK_RESOURCE_USAGE); + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); + } + } catch (Exception e) { + logger.debug("Error during writing task resource usage: ", e); + } + } + + /** + * Get the task resource usages from {@link ThreadContext} + * + * @return {@link TaskResourceInfo} + */ + public TaskResourceInfo getTaskResourceUsageFromThreadContext() { + List taskResourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE); + if (taskResourceUsages != null && taskResourceUsages.size() > 0) { + String usage = taskResourceUsages.get(0); + try { + if (usage != null && !usage.isEmpty()) { + XContentParser parser = XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(usage), + MediaTypeRegistry.JSON + ); + return TaskResourceInfo.PARSER.apply(parser, null); + } + } catch (IOException e) { + logger.debug("fail to parse phase resource usages: ", e); + } + } + return null; + } + /** * Listener that gets invoked when a task execution completes. */ diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 7dcbf213d6c9d..27336e86e52b0 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -49,6 +49,8 @@ import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; +import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.shard.ShardNotFoundException; import org.opensearch.search.SearchPhaseResult; @@ -87,6 +89,7 @@ import java.util.function.BiFunction; import java.util.stream.IntStream; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; @@ -123,7 +126,8 @@ private AbstractSearchAsyncAction createAction( ArraySearchPhaseResults results, ActionListener listener, final boolean controlled, - final AtomicLong expected + final AtomicLong expected, + final TaskResourceUsage resourceUsage ) { return createAction( request, @@ -133,6 +137,7 @@ private AbstractSearchAsyncAction createAction( false, false, expected, + resourceUsage, new SearchShardIterator(null, null, Collections.emptyList(), null) ); } @@ -145,6 +150,7 @@ private AbstractSearchAsyncAction createAction( final boolean failExecutePhaseOnShard, final boolean catchExceptionWhenExecutePhaseOnShard, final AtomicLong expected, + final TaskResourceUsage resourceUsage, final SearchShardIterator... shards ) { @@ -166,6 +172,14 @@ private AbstractSearchAsyncAction createAction( return null; }; + TaskResourceInfo taskResourceInfo = new TaskResourceInfo.Builder().setTaskResourceUsage(resourceUsage) + .setTaskId(randomLong()) + .setParentTaskId(randomLong()) + .setAction(randomAlphaOfLengthBetween(1, 5)) + .setNodeId(randomAlphaOfLengthBetween(1, 5)) + .build(); + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); + return new AbstractSearchAsyncAction( "test", logger, @@ -186,7 +200,8 @@ private AbstractSearchAsyncAction createAction( SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), - request + request, + () -> null ), NoopTracer.INSTANCE ) { @@ -248,7 +263,8 @@ private void runTestTook(final boolean controlled) { new ArraySearchPhaseResults<>(10), null, controlled, - expected + expected, + new TaskResourceUsage(0, 0) ); final long actual = action.buildTookInMillis(); if (controlled) { @@ -268,7 +284,8 @@ public void testBuildShardSearchTransportRequest() { new ArraySearchPhaseResults<>(10), null, false, - expected + expected, + new TaskResourceUsage(randomLong(), randomLong()) ); String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); SearchShardIterator iterator = new SearchShardIterator( @@ -291,19 +308,39 @@ public void testBuildShardSearchTransportRequest() { public void testBuildSearchResponse() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()); ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(10); - AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, null, false, new AtomicLong()); + TaskResourceUsage taskResourceUsage = new TaskResourceUsage(randomLong(), randomLong()); + AbstractSearchAsyncAction action = createAction( + searchRequest, + phaseResults, + null, + false, + new AtomicLong(), + taskResourceUsage + ); InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null); assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + List resourceUsages = threadPool.getThreadContext().getResponseHeaders().get(TASK_RESOURCE_USAGE); + assertNotNull(resourceUsages); + assertEquals(1, resourceUsages.size()); + assertTrue(resourceUsages.get(0).contains(Long.toString(taskResourceUsage.getCpuTimeInNanos()))); + assertTrue(resourceUsages.get(0).contains(Long.toString(taskResourceUsage.getMemoryInBytes()))); } public void testBuildSearchResponseAllowPartialFailures() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(10); - AbstractSearchAsyncAction action = createAction(searchRequest, queryResult, null, false, new AtomicLong()); + AbstractSearchAsyncAction action = createAction( + searchRequest, + queryResult, + null, + false, + new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()) + ); action.onShardFailure( 0, new SearchShardTarget("node", new ShardId("index", "index-uuid", 0), null, OriginalIndices.NONE), @@ -325,7 +362,14 @@ public void testSendSearchResponseDisallowPartialFailures() { List> nodeLookups = new ArrayList<>(); int numFailures = randomIntBetween(1, 5); ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, numFailures); - AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); + AbstractSearchAsyncAction action = createAction( + searchRequest, + phaseResults, + listener, + false, + new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()) + ); for (int i = 0; i < numFailures; i++) { ShardId failureShardId = new ShardId("index", "index-uuid", i); String failureClusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); @@ -404,7 +448,14 @@ public void testOnPhaseFailure() { Set requestIds = new HashSet<>(); List> nodeLookups = new ArrayList<>(); ArraySearchPhaseResults phaseResults = phaseResults(requestIds, nodeLookups, 0); - AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); + AbstractSearchAsyncAction action = createAction( + searchRequest, + phaseResults, + listener, + false, + new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()) + ); action.onPhaseFailure(new SearchPhase("test") { @Override @@ -428,7 +479,14 @@ public void testShardNotAvailableWithDisallowPartialFailures() { ActionListener listener = ActionListener.wrap(response -> fail("onResponse should not be called"), exception::set); int numShards = randomIntBetween(2, 10); ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(numShards); - AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); + AbstractSearchAsyncAction action = createAction( + searchRequest, + phaseResults, + listener, + false, + new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()) + ); // skip one to avoid the "all shards failed" failure. SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null); skipIterator.resetAndSkip(); @@ -450,7 +508,14 @@ public void testShardNotAvailableWithIgnoreUnavailable() { ActionListener listener = ActionListener.wrap(response -> {}, exception::set); int numShards = randomIntBetween(2, 10); ArraySearchPhaseResults phaseResults = new ArraySearchPhaseResults<>(numShards); - AbstractSearchAsyncAction action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong()); + AbstractSearchAsyncAction action = createAction( + searchRequest, + phaseResults, + listener, + false, + new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()) + ); // skip one to avoid the "all shards failed" failure. SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null); skipIterator.resetAndSkip(); @@ -521,6 +586,7 @@ public void onFailure(Exception e) { true, false, new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()), shards ); action.run(); @@ -568,6 +634,7 @@ public void onFailure(Exception e) { false, false, new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()), shards ); action.run(); @@ -620,6 +687,7 @@ public void onFailure(Exception e) { false, catchExceptionWhenExecutePhaseOnShard, new AtomicLong(), + new TaskResourceUsage(randomLong(), randomLong()), shards ); action.run(); @@ -771,7 +839,8 @@ private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAct SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger), - searchRequest + searchRequest, + () -> null ), NoopTracer.INSTANCE ); @@ -825,7 +894,8 @@ private SearchQueryThenFetchAsyncAction createSearchQueryThenFetchAsyncAction( SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(searchRequestOperationsListeners, logger), - searchRequest + searchRequest, + () -> null ), NoopTracer.INSTANCE ) { diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 1881c705fe6b3..bb51aeaeee9dd 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -170,7 +170,7 @@ public void run() throws IOException { } }, SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, searchRequest), + new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null), NoopTracer.INSTANCE ); @@ -268,7 +268,7 @@ public void run() throws IOException { } }, SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, searchRequest), + new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null), NoopTracer.INSTANCE ); @@ -366,7 +366,7 @@ public void sendCanMatch( new ArraySearchPhaseResults<>(iter.size()), randomIntBetween(1, 32), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, searchRequest), + new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null), NoopTracer.INSTANCE ) { @Override @@ -396,7 +396,7 @@ protected void executePhaseOnShard( ); }, SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, searchRequest), + new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null), NoopTracer.INSTANCE ); @@ -488,7 +488,7 @@ public void run() { } }, SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, searchRequest), + new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null), NoopTracer.INSTANCE ); @@ -595,7 +595,7 @@ public void run() { } }, SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, searchRequest), + new SearchRequestContext(searchRequestOperationsListener, searchRequest, () -> null), NoopTracer.INSTANCE ); @@ -658,7 +658,8 @@ public void sendCanMatch( ExecutorService executor = OpenSearchExecutors.newDirectExecutorService(); SearchRequestContext searchRequestContext = new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ); SearchPhaseController controller = new SearchPhaseController( diff --git a/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java index cc10da8fc1f12..2f3e462f741b8 100644 --- a/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/opensearch/action/search/MockSearchPhaseContext.java @@ -182,6 +182,14 @@ public void addReleasable(Releasable releasable) { // Noop } + /** + * Set the resource usage info for this phase + */ + @Override + public void setPhaseResourceUsages() { + // Noop + } + @Override public void execute(Runnable command) { command.run(); diff --git a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java index 35e90ff662b19..8fe2d9af217d5 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchAsyncActionTests.java @@ -162,7 +162,7 @@ public void testSkipSearchShards() throws InterruptedException { new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, request), + new SearchRequestContext(searchRequestOperationsListener, request, () -> null), NoopTracer.INSTANCE ) { @@ -287,7 +287,7 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, request), + new SearchRequestContext(searchRequestOperationsListener, request, () -> null), NoopTracer.INSTANCE ) { @@ -409,7 +409,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), - request + request, + () -> null ), NoopTracer.INSTANCE ) { @@ -537,7 +538,8 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), - request + request, + () -> null ), NoopTracer.INSTANCE ) { @@ -657,7 +659,7 @@ public void testAllowPartialResults() throws InterruptedException { new ArraySearchPhaseResults<>(shardsIter.size()), request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY, - new SearchRequestContext(searchRequestOperationsListener, request), + new SearchRequestContext(searchRequestOperationsListener, request, () -> null), NoopTracer.INSTANCE ) { @Override diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index aefbbe80d5fa1..f6a06a51c7b43 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -240,7 +240,8 @@ public void sendExecuteQuery( SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(assertingListener), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ), NoopTracer.INSTANCE ) { diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java index 0f737e00478cb..fdac91a0e3124 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestOperationsListenerSupport.java @@ -25,7 +25,8 @@ default void onPhaseEnd(SearchRequestOperationsListener listener, SearchPhaseCon context, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java index 91a2552ac3f04..453fc6cd8a74c 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestSlowLogTests.java @@ -178,7 +178,8 @@ public void testConcurrentOnRequestEnd() throws InterruptedException { for (int i = 0; i < numRequests; i++) { SearchRequestContext searchRequestContext = new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(searchListenersList, logger), - searchRequest + searchRequest, + () -> null ); searchRequestContext.setAbsoluteStartNanos((i < numRequestsLogged) ? 0 : System.nanoTime()); searchRequestContexts.add(searchRequestContext); @@ -209,7 +210,8 @@ public void testSearchRequestSlowLogHasJsonFields_EmptySearchRequestContext() th SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest); SearchRequestContext searchRequestContext = new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ); SearchRequestSlowLog.SearchRequestSlowLogMessage p = new SearchRequestSlowLog.SearchRequestSlowLogMessage( searchPhaseContext, @@ -233,7 +235,8 @@ public void testSearchRequestSlowLogHasJsonFields_NotEmptySearchRequestContext() SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest); SearchRequestContext searchRequestContext = new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ); searchRequestContext.updatePhaseTookMap(SearchPhaseName.FETCH.getName(), 10L); searchRequestContext.updatePhaseTookMap(SearchPhaseName.QUERY.getName(), 50L); @@ -262,7 +265,8 @@ public void testSearchRequestSlowLogHasJsonFields_PartialContext() throws IOExce SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest); SearchRequestContext searchRequestContext = new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ); searchRequestContext.updatePhaseTookMap(SearchPhaseName.FETCH.getName(), 10L); searchRequestContext.updatePhaseTookMap(SearchPhaseName.QUERY.getName(), 50L); @@ -291,7 +295,8 @@ public void testSearchRequestSlowLogSearchContextPrinterToLog() throws IOExcepti SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(1, searchRequest); SearchRequestContext searchRequestContext = new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ); searchRequestContext.updatePhaseTookMap(SearchPhaseName.FETCH.getName(), 10L); searchRequestContext.updatePhaseTookMap(SearchPhaseName.QUERY.getName(), 50L); diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java index fb9b26e3f3ad1..1af3eb2738a58 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestStatsTests.java @@ -60,7 +60,8 @@ public void testSearchRequestStats() { ctx, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertEquals(0, testRequestStats.getPhaseCurrent(searchPhaseName)); @@ -120,7 +121,8 @@ public void testSearchRequestStatsOnPhaseEndConcurrently() throws InterruptedExc ctx, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); countDownLatch.countDown(); diff --git a/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java b/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java index ce4d5ca4f7091..0eefa413c1864 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchResponseMergerTests.java @@ -137,7 +137,8 @@ public void testMergeTookInMillis() throws InterruptedException { SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertEquals(TimeUnit.NANOSECONDS.toMillis(currentRelativeTime), searchResponse.getTook().millis()); @@ -195,7 +196,8 @@ public void testMergeShardFailures() throws InterruptedException { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, mergedResponse.getClusters()); @@ -252,7 +254,8 @@ public void testMergeShardFailuresNullShardTarget() throws InterruptedException clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, mergedResponse.getClusters()); @@ -304,7 +307,8 @@ public void testMergeShardFailuresNullShardId() throws InterruptedException { SearchResponse.Clusters.EMPTY, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ).getShardFailures(); assertThat(Arrays.asList(shardFailures), containsInAnyOrder(expectedFailures.toArray(ShardSearchFailure.EMPTY_ARRAY))); @@ -344,7 +348,8 @@ public void testMergeProfileResults() throws InterruptedException { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, mergedResponse.getClusters()); @@ -412,7 +417,8 @@ public void testMergeCompletionSuggestions() throws InterruptedException { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, mergedResponse.getClusters()); @@ -490,7 +496,8 @@ public void testMergeCompletionSuggestionsTieBreak() throws InterruptedException clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, mergedResponse.getClusters()); @@ -570,7 +577,8 @@ public void testMergeAggs() throws InterruptedException { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, mergedResponse.getClusters()); @@ -733,7 +741,8 @@ public void testMergeSearchHits() throws InterruptedException { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); @@ -799,7 +808,8 @@ public void testMergeNoResponsesAdded() { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertSame(clusters, response.getClusters()); @@ -878,7 +888,8 @@ public void testMergeEmptySearchHitsWithNonEmpty() { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertEquals(10, mergedResponse.getHits().getTotalHits().value); @@ -926,7 +937,8 @@ public void testMergeOnlyEmptyHits() { clusters, new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - new SearchRequest() + new SearchRequest(), + () -> null ) ); assertEquals(expectedTotalHits, mergedResponse.getHits().getTotalHits()); diff --git a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java index da19c839f3826..84955d01a59ce 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportSearchActionTests.java @@ -487,7 +487,8 @@ public void testCCSRemoteReduceMergeFails() throws Exception { (r, l) -> setOnce.set(Tuple.tuple(r, l)), new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ) ); if (localIndices == null) { @@ -549,7 +550,8 @@ public void testCCSRemoteReduce() throws Exception { (r, l) -> setOnce.set(Tuple.tuple(r, l)), new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ) ); if (localIndices == null) { @@ -590,7 +592,8 @@ public void testCCSRemoteReduce() throws Exception { (r, l) -> setOnce.set(Tuple.tuple(r, l)), new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ) ); if (localIndices == null) { @@ -652,7 +655,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti (r, l) -> setOnce.set(Tuple.tuple(r, l)), new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ) ); if (localIndices == null) { @@ -696,7 +700,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti (r, l) -> setOnce.set(Tuple.tuple(r, l)), new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ) ); if (localIndices == null) { @@ -751,7 +756,8 @@ public void onNodeDisconnected(DiscoveryNode node, Transport.Connection connecti (r, l) -> setOnce.set(Tuple.tuple(r, l)), new SearchRequestContext( new SearchRequestOperationsListener.CompositeListener(List.of(), LogManager.getLogger()), - searchRequest + searchRequest, + () -> null ) ); if (localIndices == null) { diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 86de008b5dee5..622507f885814 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -2291,7 +2291,8 @@ public void onFailure(final Exception e) { new FetchPhase(Collections.emptyList()), responseCollectorService, new NoneCircuitBreakerService(), - null + null, + new TaskResourceTrackingService(settings, clusterSettings, threadPool) ); SearchPhaseController searchPhaseController = new SearchPhaseController( writableRegistry(), @@ -2326,7 +2327,8 @@ public void onFailure(final Exception e) { ), NoopMetricsRegistry.INSTANCE, searchRequestOperationsCompositeListenerFactory, - NoopTracer.INSTANCE + NoopTracer.INSTANCE, + new TaskResourceTrackingService(settings, clusterSettings, threadPool) ) ); actions.put( diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java new file mode 100644 index 0000000000000..e0bfb8710bbaa --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceInfoTests.java @@ -0,0 +1,106 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; +import org.opensearch.core.tasks.resourcetracker.TaskResourceUsage; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; + +/** + * Test cases for TaskResourceInfo + */ +public class TaskResourceInfoTests extends OpenSearchTestCase { + private final Long cpuUsage = randomNonNegativeLong(); + private final Long memoryUsage = randomNonNegativeLong(); + private final String action = randomAlphaOfLengthBetween(1, 10); + private final Long taskId = randomNonNegativeLong(); + private final Long parentTaskId = randomNonNegativeLong(); + private final String nodeId = randomAlphaOfLengthBetween(1, 10); + private TaskResourceInfo taskResourceInfo; + private TaskResourceUsage taskResourceUsage; + + @Before + public void setUpVariables() { + taskResourceUsage = new TaskResourceUsage(cpuUsage, memoryUsage); + taskResourceInfo = new TaskResourceInfo(action, taskId, parentTaskId, nodeId, taskResourceUsage); + } + + public void testGetters() { + assertEquals(action, taskResourceInfo.getAction()); + assertEquals(taskId.longValue(), taskResourceInfo.getTaskId()); + assertEquals(parentTaskId.longValue(), taskResourceInfo.getParentTaskId()); + assertEquals(nodeId, taskResourceInfo.getNodeId()); + assertEquals(taskResourceUsage, taskResourceInfo.getTaskResourceUsage()); + } + + public void testEqualsAndHashCode() { + TaskResourceInfo taskResourceInfoCopy = new TaskResourceInfo(action, taskId, parentTaskId, nodeId, taskResourceUsage); + assertEquals(taskResourceInfo, taskResourceInfoCopy); + assertEquals(taskResourceInfo.hashCode(), taskResourceInfoCopy.hashCode()); + TaskResourceInfo differentTaskResourceInfo = new TaskResourceInfo( + "differentAction", + taskId, + parentTaskId, + nodeId, + taskResourceUsage + ); + assertNotEquals(taskResourceInfo, differentTaskResourceInfo); + assertNotEquals(taskResourceInfo.hashCode(), differentTaskResourceInfo.hashCode()); + } + + public void testSerialization() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + taskResourceInfo.writeTo(output); + StreamInput input = StreamInput.wrap(output.bytes().toBytesRef().bytes); + TaskResourceInfo deserializedTaskResourceInfo = TaskResourceInfo.readFromStream(input); + assertEquals(taskResourceInfo, deserializedTaskResourceInfo); + } + + public void testToString() { + String expectedString = String.format( + Locale.ROOT, + "{\"action\":\"%s\",\"taskId\":%s,\"parentTaskId\":%s,\"nodeId\":\"%s\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":%s,\"memory_in_bytes\":%s}}", + action, + taskId, + parentTaskId, + nodeId, + taskResourceUsage.getCpuTimeInNanos(), + taskResourceUsage.getMemoryInBytes() + ); + assertTrue(expectedString.equals(taskResourceInfo.toString())); + } + + public void testToXContent() throws IOException { + char[] expectedXcontent = String.format( + Locale.ROOT, + "{\"action\":\"%s\",\"taskId\":%s,\"parentTaskId\":%s,\"nodeId\":\"%s\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":%s,\"memory_in_bytes\":%s}}", + action, + taskId, + parentTaskId, + nodeId, + taskResourceUsage.getCpuTimeInNanos(), + taskResourceUsage.getMemoryInBytes() + ).toCharArray(); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON); + char[] xContent = BytesReference.bytes(taskResourceInfo.toXContent(builder, ToXContent.EMPTY_PARAMS)).utf8ToString().toCharArray(); + assertEquals(Arrays.hashCode(expectedXcontent), Arrays.hashCode(xContent)); + } +} diff --git a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java index 45d438f8d04c9..0c19c331e1510 100644 --- a/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java +++ b/server/src/test/java/org/opensearch/tasks/TaskResourceTrackingServiceTests.java @@ -9,11 +9,15 @@ package org.opensearch.tasks; import org.opensearch.action.admin.cluster.node.tasks.TransportTasksActionTests; +import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchTask; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.tasks.TaskId; +import org.opensearch.core.tasks.resourcetracker.ResourceStatsType; +import org.opensearch.core.tasks.resourcetracker.ResourceUsageMetric; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import org.opensearch.core.tasks.resourcetracker.ThreadResourceInfo; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -31,6 +35,7 @@ import static org.opensearch.core.tasks.resourcetracker.ResourceStats.CPU; import static org.opensearch.core.tasks.resourcetracker.ResourceStats.MEMORY; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID; +import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; public class TaskResourceTrackingServiceTests extends OpenSearchTestCase { @@ -142,6 +147,36 @@ public void testStartingTrackingHandlesMultipleThreadsPerTask() throws Interrupt assertEquals(numTasks, numExecutions); } + public void testWriteTaskResourceUsage() { + SearchShardTask task = new SearchShardTask(1, "test", "test", "task", TaskId.EMPTY_TASK_ID, new HashMap<>()); + taskResourceTrackingService.setTaskResourceTrackingEnabled(true); + taskResourceTrackingService.startTracking(task); + task.startThreadResourceTracking( + Thread.currentThread().getId(), + ResourceStatsType.WORKER_STATS, + new ResourceUsageMetric(CPU, 100), + new ResourceUsageMetric(MEMORY, 100) + ); + taskResourceTrackingService.writeTaskResourceUsage(task, "node_1"); + Map> headers = threadPool.getThreadContext().getResponseHeaders(); + assertEquals(1, headers.size()); + assertTrue(headers.containsKey(TASK_RESOURCE_USAGE)); + } + + public void testGetTaskResourceUsageFromThreadContext() { + String taskResourceUsageJson = + "{\"action\":\"testAction\",\"taskId\":1,\"parentTaskId\":2,\"nodeId\":\"nodeId\",\"taskResourceUsage\":{\"cpu_time_in_nanos\":1000,\"memory_in_bytes\":2000}}"; + threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceUsageJson); + TaskResourceInfo result = taskResourceTrackingService.getTaskResourceUsageFromThreadContext(); + assertNotNull(result); + assertEquals("testAction", result.getAction()); + assertEquals(1L, result.getTaskId()); + assertEquals(2L, result.getParentTaskId()); + assertEquals("nodeId", result.getNodeId()); + assertEquals(1000L, result.getTaskResourceUsage().getCpuTimeInNanos()); + assertEquals(2000L, result.getTaskResourceUsage().getMemoryInBytes()); + } + private void verifyThreadContextFixedHeaders(String key, String value) { assertEquals(threadPool.getThreadContext().getHeader(key), value); assertEquals(threadPool.getThreadContext().getTransient(key), value); diff --git a/test/framework/src/main/java/org/opensearch/node/MockNode.java b/test/framework/src/main/java/org/opensearch/node/MockNode.java index e6c7e21d5b3ea..19c65ec169d3c 100644 --- a/test/framework/src/main/java/org/opensearch/node/MockNode.java +++ b/test/framework/src/main/java/org/opensearch/node/MockNode.java @@ -60,6 +60,7 @@ import org.opensearch.search.SearchService; import org.opensearch.search.fetch.FetchPhase; import org.opensearch.search.query.QueryPhase; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.test.MockHttpTransport; import org.opensearch.test.transport.MockTransportService; @@ -155,7 +156,8 @@ protected SearchService newSearchService( FetchPhase fetchPhase, ResponseCollectorService responseCollectorService, CircuitBreakerService circuitBreakerService, - Executor indexSearcherExecutor + Executor indexSearcherExecutor, + TaskResourceTrackingService taskResourceTrackingService ) { if (getPluginsService().filterPlugins(MockSearchService.TestPlugin.class).isEmpty()) { return super.newSearchService( @@ -168,7 +170,8 @@ protected SearchService newSearchService( fetchPhase, responseCollectorService, circuitBreakerService, - indexSearcherExecutor + indexSearcherExecutor, + taskResourceTrackingService ); } return new MockSearchService( @@ -180,7 +183,8 @@ protected SearchService newSearchService( queryPhase, fetchPhase, circuitBreakerService, - indexSearcherExecutor + indexSearcherExecutor, + taskResourceTrackingService ); } diff --git a/test/framework/src/main/java/org/opensearch/search/MockSearchService.java b/test/framework/src/main/java/org/opensearch/search/MockSearchService.java index a0bbcb7be05f9..6c9ace06c8219 100644 --- a/test/framework/src/main/java/org/opensearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/opensearch/search/MockSearchService.java @@ -42,6 +42,7 @@ import org.opensearch.search.fetch.FetchPhase; import org.opensearch.search.internal.ReaderContext; import org.opensearch.search.query.QueryPhase; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import java.util.HashMap; @@ -96,7 +97,8 @@ public MockSearchService( QueryPhase queryPhase, FetchPhase fetchPhase, CircuitBreakerService circuitBreakerService, - Executor indexSearcherExecutor + Executor indexSearcherExecutor, + TaskResourceTrackingService taskResourceTrackingService ) { super( clusterService, @@ -108,7 +110,8 @@ public MockSearchService( fetchPhase, null, circuitBreakerService, - indexSearcherExecutor + indexSearcherExecutor, + taskResourceTrackingService ); }