From c2e4afcfd584fe35aa88a9b9840cf5ff4c3c80b6 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 27 Nov 2024 13:23:20 -0800 Subject: [PATCH] Try to finish remote sink once (#117592) Currently, we have three clients fetching pages by default, each with its own lifecycle. This can result in scenarios where more than one request is sent to complete the remote sink. While this does not cause correctness issues, it is inefficient, especially for cross-cluster requests. This change tracks the status of the remote sink and tries to send only one finish request per remote sink. --- .../operator/exchange/ExchangeService.java | 28 +++++++++++++++++++ .../exchange/ExchangeServiceTests.java | 9 ++++++ 2 files changed, 37 insertions(+) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index d633270b5c595..a943a90d02e87 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -42,6 +42,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; /** @@ -292,6 +293,7 @@ static final class TransportRemoteSink implements RemoteSink { final Executor responseExecutor; final AtomicLong estimatedPageSizeInBytes = new AtomicLong(0L); + final AtomicBoolean finished = new AtomicBoolean(false); TransportRemoteSink( TransportService transportService, @@ -311,6 +313,32 @@ static final class TransportRemoteSink implements RemoteSink { @Override public void fetchPageAsync(boolean allSourcesFinished, ActionListener listener) { + if (allSourcesFinished) { + if (finished.compareAndSet(false, true)) { + doFetchPageAsync(true, listener); + } else { + // already finished or promised + listener.onResponse(new ExchangeResponse(blockFactory, null, true)); + } + } else { + // already finished + if (finished.get()) { + listener.onResponse(new ExchangeResponse(blockFactory, null, true)); + return; + } + doFetchPageAsync(false, ActionListener.wrap(r -> { + if (r.finished()) { + finished.set(true); + } + listener.onResponse(r); + }, e -> { + finished.set(true); + listener.onFailure(e); + })); + } + } + + private void doFetchPageAsync(boolean allSourcesFinished, ActionListener listener) { final long reservedBytes = allSourcesFinished ? 0 : estimatedPageSizeInBytes.get(); if (reservedBytes > 0) { // This doesn't fully protect ESQL from OOM, but reduces the likelihood. diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index 8949f61b7420d..4178f02898d79 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -449,6 +449,15 @@ public void testConcurrentWithTransportActions() { ExchangeService exchange1 = new ExchangeService(Settings.EMPTY, threadPool, ESQL_TEST_EXECUTOR, blockFactory()); exchange1.registerTransportHandler(node1); AbstractSimpleTransportTestCase.connectToNode(node0, node1.getLocalNode()); + Set finishingRequests = ConcurrentCollections.newConcurrentSet(); + node1.addRequestHandlingBehavior(ExchangeService.EXCHANGE_ACTION_NAME, (handler, request, channel, task) -> { + final ExchangeRequest exchangeRequest = (ExchangeRequest) request; + if (exchangeRequest.sourcesFinished()) { + String exchangeId = exchangeRequest.exchangeId(); + assertTrue("tried to finish [" + exchangeId + "] twice", finishingRequests.add(exchangeId)); + } + handler.messageReceived(request, channel, task); + }); try (exchange0; exchange1; node0; node1) { String exchangeId = "exchange";