diff --git a/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java b/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java index 463e59e877884..4cb0cf2bc5bd5 100644 --- a/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java +++ b/server/src/main/java/org/opensearch/cluster/service/TaskBatcher.java @@ -94,13 +94,8 @@ public void submitTasks(List tasks, @Nullable TimeValue t IdentityHashMap::new ) ); - - LinkedHashSet existingTasks = tasksPerBatchingKey.computeIfAbsent( - firstTask.batchingKey, - k -> new LinkedHashSet<>(tasks.size()) - ); - // Locking on LinkedHashSet is necessary as it is being modified in concurrent manner. - synchronized (existingTasks) { + LinkedHashSet newTasks = new LinkedHashSet<>(tasks); + tasksPerBatchingKey.merge(firstTask.batchingKey, newTasks, (existingTasks, updatedTasks) -> { for (BatchedTask existing : existingTasks) { // check that there won't be two tasks with the same identity for the same batching key BatchedTask duplicateTask = tasksIdentity.get(existing.getTask()); @@ -114,9 +109,9 @@ public void submitTasks(List tasks, @Nullable TimeValue t ); } } - existingTasks.addAll(tasks); - } - + existingTasks.addAll(updatedTasks); + return existingTasks; + }); } catch (Exception e) { taskBatcherListener.onSubmitFailure(tasks); throw e; @@ -142,15 +137,13 @@ private void onTimeoutInternal(List tasks, TimeValue time Object batchingKey = firstTask.batchingKey; assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey) : "tasks submitted in a batch should share the same batching key: " + tasks; - LinkedHashSet existingTasks = tasksPerBatchingKey.get(batchingKey); - if (existingTasks != null) { - synchronized (existingTasks) { - existingTasks.removeAll(toRemove); + tasksPerBatchingKey.computeIfPresent(batchingKey, (tasksKey, currentTasks) -> { + currentTasks.removeAll(toRemove); + if (currentTasks.isEmpty()) { + return null; } - if (existingTasks.isEmpty()) { - tasksPerBatchingKey.remove(batchingKey); - } - } + return currentTasks; + }); taskBatcherListener.onTimeout(toRemove); onTimeout(toRemove, timeout); } @@ -170,15 +163,13 @@ void runIfNotProcessed(BatchedTask updateTask) { final Map> processTasksBySource = new HashMap<>(); LinkedHashSet pending = tasksPerBatchingKey.remove(updateTask.batchingKey); if (pending != null) { - synchronized (pending) { - for (BatchedTask task : pending) { - if (task.processed.getAndSet(true) == false) { - logger.trace("will process {}", task); - toExecute.add(task); - processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task); - } else { - logger.trace("skipping {}, already processed", task); - } + for (BatchedTask task : pending) { + if (task.processed.getAndSet(true) == false) { + logger.trace("will process {}", task); + toExecute.add(task); + processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task); + } else { + logger.trace("skipping {}, already processed", task); } } }