Skip to content

Commit

Permalink
stash thread context before running forward action (#1904) (#1907)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit a14521d)

Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Jan 23, 2024
1 parent 7375ccc commit dedaefc
Showing 1 changed file with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -161,13 +162,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
.build();
MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);

transportService
.sendRequest(
getNodeById(coordinatingNodeId),
MLForwardAction.NAME,
deployModelDoneMessage,
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
transportService
.sendRequest(
getNodeById(coordinatingNodeId),
MLForwardAction.NAME,
deployModelDoneMessage,
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
);
}
}, e -> {
MLForwardInput mlForwardInput = MLForwardInput
.builder()
Expand All @@ -179,13 +182,15 @@ private MLDeployModelNodeResponse createDeployModelNodeResponse(MLDeployModelNod
.build();
MLForwardRequest deployModelDoneMessage = new MLForwardRequest(mlForwardInput);

transportService
.sendRequest(
getNodeById(coordinatingNodeId),
MLForwardAction.NAME,
deployModelDoneMessage,
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
transportService
.sendRequest(
getNodeById(coordinatingNodeId),
MLForwardAction.NAME,
deployModelDoneMessage,
new ActionListenerResponseHandler<>(taskDoneListener, MLForwardResponse::new)
);
}
})
);

Expand Down

0 comments on commit dedaefc

Please sign in to comment.