diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index 67f58d69a0..bf8c81756b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -22,6 +22,7 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -161,13 +162,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod .build(); MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput); - transportService - .sendRequest( - getNodeById(coordinatingNodeId), - MLForwardAction.NAME, - deployModelDoneMessage, - new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) - ); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + transportService + .sendRequest( + getNodeById(coordinatingNodeId), + MLForwardAction.NAME, + deployModelDoneMessage, + new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) + ); + } }, e -> { MLForwardInput mlForwardInput = MLForwardInput .builder() @@ -179,13 +182,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod .build(); MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput); - transportService - .sendRequest( - getNodeById(coordinatingNodeId), - MLForwardAction.NAME, - deployModelDoneMessage, - new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) - ); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + transportService + .sendRequest( + getNodeById(coordinatingNodeId), + MLForwardAction.NAME, + deployModelDoneMessage, + new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new) + ); + } }) );