From c0324b43925b94a188fde56cdf2f53d5bc8e9b71 Mon Sep 17 00:00:00 2001 From: Craig Perkins Date: Wed, 31 Jul 2024 12:19:58 -0400 Subject: [PATCH] Use ThreadContextAccess Signed-off-by: Craig Perkins --- .../client/OriginSettingClient.java | 7 ++++++- .../client/support/AbstractClient.java | 5 ++++- .../util/concurrent/ThreadContextTests.java | 20 +++++++++++++++---- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/opensearch/client/OriginSettingClient.java b/server/src/main/java/org/opensearch/client/OriginSettingClient.java index 1b0e08cc489c4..27d87227df7bc 100644 --- a/server/src/main/java/org/opensearch/client/OriginSettingClient.java +++ b/server/src/main/java/org/opensearch/client/OriginSettingClient.java @@ -36,6 +36,7 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.ContextPreservingActionListener; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.ThreadContextAccess; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; @@ -65,7 +66,11 @@ protected void ActionListener listener ) { final Supplier supplier = in().threadPool().getThreadContext().newRestorableContext(false); - try (ThreadContext.StoredContext ignore = in().threadPool().getThreadContext().stashWithOrigin(origin)) { + try ( + ThreadContext.StoredContext ignore = ThreadContextAccess.doPrivileged( + () -> in().threadPool().getThreadContext().stashWithOrigin(origin) + ) + ) { super.doExecute(action, request, new ContextPreservingActionListener<>(supplier, listener)); } } diff --git a/server/src/main/java/org/opensearch/client/support/AbstractClient.java b/server/src/main/java/org/opensearch/client/support/AbstractClient.java index 6c6049f04231b..509cd732357d6 100644 --- a/server/src/main/java/org/opensearch/client/support/AbstractClient.java +++ b/server/src/main/java/org/opensearch/client/support/AbstractClient.java @@ -416,6 +416,7 @@ import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.ThreadContextAccess; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.bytes.BytesReference; @@ -2148,7 +2149,9 @@ protected void ActionListener listener ) { ThreadContext threadContext = threadPool().getThreadContext(); - try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(headers)) { + try ( + ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged(() -> threadContext.stashAndMergeHeaders(headers)) + ) { super.doExecute(action, request, listener); } } diff --git a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java index 4c7cd4513412d..5992ffa1465b4 100644 --- a/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java @@ -206,7 +206,7 @@ public void testStashWithOrigin() { } assertNull(threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME)); - try (ThreadContext.StoredContext storedContext = threadContext.stashWithOrigin(origin)) { + try (ThreadContext.StoredContext storedContext = ThreadContextAccess.doPrivileged(() -> threadContext.stashWithOrigin(origin))) { assertEquals(origin, threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME)); assertNull(threadContext.getTransient("foo")); assertNull(threadContext.getTransient("bar")); @@ -231,7 +231,7 @@ public void testStashAndMerge() { HashMap toMerge = new HashMap<>(); toMerge.put("foo", "baz"); toMerge.put("simon", "says"); - try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) { + try (ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged(() -> threadContext.stashAndMergeHeaders(toMerge))) { assertEquals("bar", threadContext.getHeader("foo")); assertEquals("says", threadContext.getHeader("simon")); assertNull(threadContext.getTransient("ctx.foo")); @@ -493,7 +493,13 @@ public void testStashAndMergeWithModifiedDefaults() { ThreadContext threadContext = new ThreadContext(build); HashMap toMerge = new HashMap<>(); toMerge.put("default", "2"); - try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) { + ThreadContext finalThreadContext1 = threadContext; + HashMap finalToMerge1 = toMerge; + try ( + ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged( + () -> finalThreadContext1.stashAndMergeHeaders(finalToMerge1) + ) + ) { assertEquals("2", threadContext.getHeader("default")); } @@ -502,7 +508,13 @@ public void testStashAndMergeWithModifiedDefaults() { threadContext.putHeader("default", "4"); toMerge = new HashMap<>(); toMerge.put("default", "2"); - try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) { + ThreadContext finalThreadContext2 = threadContext; + HashMap finalToMerge2 = toMerge; + try ( + ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged( + () -> finalThreadContext2.stashAndMergeHeaders(finalToMerge2) + ) + ) { assertEquals("4", threadContext.getHeader("default")); } }