Skip to content

Commit

Permalink
Use ConcurrentHashMap for batching tasks per executor in TaskBatcher (o…
Browse files Browse the repository at this point in the history
…pensearch-project#5099)

Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Jan 11, 2023
1 parent 386ce9b commit e8dde1e
Showing 1 changed file with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand All @@ -61,7 +62,7 @@ public abstract class TaskBatcher {
private final Logger logger;
private final PrioritizedOpenSearchThreadPoolExecutor threadExecutor;
// package visible for tests
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new HashMap<>();
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new ConcurrentHashMap<>();
private final TaskBatcherListener taskBatcherListener;

public TaskBatcher(Logger logger, PrioritizedOpenSearchThreadPoolExecutor threadExecutor, TaskBatcherListener taskBatcherListener) {
Expand Down Expand Up @@ -94,11 +95,12 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
)
);

synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.computeIfAbsent(
firstTask.batchingKey,
k -> new LinkedHashSet<>(tasks.size())
);
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.computeIfAbsent(
firstTask.batchingKey,
k -> new LinkedHashSet<>(tasks.size())
);
// Locking on LinkedHashSet is necessary as it is being modified in concurrent manner.
synchronized (existingTasks) {
for (BatchedTask existing : existingTasks) {
// check that there won't be two tasks with the same identity for the same batching key
BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
Expand All @@ -114,6 +116,7 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
}
existingTasks.addAll(tasks);
}

} catch (Exception e) {
taskBatcherListener.onSubmitFailure(tasks);
throw e;
Expand All @@ -139,13 +142,13 @@ private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue time
Object batchingKey = firstTask.batchingKey;
assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey)
: "tasks submitted in a batch should share the same batching key: " + tasks;
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.get(batchingKey);
if (existingTasks != null) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.get(batchingKey);
if (existingTasks != null) {
synchronized (existingTasks) {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
tasksPerBatchingKey.remove(batchingKey);
}
}
if (existingTasks.isEmpty()) {
tasksPerBatchingKey.remove(batchingKey);
}
}
taskBatcherListener.onTimeout(toRemove);
Expand All @@ -165,9 +168,9 @@ void runIfNotProcessed(BatchedTask updateTask) {
if (updateTask.processed.get() == false) {
final List<BatchedTask> toExecute = new ArrayList<>();
final Map<String, List<BatchedTask>> processTasksBySource = new HashMap<>();
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
synchronized (pending) {
for (BatchedTask task : pending) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
Expand Down

0 comments on commit e8dde1e

Please sign in to comment.