Skip to content

Commit

Permalink
Prevent data nodes from sending stack traces to coordinator when `err…
Browse files Browse the repository at this point in the history
…or_trace=false` (#118266)

* first iterations

* added tests

* Update docs/changelog/118266.yaml

* constant for error_trace and typos

* centralized putHeader

* moved threadContext to parent class

* uses NodeClient.threadpool

* updated async tests to retrieve final result

* moved test to avoid starting up a node

* added transport version to avoid sending useless bytes

* more async tests
  • Loading branch information
piergm authored Dec 18, 2024
1 parent a2360d1 commit 97bc291
Show file tree
Hide file tree
Showing 18 changed files with 535 additions and 13 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/118266.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118266
summary: Prevent data nodes from sending stack traces to coordinator when `error_trace=false`
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public void testCancellationDuringTimeSeriesAggregation() throws Exception {
}

logger.info("Executing search");
// we have to explicitly set error_trace=true for the later exception check for `TimeSeriesIndexSearcher`
client().threadPool().getThreadContext().putHeader("error_trace", "true");
TimeSeriesAggregationBuilder timeSeriesAggregationBuilder = new TimeSeriesAggregationBuilder("test_agg");
ActionFuture<SearchResponse> searchResponse = prepareSearch("test").setQuery(matchAllQuery())
.addAggregation(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.http;

import org.apache.http.entity.ContentType;
import org.apache.http.nio.entity.NByteArrayEntity;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.Request;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.transport.TransportMessageListener;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentType;
import org.junit.Before;

import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.index.query.QueryBuilders.simpleQueryStringQuery;

public class SearchErrorTraceIT extends HttpSmokeTestCase {
private AtomicBoolean hasStackTrace;

@Before
private void setupMessageListener() {
internalCluster().getDataNodeInstances(TransportService.class).forEach(ts -> {
ts.addMessageListener(new TransportMessageListener() {
@Override
public void onResponseSent(long requestId, String action, Exception error) {
TransportMessageListener.super.onResponseSent(requestId, action, error);
if (action.startsWith("indices:data/read/search")) {
Optional<Throwable> throwable = ExceptionsHelper.unwrapCausesAndSuppressed(
error,
t -> t.getStackTrace().length > 0
);
hasStackTrace.set(throwable.isPresent());
}
}
});
});
}

private void setupIndexWithDocs() {
createIndex("test1", "test2");
indexRandom(
true,
prepareIndex("test1").setId("1").setSource("field", "foo"),
prepareIndex("test2").setId("10").setSource("field", 5)
);
refresh();
}

public void testSearchFailingQueryErrorTraceDefault() throws IOException {
hasStackTrace = new AtomicBoolean();
setupIndexWithDocs();

Request searchRequest = new Request("POST", "/_search");
searchRequest.setJsonEntity("""
{
"query": {
"simple_query_string" : {
"query": "foo",
"fields": ["field"]
}
}
}
""");
getRestClient().performRequest(searchRequest);
assertFalse(hasStackTrace.get());
}

public void testSearchFailingQueryErrorTraceTrue() throws IOException {
hasStackTrace = new AtomicBoolean();
setupIndexWithDocs();

Request searchRequest = new Request("POST", "/_search");
searchRequest.setJsonEntity("""
{
"query": {
"simple_query_string" : {
"query": "foo",
"fields": ["field"]
}
}
}
""");
searchRequest.addParameter("error_trace", "true");
getRestClient().performRequest(searchRequest);
assertTrue(hasStackTrace.get());
}

public void testSearchFailingQueryErrorTraceFalse() throws IOException {
hasStackTrace = new AtomicBoolean();
setupIndexWithDocs();

Request searchRequest = new Request("POST", "/_search");
searchRequest.setJsonEntity("""
{
"query": {
"simple_query_string" : {
"query": "foo",
"fields": ["field"]
}
}
}
""");
searchRequest.addParameter("error_trace", "false");
getRestClient().performRequest(searchRequest);
assertFalse(hasStackTrace.get());
}

public void testMultiSearchFailingQueryErrorTraceDefault() throws IOException {
hasStackTrace = new AtomicBoolean();
setupIndexWithDocs();

XContentType contentType = XContentType.JSON;
MultiSearchRequest multiSearchRequest = new MultiSearchRequest().add(
new SearchRequest("test*").source(new SearchSourceBuilder().query(simpleQueryStringQuery("foo").field("field")))
);
Request searchRequest = new Request("POST", "/_msearch");
byte[] requestBody = MultiSearchRequest.writeMultiLineFormat(multiSearchRequest, contentType.xContent());
searchRequest.setEntity(
new NByteArrayEntity(requestBody, ContentType.create(contentType.mediaTypeWithoutParameters(), (Charset) null))
);
getRestClient().performRequest(searchRequest);
assertFalse(hasStackTrace.get());
}

public void testMultiSearchFailingQueryErrorTraceTrue() throws IOException {
hasStackTrace = new AtomicBoolean();
setupIndexWithDocs();

XContentType contentType = XContentType.JSON;
MultiSearchRequest multiSearchRequest = new MultiSearchRequest().add(
new SearchRequest("test*").source(new SearchSourceBuilder().query(simpleQueryStringQuery("foo").field("field")))
);
Request searchRequest = new Request("POST", "/_msearch");
byte[] requestBody = MultiSearchRequest.writeMultiLineFormat(multiSearchRequest, contentType.xContent());
searchRequest.setEntity(
new NByteArrayEntity(requestBody, ContentType.create(contentType.mediaTypeWithoutParameters(), (Charset) null))
);
searchRequest.addParameter("error_trace", "true");
getRestClient().performRequest(searchRequest);
assertTrue(hasStackTrace.get());
}

public void testMultiSearchFailingQueryErrorTraceFalse() throws IOException {
hasStackTrace = new AtomicBoolean();
setupIndexWithDocs();

XContentType contentType = XContentType.JSON;
MultiSearchRequest multiSearchRequest = new MultiSearchRequest().add(
new SearchRequest("test*").source(new SearchSourceBuilder().query(simpleQueryStringQuery("foo").field("field")))
);
Request searchRequest = new Request("POST", "/_msearch");
byte[] requestBody = MultiSearchRequest.writeMultiLineFormat(multiSearchRequest, contentType.xContent());
searchRequest.setEntity(
new NByteArrayEntity(requestBody, ContentType.create(contentType.mediaTypeWithoutParameters(), (Charset) null))
);
searchRequest.addParameter("error_trace", "false");
getRestClient().performRequest(searchRequest);

assertFalse(hasStackTrace.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_QUERY_BUILDER_IN_SEARCH_FUNCTIONS = def(8_808_00_0);
public static final TransportVersion EQL_ALLOW_PARTIAL_SEARCH_RESULTS = def(8_809_00_0);
public static final TransportVersion NODE_VERSION_INFORMATION_WITH_MIN_READ_ONLY_INDEX_VERSION = def(8_810_00_0);
public static final TransportVersion ERROR_TRACE_IN_TRANSPORT_HEADER = def(8_811_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
(request, channel, task) -> searchService.executeQueryPhase(
request,
(SearchShardTask) task,
new ChannelActionListener<>(channel)
new ChannelActionListener<>(channel),
channel.getVersion()
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new);
Expand All @@ -468,7 +469,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
(request, channel, task) -> searchService.executeQueryPhase(
request,
(SearchShardTask) task,
new ChannelActionListener<>(channel)
new ChannelActionListener<>(channel),
channel.getVersion()
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.telemetry.tracing.TraceContext;

Expand Down Expand Up @@ -530,6 +532,17 @@ public String getHeader(String key) {
return value;
}

/**
* Returns the header for the given key or defaultValue if not present
*/
public String getHeaderOrDefault(String key, String defaultValue) {
String value = getHeader(key);
if (value == null) {
return defaultValue;
}
return value;
}

/**
* Returns all of the request headers from the thread's context.<br>
* <b>Be advised, headers might contain credentials.</b>
Expand Down Expand Up @@ -589,6 +602,14 @@ public void putHeader(Map<String, String> header) {
threadLocal.set(threadLocal.get().putHeaders(header));
}

public void setErrorTraceTransportHeader(RestRequest r) {
// set whether data nodes should send back stack trace based on the `error_trace` query parameter
if (r.paramAsBoolean("error_trace", RestController.ERROR_TRACE_DEFAULT)) {
// We only set it if error_trace is true (defaults to false) to avoid sending useless bytes
putHeader("error_trace", "true");
}
}

/**
* Puts a transient header object into this context
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,4 @@ protected Set<String> responseParams() {
protected Set<String> responseParams(RestApiVersion restApiVersion) {
return responseParams();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public class RestController implements HttpServerTransport.Dispatcher {
public static final String STATUS_CODE_KEY = "es_rest_status_code";
public static final String HANDLER_NAME_KEY = "es_rest_handler_name";
public static final String REQUEST_METHOD_KEY = "es_rest_request_method";
public static final boolean ERROR_TRACE_DEFAULT = false;

static {
try (InputStream stream = RestController.class.getResourceAsStream("/config/favicon.ico")) {
Expand Down Expand Up @@ -638,7 +639,7 @@ private void tryAllHandlers(final RestRequest request, final RestChannel channel
private static void validateErrorTrace(RestRequest request, RestChannel channel) {
// error_trace cannot be used when we disable detailed errors
// we consume the error_trace parameter first to ensure that it is always consumed
if (request.paramAsBoolean("error_trace", false) && channel.detailedErrorsEnabled() == false) {
if (request.paramAsBoolean("error_trace", ERROR_TRACE_DEFAULT) && channel.detailedErrorsEnabled() == false) {
throw new IllegalArgumentException("error traces in responses are disabled.");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static java.util.Collections.singletonMap;
import static org.elasticsearch.ElasticsearchException.REST_EXCEPTION_SKIP_STACK_TRACE;
import static org.elasticsearch.rest.RestController.ELASTIC_PRODUCT_HTTP_HEADER;
import static org.elasticsearch.rest.RestController.ERROR_TRACE_DEFAULT;

public final class RestResponse implements Releasable {

Expand Down Expand Up @@ -143,7 +144,7 @@ public RestResponse(RestChannel channel, RestStatus status, Exception e) throws
// switched in the xcontent rendering parameters.
// For authorization problems (RestStatus.UNAUTHORIZED) we don't want to do this since this could
// leak information to the caller who is unauthorized to make this call
if (params.paramAsBoolean("error_trace", false) && status != RestStatus.UNAUTHORIZED) {
if (params.paramAsBoolean("error_trace", ERROR_TRACE_DEFAULT) && status != RestStatus.UNAUTHORIZED) {
params = new ToXContent.DelegatingMapParams(singletonMap(REST_EXCEPTION_SKIP_STACK_TRACE, "false"), params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public String getName() {

@Override
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
if (client.threadPool() != null && client.threadPool().getThreadContext() != null) {
client.threadPool().getThreadContext().setErrorTraceTransportHeader(request);
}
final MultiSearchRequest multiSearchRequest = parseRequest(request, allowExplicitIndex, searchUsageHolder, clusterSupportsFeature);
return channel -> {
final RestCancellableNodeClient cancellableClient = new RestCancellableNodeClient(client, request.getHttpChannel());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ public Set<String> supportedCapabilities() {

@Override
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {

if (client.threadPool() != null && client.threadPool().getThreadContext() != null) {
client.threadPool().getThreadContext().setErrorTraceTransportHeader(request);
}
SearchRequest searchRequest = new SearchRequest();
// access the BwC param, but just drop it
// this might be set by old clients
Expand Down
Loading

0 comments on commit 97bc291

Please sign in to comment.