Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add logic to add queryGroupId to task headers
Browse files Browse the repository at this point in the history
Signed-off-by: Kaushal Kumar <[email protected]>
kaushalmahi12 committed Jul 10, 2024
1 parent e1d62fd commit 0e4bc0c
Showing 4 changed files with 59 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -462,6 +462,10 @@ private void executeRequest(
);
searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext);

// At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter
// or HTTP header (HTTP header will be deprecated once ActionFilter is implemented)
task.addQueryGroupHeadersTo(threadPool.getThreadContext());

PipelinedRequest searchRequest;
ActionListener<SearchResponse> listener;
try {
7 changes: 7 additions & 0 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
@@ -557,6 +557,7 @@ public void executeDfsPhase(
ActionListener<SearchPhaseResult> listener
) {
final IndexShard shard = getShard(request);
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
rewriteAndFetchShardRequest(shard, request, new ActionListener<ShardSearchRequest>() {
@Override
public void onResponse(ShardSearchRequest rewritten) {
@@ -574,6 +575,7 @@ public void onFailure(Exception exc) {
private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardTask task, boolean keepStatesInContext)
throws IOException {
ReaderContext readerContext = createOrGetReaderContext(request, keepStatesInContext);
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
try (
Releasable ignored = readerContext.markAsUsed(getKeepAlive(request));
SearchContext context = createContext(readerContext, request, task, true)
@@ -610,6 +612,7 @@ public void executeQueryPhase(
) {
assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1
: "empty responses require more than one shard";
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
final IndexShard shard = getShard(request);
rewriteAndFetchShardRequest(shard, request, new ActionListener<ShardSearchRequest>() {
@Override
@@ -719,6 +722,7 @@ public void executeQueryPhase(
freeReaderContext(readerContext.id());
throw e;
}
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (
@@ -745,6 +749,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest());
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
readerContext.setAggregatedDfs(request.dfs());
try (
@@ -795,6 +800,7 @@ public void executeFetchPhase(
) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed;
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
try {
markAsUsed = readerContext.markAsUsed(getScrollKeepAlive(request.scroll()));
} catch (Exception e) {
@@ -830,6 +836,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
task.addQueryGroupHeadersTo(threadPool.getThreadContext());
runAsync(getExecutor(readerContext.indexShard()), () -> {
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
if (request.lastEmittedDoc() != null) {
31 changes: 22 additions & 9 deletions server/src/main/java/org/opensearch/tasks/Task.java
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@

import org.opensearch.ExceptionsHelper;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.action.NotifyOnceListener;
import org.opensearch.core.common.io.stream.NamedWriteable;
@@ -88,7 +89,7 @@ public class Task {

private final TaskId parentTask;

private final Map<String, String> headers;
private Map<String, String> headers;

private final Map<Long, List<ThreadResourceInfo>> resourceStats;

@@ -277,6 +278,14 @@ public TaskId getParentTaskId() {
return parentTask;
}

/**
*
* returns the headers for this task
*/
public Map<String, String> getHeaders() {
return headers;
}

/**
* Build a status for this task or null if this task doesn't have status.
* Since most tasks don't have status this defaults to returning null. While
@@ -523,14 +532,18 @@ public String getHeader(String header) {
return headers.get(header);
}

/**
* sets the header value, It is currently not possible to determine the query group for the task at the task creation
* time, hence we need this method to add the headers to task
* @param name header name
* @param value header value
*/
public void putHeader(String name, String value) {
this.headers.put(name, value);
public void addQueryGroupHeadersTo(final ThreadContext threadContext) {
// For now this header will be coming from HTTP headers but in second phase this header

// We will use this constant from QueryGroup Service once the framework changes are done
final String QUERY_GROUP_ID_HEADER = "queryGroupId";
final String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER);

final Map<String, String> newHeaders = new HashMap<>(headers);

newHeaders.put(QUERY_GROUP_ID_HEADER, requestQueryGroupId);

this.headers = newHeaders;
}

public TaskResult result(final String nodeId, Exception error) throws IOException {
Original file line number Diff line number Diff line change
@@ -41,6 +41,8 @@
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskInfo;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
@@ -236,4 +238,28 @@ public void testTaskResourceStats() {
// pass
}
}

public void testAddQueryGroupHeadersTo() {
ThreadPool threadPool = new TestThreadPool(getClass().getName());
try {
Task task = new Task(
randomLong(),
"transport",
SearchAction.NAME,
"description",
new TaskId(randomLong() + ":" + randomLong()),
Collections.emptyMap()
);

threadPool.getThreadContext().putHeader("queryGroupId", "afakgkagj09532059");

task.addQueryGroupHeadersTo(threadPool.getThreadContext());

String queryGroupId = task.getHeader("queryGroupId");

assertEquals("afakgkagj09532059", queryGroupId);
} finally {
threadPool.shutdown();
}
}
}

0 comments on commit 0e4bc0c

Please sign in to comment.