From 0ad8abaf4ab854a8c0005e7ab747c0b9822433ca Mon Sep 17 00:00:00 2001 From: David Turner Date: Wed, 20 Sep 2023 13:03:54 +0100 Subject: [PATCH] Fix thread context in getRepositoryData (#99627) Listeners which subscribe to `BlobStoreRepository#repoDataInitialized` are today completed in the thread context of the thread which first triggers the initialization of repository generation tracking, but we must instead capture each listener's own thread context to avoid cross-context pollution. --- docs/changelog/99627.yaml | 5 + .../blobstore/BlobStoreRepository.java | 245 ++++++++++-------- .../blobstore/BlobStoreRepositoryTests.java | 62 ++++- .../SingleResultDeduplicatorTests.java | 51 ++++ 4 files changed, 241 insertions(+), 122 deletions(-) create mode 100644 docs/changelog/99627.yaml diff --git a/docs/changelog/99627.yaml b/docs/changelog/99627.yaml new file mode 100644 index 0000000000000..84abdf6418dc2 --- /dev/null +++ b/docs/changelog/99627.yaml @@ -0,0 +1,5 @@ +pr: 99627 +summary: Fix thread context in `getRepositoryData` +area: Snapshot/Restore +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index c0ef6581db94b..bfa4cc5be7863 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -26,7 +26,6 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.SingleResultDeduplicator; import org.elasticsearch.action.support.GroupedActionListener; -import org.elasticsearch.action.support.ListenableActionFuture; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.RefCountingRunnable; @@ -68,6 +67,7 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.common.util.concurrent.ListenableFuture; import org.elasticsearch.common.xcontent.ChunkedToXContent; @@ -1788,36 +1788,41 @@ public void getRepositoryData(ActionListener listener) { // master-eligible or not. assert clusterService.localNode().isMasterNode() : "should only load repository data on master nodes"; - if (lifecycle.started() == false) { - listener.onFailure(notStartedException()); - return; - } + while (true) { + // retry loop, in case the state changes underneath us somehow - if (latestKnownRepoGen.get() == RepositoryData.CORRUPTED_REPO_GEN) { - listener.onFailure(corruptedStateException(null, null)); - return; - } - final RepositoryData cached = latestKnownRepositoryData.get(); - // Fast path loading repository data directly from cache if we're in fully consistent mode and the cache matches up with - // the latest known repository generation - if (bestEffortConsistency == false && cached.getGenId() == latestKnownRepoGen.get()) { - listener.onResponse(cached); - return; - } - if (metadata.generation() == RepositoryData.UNKNOWN_REPO_GEN && isReadOnly() == false) { - logger.debug( - "[{}] loading repository metadata for the first time, trying to determine correct generation and to store " - + "it in the cluster state", - metadata.name() - ); - initializeRepoGenerationTracking(listener); - } else { - logger.trace( - "[{}] loading un-cached repository data with best known repository generation [{}]", - metadata.name(), - latestKnownRepoGen - ); - repoDataLoadDeduplicator.execute(listener); + if (lifecycle.started() == false) { + listener.onFailure(notStartedException()); + return; + } + + if (latestKnownRepoGen.get() == RepositoryData.CORRUPTED_REPO_GEN) { + listener.onFailure(corruptedStateException(null, null)); + return; + } + final RepositoryData cached = latestKnownRepositoryData.get(); + // Fast path loading repository data directly from cache if we're in fully consistent mode and the cache matches up with + // the latest known repository generation + if (bestEffortConsistency == false && cached.getGenId() == latestKnownRepoGen.get()) { + listener.onResponse(cached); + return; + } + if (metadata.generation() == RepositoryData.UNKNOWN_REPO_GEN && isReadOnly() == false) { + logger.debug(""" + [{}] loading repository metadata for the first time, trying to determine correct generation and to store it in the \ + cluster state""", metadata.name()); + if (initializeRepoGenerationTracking(listener)) { + return; + } // else there was a concurrent modification, retry from the start + } else { + logger.trace( + "[{}] loading un-cached repository data with best known repository generation [{}]", + metadata.name(), + latestKnownRepoGen + ); + repoDataLoadDeduplicator.execute(listener); + return; + } } } @@ -1826,7 +1831,8 @@ private RepositoryException notStartedException() { } // Listener used to ensure that repository data is only initialized once in the cluster state by #initializeRepoGenerationTracking - private ListenableActionFuture repoDataInitialized; + @Nullable // unless we're in the process of initializing repo-generation tracking + private SubscribableListener repoDataInitialized; /** * Method used to set the current repository generation in the cluster state's {@link RepositoryMetadata} to the latest generation that @@ -1835,103 +1841,120 @@ private RepositoryException notStartedException() { * have a consistent view of the {@link RepositoryData} before any data has been written to the repository. * * @param listener listener to resolve with new repository data + * @return {@code true} if this method at least started the initialization process successfully and will eventually complete the + * listener, {@code false} if there was some concurrent state change which prevents us from starting repo generation tracking (typically + * that some other node got there first) and the caller should check again and possibly retry or complete the listener in some other + * way. */ - private void initializeRepoGenerationTracking(ActionListener listener) { + private boolean initializeRepoGenerationTracking(ActionListener listener) { + final SubscribableListener listenerToSubscribe; + final ActionListener listenerToComplete; + synchronized (this) { if (repoDataInitialized == null) { - // double check the generation since we checked it outside the mutex in the caller and it could have changed by a + // double-check the generation since we checked it outside the mutex in the caller and it could have changed by a // concurrent initialization of the repo metadata and just load repository normally in case we already finished the // initialization if (metadata.generation() != RepositoryData.UNKNOWN_REPO_GEN) { - getRepositoryData(listener); - return; + return false; // retry } logger.trace("[{}] initializing repository generation in cluster state", metadata.name()); - repoDataInitialized = new ListenableActionFuture<>(); - repoDataInitialized.addListener(listener); - final Consumer onFailure = e -> { - logger.warn( - () -> format("[%s] Exception when initializing repository generation in cluster state", metadata.name()), - e - ); - final ActionListener existingListener; - synchronized (BlobStoreRepository.this) { - existingListener = repoDataInitialized; - repoDataInitialized = null; + repoDataInitialized = listenerToSubscribe = new SubscribableListener<>(); + listenerToComplete = new ActionListener<>() { + private ActionListener acquireAndClearRepoDataInitialized() { + synchronized (BlobStoreRepository.this) { + assert repoDataInitialized == listenerToSubscribe; + repoDataInitialized = null; + return listenerToSubscribe; + } } - existingListener.onFailure(e); - }; - repoDataLoadDeduplicator.execute( - ActionListener.wrap( - repoData -> submitUnbatchedTask( - "set initial safe repository generation [" + metadata.name() + "][" + repoData.getGenId() + "]", - new ClusterStateUpdateTask() { - @Override - public ClusterState execute(ClusterState currentState) { - RepositoryMetadata metadata = getRepoMetadata(currentState); - // No update to the repository generation should have occurred concurrently in general except - // for - // extreme corner cases like failing over to an older version master node and back to the - // current - // node concurrently - if (metadata.generation() != RepositoryData.UNKNOWN_REPO_GEN) { - throw new RepositoryException( - metadata.name(), - "Found unexpected initialized repo metadata [" + metadata + "]" - ); - } - return ClusterState.builder(currentState) - .metadata( - Metadata.builder(currentState.getMetadata()) - .putCustom( - RepositoriesMetadata.TYPE, - RepositoriesMetadata.get(currentState) - .withUpdatedGeneration(metadata.name(), repoData.getGenId(), repoData.getGenId()) - ) - ) - .build(); - } - @Override - public void onFailure(Exception e) { - onFailure.accept(e); - } + @Override + public void onResponse(RepositoryData repositoryData) { + acquireAndClearRepoDataInitialized().onResponse(repositoryData); + } - @Override - public void clusterStateProcessed(ClusterState oldState, ClusterState newState) { - logger.trace( - "[{}] initialized repository generation in cluster state to [{}]", - metadata.name(), - repoData.getGenId() - ); - // Resolve listeners on generic pool since some callbacks for repository data do additional IO - threadPool.generic().execute(() -> { - final ActionListener existingListener; - synchronized (BlobStoreRepository.this) { - existingListener = repoDataInitialized; - repoDataInitialized = null; - } - existingListener.onResponse(repoData); - logger.trace( - "[{}] called listeners after initializing repository to generation [{}]", - metadata.name(), - repoData.getGenId() - ); - }); - } - } - ), - onFailure - ) - ); + @Override + public void onFailure(Exception e) { + logger.warn( + () -> format("[%s] Exception when initializing repository generation in cluster state", metadata.name()), + e + ); + acquireAndClearRepoDataInitialized().onFailure(e); + } + }; } else { logger.trace( "[{}] waiting for existing initialization of repository metadata generation in cluster state", metadata.name() ); - repoDataInitialized.addListener(listener); - } + listenerToComplete = null; + listenerToSubscribe = repoDataInitialized; + } + } + + if (listenerToComplete != null) { + SubscribableListener + // load the current repository data + .newForked(repoDataLoadDeduplicator::execute) + // write its generation to the cluster state + .andThen( + (l, repoData) -> submitUnbatchedTask( + "set initial safe repository generation [" + metadata.name() + "][" + repoData.getGenId() + "]", + new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) { + return getClusterStateWithUpdatedRepositoryGeneration(currentState, repoData); + } + + @Override + public void onFailure(Exception e) { + l.onFailure(e); + } + + @Override + public void clusterStateProcessed(ClusterState oldState, ClusterState newState) { + l.onResponse(repoData); + } + } + ) + ) + // fork to generic pool since we're on the applier thread and some callbacks for repository data do additional IO + .andThen((l, repoData) -> { + logger.trace("[{}] initialized repository generation in cluster state to [{}]", metadata.name(), repoData.getGenId()); + threadPool.generic().execute(ActionRunnable.supply(ActionListener.runAfter(l, () -> { + logger.trace( + "[{}] called listeners after initializing repository to generation [{}]", + metadata.name(), + repoData.getGenId() + ); + }), () -> repoData)); + }) + // and finally complete the listener + .addListener(listenerToComplete); } + + listenerToSubscribe.addListener(listener, EsExecutors.DIRECT_EXECUTOR_SERVICE, threadPool.getThreadContext()); + return true; + } + + private ClusterState getClusterStateWithUpdatedRepositoryGeneration(ClusterState currentState, RepositoryData repoData) { + // In theory we might have failed over to a different master which initialized the repo and then failed back to this node, so we + // must check the repository generation in the cluster state is still unknown here. + final RepositoryMetadata repoMetadata = getRepoMetadata(currentState); + if (repoMetadata.generation() != RepositoryData.UNKNOWN_REPO_GEN) { + throw new RepositoryException(repoMetadata.name(), "Found unexpected initialized repo metadata [" + repoMetadata + "]"); + } + return ClusterState.builder(currentState) + .metadata( + Metadata.builder(currentState.getMetadata()) + .putCustom( + RepositoriesMetadata.TYPE, + RepositoriesMetadata.get(currentState) + .withUpdatedGeneration(repoMetadata.name(), repoData.getGenId(), repoData.getGenId()) + ) + ) + .build(); } /** diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryTests.java index 78f1e6c46956e..624ad6a9fc7da 100644 --- a/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryTests.java @@ -9,8 +9,10 @@ package org.elasticsearch.repositories.blobstore; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -24,6 +26,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.index.IndexVersion; @@ -36,6 +39,7 @@ import org.elasticsearch.repositories.Repository; import org.elasticsearch.repositories.RepositoryData; import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.RepositoryMissingException; import org.elasticsearch.repositories.ShardGeneration; import org.elasticsearch.repositories.ShardGenerations; import org.elasticsearch.repositories.SnapshotShardContext; @@ -46,6 +50,7 @@ import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.junit.After; import java.io.IOException; import java.nio.file.Path; @@ -55,7 +60,9 @@ import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -75,6 +82,7 @@ public class BlobStoreRepositoryTests extends ESSingleNodeTestCase { static final String REPO_TYPE = "fsLike"; + private static final String TEST_REPO_NAME = "test-repo"; protected Collection> getPlugins() { return Arrays.asList(FsLikeRepoPlugin.class); @@ -106,12 +114,11 @@ protected void assertSnapshotOrGenericThread() { public void testRetrieveSnapshots() throws Exception { final Client client = client(); final Path location = ESIntegTestCase.randomRepoPath(node().settings()); - final String repositoryName = "test-repo"; logger.info("--> creating repository"); AcknowledgedResponse putRepositoryResponse = client.admin() .cluster() - .preparePutRepository(repositoryName) + .preparePutRepository(TEST_REPO_NAME) .setType(REPO_TYPE) .setSettings(Settings.builder().put(node().settings()).put("location", location)) .get(); @@ -131,7 +138,7 @@ public void testRetrieveSnapshots() throws Exception { logger.info("--> create first snapshot"); CreateSnapshotResponse createSnapshotResponse = client.admin() .cluster() - .prepareCreateSnapshot(repositoryName, "test-snap-1") + .prepareCreateSnapshot(TEST_REPO_NAME, "test-snap-1") .setWaitForCompletion(true) .setIndices(indexName) .get(); @@ -140,7 +147,7 @@ public void testRetrieveSnapshots() throws Exception { logger.info("--> create second snapshot"); createSnapshotResponse = client.admin() .cluster() - .prepareCreateSnapshot(repositoryName, "test-snap-2") + .prepareCreateSnapshot(TEST_REPO_NAME, "test-snap-2") .setWaitForCompletion(true) .setIndices(indexName) .get(); @@ -148,7 +155,7 @@ public void testRetrieveSnapshots() throws Exception { logger.info("--> make sure the node's repository can resolve the snapshots"); final RepositoriesService repositoriesService = getInstanceFromNode(RepositoriesService.class); - final BlobStoreRepository repository = (BlobStoreRepository) repositoriesService.repository(repositoryName); + final BlobStoreRepository repository = (BlobStoreRepository) repositoriesService.repository(TEST_REPO_NAME); final List originalSnapshots = Arrays.asList(snapshotId1, snapshotId2); List snapshotIds = ESBlobStoreRepositoryIntegTestCase.getRepositoryData(repository) @@ -255,13 +262,12 @@ public void testRepositoryDataConcurrentModificationNotAllowed() throws Exceptio public void testBadChunksize() throws Exception { final Client client = client(); final Path location = ESIntegTestCase.randomRepoPath(node().settings()); - final String repositoryName = "test-repo"; expectThrows( RepositoryException.class, () -> client.admin() .cluster() - .preparePutRepository(repositoryName) + .preparePutRepository(TEST_REPO_NAME) .setType(REPO_TYPE) .setSettings( Settings.builder() @@ -345,7 +351,6 @@ private static void writeIndexGen(BlobStoreRepository repository, RepositoryData private BlobStoreRepository setupRepo() { final Client client = client(); final Path location = ESIntegTestCase.randomRepoPath(node().settings()); - final String repositoryName = "test-repo"; Settings.Builder repoSettings = Settings.builder().put(node().settings()).put("location", location); boolean compress = randomBoolean(); @@ -354,20 +359,29 @@ private BlobStoreRepository setupRepo() { } AcknowledgedResponse putRepositoryResponse = client.admin() .cluster() - .preparePutRepository(repositoryName) + .preparePutRepository(TEST_REPO_NAME) .setType(REPO_TYPE) .setSettings(repoSettings) .setVerify(false) // prevent eager reading of repo data - .get(); + .get(TimeValue.timeValueSeconds(10)); assertThat(putRepositoryResponse.isAcknowledged(), equalTo(true)); final RepositoriesService repositoriesService = getInstanceFromNode(RepositoriesService.class); - final BlobStoreRepository repository = (BlobStoreRepository) repositoriesService.repository(repositoryName); + final BlobStoreRepository repository = (BlobStoreRepository) repositoriesService.repository(TEST_REPO_NAME); assertThat("getBlobContainer has to be lazy initialized", repository.getBlobContainer(), nullValue()); assertEquals("Compress must be set to", compress, repository.isCompress()); return repository; } + @After + public void removeRepo() { + try { + client().admin().cluster().prepareDeleteRepository(TEST_REPO_NAME).get(TimeValue.timeValueSeconds(10)); + } catch (RepositoryMissingException e) { + // ok, not all tests create the test repo + } + } + private RepositoryData addRandomSnapshotsToRepoData(RepositoryData repoData, boolean inclIndices) { int numSnapshots = randomIntBetween(1, 20); for (int i = 0; i < numSnapshots; i++) { @@ -441,6 +455,32 @@ protected void snapshotFile(SnapshotShardContext context, BlobStoreIndexShardSna listenerCalled.get(); } + public void testGetRepositoryDataThreadContext() { + final var future = new PlainActionFuture(); + try (var listeners = new RefCountingListener(future)) { + final var repo = setupRepo(); + final int threads = between(1, 5); + final var barrier = new CyclicBarrier(threads); + final var headerName = "test-header"; + final var threadPool = client().threadPool(); + final var threadContext = threadPool.getThreadContext(); + for (int i = 0; i < threads; i++) { + final var headerValue = randomAlphaOfLength(10); + try (var ignored = threadContext.stashContext()) { + threadContext.putHeader(headerName, headerValue); + threadPool.generic().execute(ActionRunnable.wrap(listeners.acquire(), l -> { + safeAwait(barrier); + repo.getRepositoryData(l.map(repositoryData -> { + assertEquals(headerValue, threadContext.getHeader(headerName)); + return null; + })); + })); + } + } + } + future.actionGet(10, TimeUnit.SECONDS); + } + private Environment createEnvironment() { Path home = createTempDir(); return TestEnvironment.newEnvironment( diff --git a/server/src/test/java/org/elasticsearch/transport/SingleResultDeduplicatorTests.java b/server/src/test/java/org/elasticsearch/transport/SingleResultDeduplicatorTests.java index 56bfe72241f28..fb4c9df512a5a 100644 --- a/server/src/test/java/org/elasticsearch/transport/SingleResultDeduplicatorTests.java +++ b/server/src/test/java/org/elasticsearch/transport/SingleResultDeduplicatorTests.java @@ -10,10 +10,20 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.SingleResultDeduplicator; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; + +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; public class SingleResultDeduplicatorTests extends ESTestCase { @@ -74,4 +84,45 @@ public void onFailure(Exception e) { assertTrue(called[i]); } } + + public void testThreadContextPreservation() { + final var resources = new Releasable[1]; + try { + final var future = new PlainActionFuture(); + try (var listeners = new RefCountingListener(future)) { + final var threadContext = new ThreadContext(Settings.EMPTY); + final var deduplicator = new SingleResultDeduplicator(threadContext, l -> l.onResponse(null)); + final var threads = between(1, 5); + final var executor = EsExecutors.newFixed( + "test", + threads, + 0, + EsExecutors.daemonThreadFactory("test"), + threadContext, + EsExecutors.TaskTrackingConfig.DO_NOT_TRACK + ); + resources[0] = () -> ThreadPool.terminate(executor, 10, TimeUnit.SECONDS); + final var barrier = new CyclicBarrier(threads); + final var headerName = "test-header"; + for (int i = 0; i < threads; i++) { + try (var ignored = threadContext.stashContext()) { + final var headerValue = randomAlphaOfLength(10); + threadContext.putHeader(headerName, headerValue); + executor.execute( + ActionRunnable.wrap( + listeners.acquire(v -> assertEquals(headerValue, threadContext.getHeader(headerName))), + listener -> { + safeAwait(barrier); + deduplicator.execute(listener); + } + ) + ); + } + } + } + future.actionGet(10, TimeUnit.SECONDS); + } finally { + Releasables.closeExpectNoException(resources); + } + } }