Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ThreadContext.markAsSystemContext package-private #14988

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Bump `org.apache.commons:commons-lang3` from 3.14.0 to 3.15.0 ([#14861](https://github.com/opensearch-project/OpenSearch/pull/14861))

### Changed
- Make ThreadContext.markAsSystemContext package-private ([#14988](https://github.com/opensearch-project/OpenSearch/pull/14988))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.core.Assertions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
Expand Down Expand Up @@ -142,6 +143,7 @@ public abstract class TransportReplicationAction<
public static final String REPLICA_ACTION_SUFFIX = "[r]";

protected final ThreadPool threadPool;
protected volatile InternalThreadContextWrapper tcWrapper;
protected final TransportService transportService;
protected final ClusterService clusterService;
protected final ShardStateAction shardStateAction;
Expand Down Expand Up @@ -239,6 +241,9 @@ protected TransportReplicationAction(
) {
super(actionName, actionFilters, transportService.getTaskManager());
this.threadPool = threadPool;
if (threadPool != null) {
this.tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext());
}
this.transportService = transportService;
this.clusterService = clusterService;
this.indicesService = indicesService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -104,6 +105,7 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements

private final ClusterSettings clusterSettings;
protected final ThreadPool threadPool;
protected volatile InternalThreadContextWrapper tcWrapper;

private volatile TimeValue slowTaskLoggingThreshold;

Expand Down Expand Up @@ -173,6 +175,7 @@ protected synchronized void doStart() {
Objects.requireNonNull(nodeConnectionsService, "please set the node connection service before starting");
Objects.requireNonNull(state.get(), "please set initial state before starting");
threadPoolExecutor = createThreadPoolExecutor();
tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext());
}

protected PrioritizedOpenSearchThreadPoolExecutor createThreadPoolExecutor() {
Expand Down Expand Up @@ -396,7 +399,7 @@ private void submitStateUpdateTask(
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
final UpdateTask updateTask = new UpdateTask(
config.priority(),
source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.CountDown;
import org.opensearch.common.util.concurrent.FutureUtils;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -134,6 +135,7 @@ public class MasterService extends AbstractLifecycleComponent {
private volatile TimeValue slowTaskLoggingThreshold;

protected final ThreadPool threadPool;
protected volatile InternalThreadContextWrapper tcWrapper;

private volatile PrioritizedOpenSearchThreadPoolExecutor threadPoolExecutor;
private volatile Batcher taskBatcher;
Expand Down Expand Up @@ -190,6 +192,7 @@ protected synchronized void doStart() {
Objects.requireNonNull(clusterStateSupplier, "please set a cluster state supplier before starting");
threadPoolExecutor = createThreadPoolExecutor();
taskBatcher = new Batcher(logger, threadPoolExecutor, clusterManagerTaskThrottler);
tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext());
}

protected PrioritizedOpenSearchThreadPoolExecutor createThreadPoolExecutor() {
Expand Down Expand Up @@ -1022,7 +1025,7 @@ public <T> void submitStateUpdateTasks(
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();

List<Batcher.UpdateTask> safeTasks = tasks.entrySet()
.stream()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.util.concurrent;

import java.util.Objects;

/**
* Wrapper around the ThreadContext to expose methods to the core repo without
* exposing them to plugins
*
* @opensearch.internal
*/
public class InternalThreadContextWrapper {
Copy link
Collaborator

@reta reta Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cwperks definitely +1 for the intent but I don't think the taking implementation path is correct:

  • The usage of the ThreadContext within core should be frictionless. The InternalThreadContextWrapper does break this promise, plus the class exposes public static from so anyone could do InternalThreadContextWrapper.from(threadPool.getThreadContext()); to gain the access
  • The usage of the ThreadContext outside of the core should not be possible. The long term solution for that is using JPMS but meanwhile, we could close it up using security permissions:
private static final Permission ACCESS_SYSTEM_THREAD_CONTEXT_PERMISSION = new RuntimePermission("markAsSystemContext");
SecurityManager sm = System.getSecurityManager();
        if (sm != null) {
            sm.checkPermission(ACCESS_SYSTEM_THREAD_CONTEXT_PERMISSION);
        }

Yes, the compile time checks won't be enforced but runtime checks will be.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reta That sounds like a good approach to grant/prohibit the usage of certain methods in this class in the plugin ecosystem.

I decided to take this approach initially to take advantage of the @opensearch.internal annotation so that this repo could use InternalThreadContextWrapper.from(threadPool.getThreadContext());, but plugins cannot. I will look into making this a permission that is granted through the JSM policy file.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reta I opened 2 other PRs similar to this one, but they don't make use of this class.

Do you think the other 2 PRs should be updated similarly, or would the changes in those PRs be ok?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think the other 2 PRs should be updated similarly, or would the changes in those PRs be ok?

Thanks @cwperks , I think if we cannot cleanly seal the ThreadContext methods, we would need same change there (plus, changing the method visibility for public API is breaking).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to find any usages of those 2 methods outside of the core:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean those are public, right? We sadly don't know who might have been using them (since we don't host all plugins)

Copy link
Member Author

@cwperks cwperks Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be safe to change the modifier on stashWithOrigin because a plugin would still have access to stashContext after that PR and could refactor to 2 lines.

final ThreadContext.StoredContext storedContext = threadContext.stashContext();
threadContext.putTransient(ACTION_ORIGIN_TRANSIENT_NAME, origin);

I think stashAndMergeHeaders would be safe to change based on the usages I see in the core repo, but yes generally changing an access modifier could be breaking.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think stashAndMergeHeaders would be safe to change based on the usages I see in the core repo, but yes generally changing an access modifier could be breaking.

That's the thing, thank you

private final ThreadContext threadContext;

private InternalThreadContextWrapper(final ThreadContext threadContext) {
this.threadContext = threadContext;
}

public static InternalThreadContextWrapper from(ThreadContext threadContext) {
return new InternalThreadContextWrapper(threadContext);
}

public void markAsSystemContext() {
Objects.requireNonNull(threadContext, "threadContext cannot be null");
threadContext.markAsSystemContext();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ boolean isDefaultContext() {
* Marks this thread context as an internal system context. This signals that actions in this context are issued
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
void markAsSystemContext() {
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public GlobalCheckpointSyncAction(
public void updateGlobalCheckpointForShard(final ShardId shardId) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
execute(new Request(shardId), ActionListener.wrap(r -> {}, e -> {
if (ExceptionsHelper.unwrap(e, AlreadyClosedException.class, IndexShardClosedException.class) == null) {
logger.info(new ParameterizedMessage("{} global checkpoint sync failed", shardId), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ final void backgroundSync(ShardId shardId, String primaryAllocationId, long prim
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
final Request request = new Request(shardId, retentionLeases);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_background_sync", request);
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ final void sync(
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
final Request request = new Request(shardId, retentionLeases);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_sync", request);
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ final void publish(IndexShard indexShard, ReplicationCheckpoint checkpoint) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
PublishCheckpointRequest request = new PublishCheckpointRequest(checkpoint);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "segrep_publish_checkpoint", request);
final ReplicationTimer timer = new ReplicationTimer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -71,6 +72,7 @@ final class RemoteClusterConnection implements Closeable {
private final RemoteConnectionStrategy connectionStrategy;
private final String clusterAlias;
private final ThreadPool threadPool;
private final InternalThreadContextWrapper tcWrapper;
private volatile boolean skipUnavailable;
private final TimeValue initialConnectionTimeout;

Expand All @@ -91,6 +93,7 @@ final class RemoteClusterConnection implements Closeable {
this.skipUnavailable = RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE.getConcreteSettingForNamespace(clusterAlias)
.get(settings);
this.threadPool = transportService.threadPool;
this.tcWrapper = InternalThreadContextWrapper.from(transportService.threadPool.getThreadContext());
initialConnectionTimeout = RemoteClusterService.REMOTE_INITIAL_CONNECTION_TIMEOUT_SETTING.get(settings);
}

Expand Down Expand Up @@ -136,7 +139,7 @@ void collectNodes(ActionListener<Function<String, DiscoveryNode>> listener) {
new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();

final ClusterStateRequest request = new ClusterStateRequest();
request.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -160,6 +161,7 @@ public Writeable.Reader<RemoteConnectionInfo.ModeInfo> getReader() {

protected final TransportService transportService;
protected final RemoteConnectionManager connectionManager;
protected final InternalThreadContextWrapper tcWrapper;
protected final String clusterAlias;

RemoteConnectionStrategy(
Expand All @@ -170,6 +172,7 @@ public Writeable.Reader<RemoteConnectionInfo.ModeInfo> getReader() {
) {
this.clusterAlias = clusterAlias;
this.transportService = transportService;
this.tcWrapper = InternalThreadContextWrapper.from(transportService.getThreadPool().getThreadContext());
this.connectionManager = connectionManager;
this.maxPendingConnectionListeners = REMOTE_MAX_PENDING_CONNECTION_LISTENERS.get(settings);
connectionManager.addListener(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ private void collectRemoteNodes(Iterator<Supplier<DiscoveryNode>> seedNodes, Act
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any
// existing context information.
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
transportService.sendRequest(
connection,
ClusterStateAction.NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
Expand Down Expand Up @@ -224,8 +225,9 @@ public void testUpdateTemplates() {

service.upgradesInProgress.set(additionsCount + deletionsCount + 2); // +2 to skip tryFinishUpgrade
final ThreadContext threadContext = threadPool.getThreadContext();
final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
service.upgradeTemplates(additions, deletions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContext.StoredContext;
import org.opensearch.telemetry.Telemetry;
Expand Down Expand Up @@ -256,11 +257,12 @@ public void run() {
public void testSpanNotPropagatedToChildSystemThreadContext() {
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));

final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext);
try (SpanScope scope = tracer.withSpanInScope(span)) {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class RemoteConnectionStrategyTests extends OpenSearchTestCase {

public void testStrategyChangeMeansThatStrategyMustBeRebuilt() {
ClusterConnectionManager connectionManager = new ClusterConnectionManager(Settings.EMPTY, mock(Transport.class));
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager);
TransportService mockTransportService = mock(TransportService.class);
when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class));
FakeConnectionStrategy first = new FakeConnectionStrategy(
"cluster-alias",
mock(TransportService.class),
mockTransportService,
remoteConnectionManager,
RemoteConnectionStrategy.ConnectionStrategy.PROXY
);
Expand All @@ -60,9 +64,11 @@ public void testStrategyChangeMeansThatStrategyMustBeRebuilt() {
public void testSameStrategyChangeMeansThatStrategyDoesNotNeedToBeRebuilt() {
ClusterConnectionManager connectionManager = new ClusterConnectionManager(Settings.EMPTY, mock(Transport.class));
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager);
TransportService mockTransportService = mock(TransportService.class);
when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class));
FakeConnectionStrategy first = new FakeConnectionStrategy(
"cluster-alias",
mock(TransportService.class),
mockTransportService,
remoteConnectionManager,
RemoteConnectionStrategy.ConnectionStrategy.PROXY
);
Expand All @@ -78,9 +84,11 @@ public void testChangeInConnectionProfileMeansTheStrategyMustBeRebuilt() {
assertEquals(TimeValue.MINUS_ONE, connectionManager.getConnectionProfile().getPingInterval());
assertEquals(false, connectionManager.getConnectionProfile().getCompressionEnabled());
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager);
TransportService mockTransportService = mock(TransportService.class);
when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class));
FakeConnectionStrategy first = new FakeConnectionStrategy(
"cluster-alias",
mock(TransportService.class),
mockTransportService,
remoteConnectionManager,
RemoteConnectionStrategy.ConnectionStrategy.PROXY
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -133,8 +134,9 @@ public void run() {
taskInProgress = true;
scheduledNextTask = false;
final ThreadContext threadContext = threadPool.getThreadContext();
final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext);
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
task.run();
}
if (waitForPublish == false) {
Expand Down
Loading