diff --git a/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java b/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java index 164e806d04306..cbf1633726b86 100644 --- a/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java +++ b/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java @@ -143,7 +143,7 @@ public abstract class TransportReplicationAction< public static final String REPLICA_ACTION_SUFFIX = "[r]"; protected final ThreadPool threadPool; - protected final InternalThreadContextWrapper tcWrapper; + protected volatile InternalThreadContextWrapper tcWrapper; protected final TransportService transportService; protected final ClusterService clusterService; protected final ShardStateAction shardStateAction; @@ -243,8 +243,6 @@ protected TransportReplicationAction( this.threadPool = threadPool; if (threadPool != null) { this.tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext()); - } else { - this.tcWrapper = InternalThreadContextWrapper.from(transportService.getThreadPool().getThreadContext()); } this.transportService = transportService; this.clusterService = clusterService; diff --git a/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java b/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java index a4ade7f59709b..07428ad29282a 100644 --- a/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java +++ b/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java @@ -172,7 +172,7 @@ public Writeable.Reader getReader() { ) { this.clusterAlias = clusterAlias; this.transportService = transportService; - this.tcWrapper = InternalThreadContextWrapper.from(transportService.threadPool.getThreadContext()); + this.tcWrapper = InternalThreadContextWrapper.from(transportService.getThreadPool().getThreadContext()); this.connectionManager = connectionManager; this.maxPendingConnectionListeners = REMOTE_MAX_PENDING_CONNECTION_LISTENERS.get(settings); connectionManager.addListener(this); diff --git a/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java b/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java index e2acbcff3db16..7d9f5dd2a8ead 100644 --- a/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java +++ b/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java @@ -36,17 +36,21 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class RemoteConnectionStrategyTests extends OpenSearchTestCase { public void testStrategyChangeMeansThatStrategyMustBeRebuilt() { ClusterConnectionManager connectionManager = new ClusterConnectionManager(Settings.EMPTY, mock(Transport.class)); RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager); + TransportService mockTransportService = mock(TransportService.class); + when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class)); FakeConnectionStrategy first = new FakeConnectionStrategy( "cluster-alias", - mock(TransportService.class), + mockTransportService, remoteConnectionManager, RemoteConnectionStrategy.ConnectionStrategy.PROXY ); @@ -60,9 +64,11 @@ public void testStrategyChangeMeansThatStrategyMustBeRebuilt() { public void testSameStrategyChangeMeansThatStrategyDoesNotNeedToBeRebuilt() { ClusterConnectionManager connectionManager = new ClusterConnectionManager(Settings.EMPTY, mock(Transport.class)); RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager); + TransportService mockTransportService = mock(TransportService.class); + when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class)); FakeConnectionStrategy first = new FakeConnectionStrategy( "cluster-alias", - mock(TransportService.class), + mockTransportService, remoteConnectionManager, RemoteConnectionStrategy.ConnectionStrategy.PROXY ); @@ -78,9 +84,11 @@ public void testChangeInConnectionProfileMeansTheStrategyMustBeRebuilt() { assertEquals(TimeValue.MINUS_ONE, connectionManager.getConnectionProfile().getPingInterval()); assertEquals(false, connectionManager.getConnectionProfile().getCompressionEnabled()); RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager); + TransportService mockTransportService = mock(TransportService.class); + when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class)); FakeConnectionStrategy first = new FakeConnectionStrategy( "cluster-alias", - mock(TransportService.class), + mockTransportService, remoteConnectionManager, RemoteConnectionStrategy.ConnectionStrategy.PROXY );