Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conditional Transient header propagation #11490

Merged
merged 10 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix template setting override for replication type ([#11417](https://github.com/opensearch-project/OpenSearch/pull/11417))
- Fix Automatic addition of protocol broken in #11512 ([#11609](https://github.com/opensearch-project/OpenSearch/pull/11609))
- Fix issue when calling Delete PIT endpoint and no PITs exist ([#11711](https://github.com/opensearch-project/OpenSearch/pull/11711))
- Fix tracing context propagation for local transport instrumentation ([#11490](https://github.com/opensearch-project/OpenSearch/pull/11490))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public StoredContext stashContext() {
);
}

final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext);
if (!transientHeaders.isEmpty()) {
threadContextStruct = threadContextStruct.putTransient(transientHeaders);
}
Expand All @@ -182,7 +182,7 @@ public StoredContext stashContext() {
public Writeable captureAsWriteable() {
final ThreadContextStruct context = threadLocal.get();
return out -> {
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
};
}
Expand Down Expand Up @@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);

boolean transientHeadersModified = false;
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext);
if (!transientHeaders.isEmpty()) {
newTransientHeaders.putAll(transientHeaders);
transientHeadersModified = true;
Expand Down Expand Up @@ -322,7 +322,7 @@ public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
@Override
public void writeTo(StreamOutput out) throws IOException {
final ThreadContextStruct context = threadLocal.get();
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
}

Expand Down Expand Up @@ -534,7 +534,7 @@ boolean isDefaultContext() {
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
threadLocal.set(threadLocal.get().setSystemContext());
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

/**
Expand Down Expand Up @@ -573,15 +573,15 @@ public static Map<String, String> buildDefaultHeaders(Settings settings) {
}
}

private Map<String, Object> propagateTransients(Map<String, Object> source) {
private Map<String, Object> propagateTransients(Map<String, Object> source, boolean isSystemContext) {
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(source)));
propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext)));
return transients;
}

private Map<String, String> propagateHeaders(Map<String, Object> source) {
private Map<String, String> propagateHeaders(Map<String, Object> source, boolean isSystemContext) {
final Map<String, String> headers = new HashMap<>();
propagators.forEach(p -> headers.putAll(p.headers(source)));
propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext)));
return headers;
}

Expand All @@ -603,11 +603,13 @@ private static final class ThreadContextStruct {
// saving current warning headers' size not to recalculate the size with every new warning header
private final long warningHeadersSize;

private ThreadContextStruct setSystemContext() {
private ThreadContextStruct setSystemContext(final List<ThreadContextStatePropagator> propagators) {
if (isSystemContext) {
return this;
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true);
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true)));
return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true);
}

private ThreadContextStruct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,41 @@
public interface ThreadContextStatePropagator {
/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
* @param source current context transient headers
*
* @param source current context transient headers
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
@Deprecated(since = "2.12.0", forRemoval = true)
Map<String, Object> transients(Map<String, Object> source);
reta marked this conversation as resolved.
Show resolved Hide resolved

/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
*
* @param source current context transient headers
* @param isSystemContext if the propagation is for system context.
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
default Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
return transients(source);

Check warning on line 40 in server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java#L40

Added line #L40 was not covered by tests
};

/**
* Returns the list of request headers that needs to be propagated from current context to request.
* @param source current context headers
*
* @param source current context headers
* @return the list of request headers that needs to be propagated from current context to request
*/
@Deprecated(since = "2.12.0", forRemoval = true)
Map<String, String> headers(Map<String, Object> source);

/**
* Returns the list of request headers that needs to be propagated from current context to request.
*
* @param source current context headers
* @param isSystemContext if the propagation is for system context.
* @return the list of request headers that needs to be propagated from current context to request
*/
default Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);

Check warning on line 60 in server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java#L60

Added line #L60 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
* Propagates TASK_ID across thread contexts
*/
public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator {

@Override
@SuppressWarnings("removal")
reta marked this conversation as resolved.
Show resolved Hide resolved
public Map<String, Object> transients(Map<String, Object> source) {
final Map<String, Object> transients = new HashMap<>();

Expand All @@ -32,7 +34,18 @@ public Map<String, Object> transients(Map<String, Object> source) {
}

@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
return transients(source);
}

@Override
@SuppressWarnings("removal")
public Map<String, String> headers(Map<String, Object> source) {
return Collections.emptyMap();
}

@Override
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextStatePropagator;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -50,20 +51,29 @@ public void put(String key, Span span) {
}

@Override
@SuppressWarnings("removal")
public Map<String, Object> transients(Map<String, Object> source) {
final Map<String, Object> transients = new HashMap<>();

if (source.containsKey(CURRENT_SPAN)) {
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
if (current != null) {
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
}
}

return transients;
}

@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
reta marked this conversation as resolved.
Show resolved Hide resolved
if (isSystemContext == true) {
return Collections.emptyMap();
} else {
return transients(source);
}
}

@Override
@SuppressWarnings("removal")
public Map<String, String> headers(Map<String, Object> source) {
final Map<String, String> headers = new HashMap<>();

Expand All @@ -77,6 +87,11 @@ public Map<String, String> headers(Map<String, Object> source) {
return headers;
}

@Override
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);
}

Span getCurrentSpan(String key) {
SpanReference currentSpanRef = threadContext.getTransient(key);
return (currentSpanRef == null) ? null : currentSpanRef.getSpan();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,10 @@ public final <T extends TransportResponse> void sendRequest(
final TransportRequestOptions options,
final TransportResponseHandler<T> handler
) {
if (connection == localNodeConnection) {
// See please https://github.com/opensearch-project/OpenSearch/issues/10291
sendRequestAsync(connection, action, request, options, handler);
} else {
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(
handler,
span,
tracer
);
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
}
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(handler, span, tracer);
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import java.util.Map;
import java.util.function.Supplier;

import org.mockito.Mockito;

import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
Expand Down Expand Up @@ -740,6 +742,71 @@ public void testMarkAsSystemContext() throws IOException {
assertFalse(threadContext.isSystemContext());
}

public void testSystemContextWithPropagator() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
Map<String, Object> transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> headerMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, String> headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
ThreadContext threadContext = new ThreadContext(build);
ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class);
Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap());
Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap);

Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap);
Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap);
threadContext.registerThreadContextStatePropagator(mockPropagator);
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", 1);
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("bar", threadContext.getHeader("foo"));
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
threadContext.markAsSystemContext();
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

public void testSerializeSystemContext() throws IOException {
Settings build = Settings.builder().put("request.headers.default", "1").build();
Map<String, Object> transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> headerMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, String> headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
ThreadContext threadContext = new ThreadContext(build);
ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class);
Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap());
Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap);

Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap);
Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap);
threadContext.registerThreadContextStatePropagator(mockPropagator);
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", "test");
BytesStreamOutput out = new BytesStreamOutput();
BytesStreamOutput outFromSystemContext = new BytesStreamOutput();
threadContext.writeTo(out);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.markAsSystemContext();
threadContext.writeTo(outFromSystemContext);
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(outFromSystemContext.bytes().streamInput());
assertNull(threadContext.getHeader("test_transient_propagation_key"));
}
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(out.bytes().streamInput());
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("test", threadContext.getHeader("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

public void testPutHeaders() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.tasks;

import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

public class TaskThreadContextStatePropagatorTests extends OpenSearchTestCase {
private final TaskThreadContextStatePropagator taskThreadContextStatePropagator = new TaskThreadContextStatePropagator();

public void testTransient() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, false);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}

public void testTransientForSystemContext() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, true);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,20 @@ public void run() {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}

public void testSpanNotPropagatedToChildSystemThreadContext() {
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));

try (SpanScope scope = tracer.withSpanInScope(span)) {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}

assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}
Loading