diff --git a/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java b/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java index b5710bab41172..686169f81e837 100644 --- a/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java +++ b/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java @@ -46,6 +46,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.stream.Collectors; @@ -61,7 +62,7 @@ public abstract class TaskBatcher { private final Logger logger; private final PrioritizedOpenSearchThreadPoolExecutor threadExecutor; // package visible for tests - final Map> tasksPerBatchingKey = new HashMap<>(); + final Map> tasksPerBatchingKey = new ConcurrentHashMap<>(); private final TaskBatcherListener taskBatcherListener; public TaskBatcher(Logger logger, PrioritizedOpenSearchThreadPoolExecutor threadExecutor, TaskBatcherListener taskBatcherListener) { @@ -93,12 +94,8 @@ public void submitTasks(List tasks, @Nullable TimeValue t IdentityHashMap::new ) ); - - synchronized (tasksPerBatchingKey) { - LinkedHashSet existingTasks = tasksPerBatchingKey.computeIfAbsent( - firstTask.batchingKey, - k -> new LinkedHashSet<>(tasks.size()) - ); + LinkedHashSet newTasks = new LinkedHashSet<>(tasks); + tasksPerBatchingKey.merge(firstTask.batchingKey, newTasks, (existingTasks, updatedTasks) -> { for (BatchedTask existing : existingTasks) { // check that there won't be two tasks with the same identity for the same batching key BatchedTask duplicateTask = tasksIdentity.get(existing.getTask()); @@ -112,8 +109,9 @@ public void submitTasks(List tasks, @Nullable TimeValue t ); } } - existingTasks.addAll(tasks); - } + existingTasks.addAll(updatedTasks); + return existingTasks; + }); } catch (Exception e) { taskBatcherListener.onSubmitFailure(tasks); throw e; @@ -139,15 +137,13 @@ private void onTimeoutInternal(List tasks, TimeValue time Object batchingKey = firstTask.batchingKey; assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) : "tasks submitted in a batch should share the same batching key: " + tasks; - synchronized (tasksPerBatchingKey) { - LinkedHashSet existingTasks = tasksPerBatchingKey.get(batchingKey); - if (existingTasks != null) { - existingTasks.removeAll(toRemove); - if (existingTasks.isEmpty()) { - tasksPerBatchingKey.remove(batchingKey); - } + tasksPerBatchingKey.computeIfPresent(batchingKey, (tasksKey, existingTasks) -> { + existingTasks.removeAll(toRemove); + if (existingTasks.isEmpty()) { + return null; } - } + return existingTasks; + }); taskBatcherListener.onTimeout(toRemove); onTimeout(toRemove, timeout); } @@ -165,17 +161,15 @@ void runIfNotProcessed(BatchedTask updateTask) { if (updateTask.processed.get() == false) { final List toExecute = new ArrayList<>(); final Map> processTasksBySource = new HashMap<>(); - synchronized (tasksPerBatchingKey) { - LinkedHashSet pending = tasksPerBatchingKey.remove(updateTask.batchingKey); - if (pending != null) { - for (BatchedTask task : pending) { - if (task.processed.getAndSet(true) == false) { - logger.trace("will process {}", task); - toExecute.add(task); - processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task); - } else { - logger.trace("skipping {}, already processed", task); - } + LinkedHashSet pending = tasksPerBatchingKey.remove(updateTask.batchingKey); + if (pending != null) { + for (BatchedTask task : pending) { + if (task.processed.getAndSet(true) == false) { + logger.trace("will process {}", task); + toExecute.add(task); + processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task); + } else { + logger.trace("skipping {}, already processed", task); } } } diff --git a/server/src/test/java/org/opensearch/cluster/service/TaskBatcherTests.java b/server/src/test/java/org/opensearch/cluster/service/TaskBatcherTests.java index 31018d4cef029..b59b70ca60ef8 100644 --- a/server/src/test/java/org/opensearch/cluster/service/TaskBatcherTests.java +++ b/server/src/test/java/org/opensearch/cluster/service/TaskBatcherTests.java @@ -279,6 +279,87 @@ public void processed(String source) { } } + public void testNoTasksAreDroppedInParallelSubmission() throws BrokenBarrierException, InterruptedException { + int numberOfThreads = randomIntBetween(2, 8); + TaskExecutor[] executors = new TaskExecutor[numberOfThreads]; + for (int i = 0; i < numberOfThreads; i++) { + executors[i] = new TaskExecutor(); + } + + int tasksSubmittedPerThread = randomIntBetween(2, 1024); + + CopyOnWriteArrayList> failures = new CopyOnWriteArrayList<>(); + CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread); + + final TestListener listener = new TestListener() { + @Override + public void onFailure(String source, Exception e) { + logger.error(() -> new ParameterizedMessage("unexpected failure: [{}]", source), e); + failures.add(new Tuple<>(source, e)); + updateLatch.countDown(); + } + + @Override + public void processed(String source) { + updateLatch.countDown(); + } + }; + + CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); + + for (int i = 0; i < numberOfThreads; i++) { + final int index = i; + Thread thread = new Thread(() -> { + try { + barrier.await(); + CyclicBarrier tasksBarrier = new CyclicBarrier(1 + tasksSubmittedPerThread); + for (int j = 0; j < tasksSubmittedPerThread; j++) { + int taskNumber = j; + Thread taskThread = new Thread(() -> { + try { + tasksBarrier.await(); + submitTask( + "[" + index + "][" + taskNumber + "]", + taskNumber, + ClusterStateTaskConfig.build(randomFrom(Priority.values())), + executors[index], + listener + ); + tasksBarrier.await(); + } catch (InterruptedException | BrokenBarrierException e) { + throw new AssertionError(e); + } + }); + // submit tasks per batchingKey in parallel + taskThread.start(); + } + // wait for all task threads to be ready + tasksBarrier.await(); + // wait for all task threads to finish + tasksBarrier.await(); + barrier.await(); + } catch (InterruptedException | BrokenBarrierException e) { + throw new AssertionError(e); + } + }); + thread.start(); + } + + // wait for all executor threads to be ready + barrier.await(); + // wait for all executor threads to finish + barrier.await(); + + updateLatch.await(); + + assertThat(failures, empty()); + + for (int i = 0; i < numberOfThreads; i++) { + // assert that total executed tasks is same for every executor as we initiated + assertEquals(tasksSubmittedPerThread, executors[i].tasks.size()); + } + } + public void testSingleBatchSubmission() throws InterruptedException { Map tasks = new HashMap<>(); final int numOfTasks = randomInt(10);