From c92699658e0163698ec5c4bf7af0506a288b386e Mon Sep 17 00:00:00 2001 From: Kaushal Kumar Date: Tue, 23 Jul 2024 12:07:02 -0700 Subject: [PATCH] add transport interceptor to populate queryGroupId in task headers Signed-off-by: Kaushal Kumar --- .../action/search/TransportSearchAction.java | 7 +- .../org/opensearch/search/SearchService.java | 6 -- .../main/java/org/opensearch/tasks/Task.java | 13 +-- .../opensearch/wlm/QueryGroupConstants.java | 19 +++++ .../wlm/SearchWorkloadTransportHandler.java | 53 ++++++++++++ .../SearchWorkloadTransportInterceptor.java | 37 ++++++++ .../admin/cluster/node/tasks/TaskTests.java | 9 +- .../SearchWorkloadTransportHandlerTests.java | 84 +++++++++++++++++++ ...archWorkloadTransportInterceptorTests.java | 37 ++++++++ 9 files changed, 248 insertions(+), 17 deletions(-) create mode 100644 server/src/main/java/org/opensearch/wlm/QueryGroupConstants.java create mode 100644 server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java create mode 100644 server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java create mode 100644 server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 8772e74ce7acf..a6930e1aa7798 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -101,6 +101,7 @@ import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; +import org.opensearch.wlm.QueryGroupConstants; import java.util.ArrayList; import java.util.Arrays; @@ -444,7 +445,11 @@ private void executeRequest( // At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter // or HTTP header (HTTP header will be deprecated once ActionFilter is implemented) - task.addQueryGroupHeaders(threadPool.getThreadContext()); + task.addHeader( + QueryGroupConstants.QUERY_GROUP_ID_HEADER, + threadPool.getThreadContext(), + QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER + ); PipelinedRequest searchRequest; ActionListener listener; diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index aa3e409190ae5..a53a7198c366f 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -557,7 +557,6 @@ public void executeDfsPhase( ActionListener listener ) { final IndexShard shard = getShard(request); - task.addQueryGroupHeaders(threadPool.getThreadContext()); rewriteAndFetchShardRequest(shard, request, new ActionListener() { @Override public void onResponse(ShardSearchRequest rewritten) { @@ -611,7 +610,6 @@ public void executeQueryPhase( ) { assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1 : "empty responses require more than one shard"; - task.addQueryGroupHeaders(threadPool.getThreadContext()); final IndexShard shard = getShard(request); rewriteAndFetchShardRequest(shard, request, new ActionListener() { @Override @@ -721,7 +719,6 @@ public void executeQueryPhase( freeReaderContext(readerContext.id()); throw e; } - task.addQueryGroupHeaders(threadPool.getThreadContext()); runAsync(getExecutor(readerContext.indexShard()), () -> { final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null); try ( @@ -748,7 +745,6 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest()); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); - task.addQueryGroupHeaders(threadPool.getThreadContext()); runAsync(getExecutor(readerContext.indexShard()), () -> { readerContext.setAggregatedDfs(request.dfs()); try ( @@ -799,7 +795,6 @@ public void executeFetchPhase( ) { final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request); final Releasable markAsUsed; - task.addQueryGroupHeaders(threadPool.getThreadContext()); try { markAsUsed = readerContext.markAsUsed(getScrollKeepAlive(request.scroll())); } catch (Exception e) { @@ -835,7 +830,6 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A final ReaderContext readerContext = findReaderContext(request.contextId(), request); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); - task.addQueryGroupHeaders(threadPool.getThreadContext()); runAsync(getExecutor(readerContext.indexShard()), () -> { try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) { if (request.lastEmittedDoc() != null) { diff --git a/server/src/main/java/org/opensearch/tasks/Task.java b/server/src/main/java/org/opensearch/tasks/Task.java index 01a2781dd5c1c..bb1bf5630a3aa 100644 --- a/server/src/main/java/org/opensearch/tasks/Task.java +++ b/server/src/main/java/org/opensearch/tasks/Task.java @@ -58,6 +58,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; /** * Current task information @@ -529,20 +530,20 @@ public String getHeader(String header) { * hence it is not possible to copy this header from request headers. This header is required to group the tasks into queryGroups to account for the QueryGroup level resource footprint * @param threadContext current thread context */ - public void addQueryGroupHeaders(final ThreadContext threadContext) { + public void addHeader(final String headerName, final ThreadContext threadContext, final Supplier defaultValueSupplier) { // For now this header will be coming from HTTP headers but in second phase this header // We will use this constant from QueryGroup Service once the framework changes are done - final String QUERY_GROUP_ID_HEADER = "queryGroupId"; - String requestQueryGroupId = threadContext.getHeader(QUERY_GROUP_ID_HEADER); - if (requestQueryGroupId == null) { - requestQueryGroupId = "DEFAULT_QUERY_GROUP_ID"; // TODO: move this constant either to QueryGroupService or Tracking equivalent + String headerValue = threadContext.getHeader(headerName); + + if (headerValue == null) { + headerValue = defaultValueSupplier.get(); } final Map newHeaders = new HashMap<>(headers); - newHeaders.put(QUERY_GROUP_ID_HEADER, requestQueryGroupId); + newHeaders.put(headerName, headerValue); this.headers = newHeaders; } diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupConstants.java b/server/src/main/java/org/opensearch/wlm/QueryGroupConstants.java new file mode 100644 index 0000000000000..e7b8df29d5b54 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupConstants.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import java.util.function.Supplier; + +/** + * This class will hold all the QueryGroup related constants + */ +public class QueryGroupConstants { + public static final String QUERY_GROUP_ID_HEADER = "queryGroupId"; + public static final Supplier DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP"; +} diff --git a/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java b/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java new file mode 100644 index 0000000000000..8006960790d27 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportHandler.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.search.fetch.ShardFetchRequest; +import org.opensearch.search.internal.InternalScrollSearchRequest; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchRequest; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +/** + * This class is mainly used to populate the queryGroupId header + * @param T is Search related request + */ +public class SearchWorkloadTransportHandler implements TransportRequestHandler { + + private final ThreadPool threadPool; + TransportRequestHandler actualHandler; + + public SearchWorkloadTransportHandler(ThreadPool threadPool, TransportRequestHandler actualHandler) { + this.threadPool = threadPool; + this.actualHandler = actualHandler; + } + + @Override + public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { + if (isSearchWorkloadRequest(request)) { + task.addHeader( + QueryGroupConstants.QUERY_GROUP_ID_HEADER, + threadPool.getThreadContext(), + QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER + ); + } + actualHandler.messageReceived(request, channel, task); + } + + private boolean isSearchWorkloadRequest(TransportRequest request) { + return (request instanceof ShardSearchRequest) + || (request instanceof ShardFetchRequest) + || (request instanceof InternalScrollSearchRequest) + || (request instanceof QuerySearchRequest); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java b/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java new file mode 100644 index 0000000000000..2583158a98113 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/SearchWorkloadTransportInterceptor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +/** + * This class is used to intercept search traffic requests and populate the queryGroupId header in task headers + * TODO: We still need to add this interceptor in {@link org.opensearch.node.Node} class to enable, + * leaving it until the feature is tested and done. + */ +public class SearchWorkloadTransportInterceptor implements TransportInterceptor { + private final ThreadPool threadPool; + + public SearchWorkloadTransportInterceptor(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + @Override + public TransportRequestHandler interceptHandler( + String action, + String executor, + boolean forceExecution, + TransportRequestHandler actualHandler + ) { + return new SearchWorkloadTransportHandler(threadPool, actualHandler); + } +} diff --git a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java index ad95ffc59e5ac..69491689f4686 100644 --- a/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java +++ b/server/src/test/java/org/opensearch/action/admin/cluster/node/tasks/TaskTests.java @@ -43,6 +43,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.wlm.QueryGroupConstants; import java.nio.charset.StandardCharsets; import java.util.Collections; @@ -253,9 +254,9 @@ public void testAddQueryGroupHeaders() { threadPool.getThreadContext().putHeader("queryGroupId", "afakgkagj09532059"); - task.addQueryGroupHeaders(threadPool.getThreadContext()); + task.addHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER, threadPool.getThreadContext(), () -> "default_val"); - String queryGroupId = task.getHeader("queryGroupId"); + String queryGroupId = task.getHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER); assertEquals("afakgkagj09532059", queryGroupId); } finally { @@ -275,11 +276,11 @@ public void testAddQueryGroupHeadersWhenHeaderIsNotPresentInThreadContext() { Collections.emptyMap() ); - task.addQueryGroupHeaders(threadPool.getThreadContext()); + task.addHeader(QueryGroupConstants.QUERY_GROUP_ID_HEADER, threadPool.getThreadContext(), () -> "default_val"); String queryGroupId = task.getHeader("queryGroupId"); - assertEquals("DEFAULT_QUERY_GROUP_ID", queryGroupId); + assertEquals("default_val", queryGroupId); } finally { threadPool.shutdown(); } diff --git a/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java b/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java new file mode 100644 index 0000000000000..9e3cf020ccd61 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportHandlerTests.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.action.index.IndexRequest; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +import java.util.Collections; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class SearchWorkloadTransportHandlerTests extends OpenSearchTestCase { + private SearchWorkloadTransportHandler sut; + private ThreadPool threadPool; + + private TransportRequestHandler actualHandler; + + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + actualHandler = new TestTransportRequestHandler<>(); + + sut = new SearchWorkloadTransportHandler<>(threadPool, actualHandler); + } + + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + } + + public void testMessageReceivedForSearchWorkload() throws Exception { + ShardSearchRequest request = mock(ShardSearchRequest.class); + Task spyTask = getSpyTask(); + + sut.messageReceived(request, mock(TransportChannel.class), spyTask); + + verify(spyTask, times(1)).addHeader( + QueryGroupConstants.QUERY_GROUP_ID_HEADER, + threadPool.getThreadContext(), + QueryGroupConstants.DEFAULT_QUERY_GROUP_ID_SUPPLIER + ); + } + + public void testMessageReceivedForNonSearchWorkload() throws Exception { + IndexRequest indexRequest = mock(IndexRequest.class); + Task spyTask = getSpyTask(); + sut.messageReceived(indexRequest, mock(TransportChannel.class), spyTask); + + verify(spyTask, times(0)).addHeader(any(), any(), any()); + } + + private static Task getSpyTask() { + final Task task = new Task(123, "transport", "Search", "test task", null, Collections.emptyMap()); + + return spy(task); + } + + private static class TestTransportRequestHandler implements TransportRequestHandler { + int invokeCount = 0; + + @Override + public void messageReceived(TransportRequest request, TransportChannel channel, Task task) throws Exception { + invokeCount += 1; + } + + }; +} diff --git a/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java b/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java new file mode 100644 index 0000000000000..0dbb3e9f88b4b --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/SearchWorkloadTransportInterceptorTests.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +import static org.opensearch.threadpool.ThreadPool.Names.SAME; + +public class SearchWorkloadTransportInterceptorTests extends OpenSearchTestCase { + + private ThreadPool threadPool; + private SearchWorkloadTransportInterceptor sut; + + public void setUp() throws Exception { + threadPool = new TestThreadPool(getTestName()); + sut = new SearchWorkloadTransportInterceptor(threadPool); + } + + public void tearDown() throws Exception { + threadPool.shutdown(); + } + + public void testInterceptHandler() { + TransportRequestHandler requestHandler = sut.interceptHandler("Search", SAME, false, null); + assertTrue(requestHandler instanceof SearchWorkloadTransportHandler); + } +}