From 8cf0eb498020eb6587d12363acd1225ef6970a17 Mon Sep 17 00:00:00 2001 From: Aman Khare Date: Wed, 11 Jan 2023 12:30:09 +0530 Subject: [PATCH] Use ConcurrentHashMap for batching tasks per executor in TaskBatcher (#5099) Signed-off-by: Aman Khare --- .../cluster/service/TaskBatcher.java | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java b/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java index b5710bab41172..463e59e877884 100644 --- a/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java +++ b/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java @@ -46,6 +46,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.stream.Collectors; @@ -61,7 +62,7 @@ public abstract class TaskBatcher { private final Logger logger; private final PrioritizedOpenSearchThreadPoolExecutor threadExecutor; // package visible for tests - final Map> tasksPerBatchingKey = new HashMap<>(); + final Map> tasksPerBatchingKey = new ConcurrentHashMap<>(); private final TaskBatcherListener taskBatcherListener; public TaskBatcher(Logger logger, PrioritizedOpenSearchThreadPoolExecutor threadExecutor, TaskBatcherListener taskBatcherListener) { @@ -94,11 +95,12 @@ public void submitTasks(List tasks, @Nullable TimeValue t ) ); - synchronized (tasksPerBatchingKey) { - LinkedHashSet existingTasks = tasksPerBatchingKey.computeIfAbsent( - firstTask.batchingKey, - k -> new LinkedHashSet<>(tasks.size()) - ); + LinkedHashSet existingTasks = tasksPerBatchingKey.computeIfAbsent( + firstTask.batchingKey, + k -> new LinkedHashSet<>(tasks.size()) + ); + // Locking on LinkedHashSet is necessary as it is being modified in concurrent manner. + synchronized (existingTasks) { for (BatchedTask existing : existingTasks) { // check that there won't be two tasks with the same identity for the same batching key BatchedTask duplicateTask = tasksIdentity.get(existing.getTask()); @@ -114,6 +116,7 @@ public void submitTasks(List tasks, @Nullable TimeValue t } existingTasks.addAll(tasks); } + } catch (Exception e) { taskBatcherListener.onSubmitFailure(tasks); throw e; @@ -139,13 +142,13 @@ private void onTimeoutInternal(List tasks, TimeValue time Object batchingKey = firstTask.batchingKey; assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) : "tasks submitted in a batch should share the same batching key: " + tasks; - synchronized (tasksPerBatchingKey) { - LinkedHashSet existingTasks = tasksPerBatchingKey.get(batchingKey); - if (existingTasks != null) { + LinkedHashSet existingTasks = tasksPerBatchingKey.get(batchingKey); + if (existingTasks != null) { + synchronized (existingTasks) { existingTasks.removeAll(toRemove); - if (existingTasks.isEmpty()) { - tasksPerBatchingKey.remove(batchingKey); - } + } + if (existingTasks.isEmpty()) { + tasksPerBatchingKey.remove(batchingKey); } } taskBatcherListener.onTimeout(toRemove); @@ -165,9 +168,9 @@ void runIfNotProcessed(BatchedTask updateTask) { if (updateTask.processed.get() == false) { final List toExecute = new ArrayList<>(); final Map> processTasksBySource = new HashMap<>(); - synchronized (tasksPerBatchingKey) { - LinkedHashSet pending = tasksPerBatchingKey.remove(updateTask.batchingKey); - if (pending != null) { + LinkedHashSet pending = tasksPerBatchingKey.remove(updateTask.batchingKey); + if (pending != null) { + synchronized (pending) { for (BatchedTask task : pending) { if (task.processed.getAndSet(true) == false) { logger.trace("will process {}", task);