Skip to content

Commit

Permalink
Use ConcurrentHashMap for batching tasks per executor in TaskBatcher … (
Browse files Browse the repository at this point in the history
#5827)

* Use ConcurrentHashMap for batching tasks per executor in TaskBatcher (#5099)

Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
amkhar authored Mar 7, 2023
1 parent 9a763e8 commit 30e4e5e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand All @@ -61,7 +62,7 @@ public abstract class TaskBatcher {
private final Logger logger;
private final PrioritizedOpenSearchThreadPoolExecutor threadExecutor;
// package visible for tests
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new HashMap<>();
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new ConcurrentHashMap<>();
private final TaskBatcherListener taskBatcherListener;

public TaskBatcher(Logger logger, PrioritizedOpenSearchThreadPoolExecutor threadExecutor, TaskBatcherListener taskBatcherListener) {
Expand Down Expand Up @@ -93,12 +94,8 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
IdentityHashMap::new
)
);

synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.computeIfAbsent(
firstTask.batchingKey,
k -> new LinkedHashSet<>(tasks.size())
);
LinkedHashSet<BatchedTask> newTasks = new LinkedHashSet<>(tasks);
tasksPerBatchingKey.merge(firstTask.batchingKey, newTasks, (existingTasks, updatedTasks) -> {
for (BatchedTask existing : existingTasks) {
// check that there won't be two tasks with the same identity for the same batching key
BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
Expand All @@ -112,8 +109,9 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
);
}
}
existingTasks.addAll(tasks);
}
existingTasks.addAll(updatedTasks);
return existingTasks;
});
} catch (Exception e) {
taskBatcherListener.onSubmitFailure(tasks);
throw e;
Expand All @@ -139,15 +137,13 @@ private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue time
Object batchingKey = firstTask.batchingKey;
assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey)
: "tasks submitted in a batch should share the same batching key: " + tasks;
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.get(batchingKey);
if (existingTasks != null) {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
tasksPerBatchingKey.remove(batchingKey);
}
tasksPerBatchingKey.computeIfPresent(batchingKey, (tasksKey, existingTasks) -> {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
return null;
}
}
return existingTasks;
});
taskBatcherListener.onTimeout(toRemove);
onTimeout(toRemove, timeout);
}
Expand All @@ -165,17 +161,15 @@ void runIfNotProcessed(BatchedTask updateTask) {
if (updateTask.processed.get() == false) {
final List<BatchedTask> toExecute = new ArrayList<>();
final Map<String, List<BatchedTask>> processTasksBySource = new HashMap<>();
synchronized (tasksPerBatchingKey) {
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
for (BatchedTask task : pending) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
toExecute.add(task);
processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task);
} else {
logger.trace("skipping {}, already processed", task);
}
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
for (BatchedTask task : pending) {
if (task.processed.getAndSet(true) == false) {
logger.trace("will process {}", task);
toExecute.add(task);
processTasksBySource.computeIfAbsent(task.source, s -> new ArrayList<>()).add(task);
} else {
logger.trace("skipping {}, already processed", task);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,87 @@ public void processed(String source) {
}
}

public void testNoTasksAreDroppedInParallelSubmission() throws BrokenBarrierException, InterruptedException {
int numberOfThreads = randomIntBetween(2, 8);
TaskExecutor[] executors = new TaskExecutor[numberOfThreads];
for (int i = 0; i < numberOfThreads; i++) {
executors[i] = new TaskExecutor();
}

int tasksSubmittedPerThread = randomIntBetween(2, 1024);

CopyOnWriteArrayList<Tuple<String, Throwable>> failures = new CopyOnWriteArrayList<>();
CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread);

final TestListener listener = new TestListener() {
@Override
public void onFailure(String source, Exception e) {
logger.error(() -> new ParameterizedMessage("unexpected failure: [{}]", source), e);
failures.add(new Tuple<>(source, e));
updateLatch.countDown();
}

@Override
public void processed(String source) {
updateLatch.countDown();
}
};

CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);

for (int i = 0; i < numberOfThreads; i++) {
final int index = i;
Thread thread = new Thread(() -> {
try {
barrier.await();
CyclicBarrier tasksBarrier = new CyclicBarrier(1 + tasksSubmittedPerThread);
for (int j = 0; j < tasksSubmittedPerThread; j++) {
int taskNumber = j;
Thread taskThread = new Thread(() -> {
try {
tasksBarrier.await();
submitTask(
"[" + index + "][" + taskNumber + "]",
taskNumber,
ClusterStateTaskConfig.build(randomFrom(Priority.values())),
executors[index],
listener
);
tasksBarrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
throw new AssertionError(e);
}
});
// submit tasks per batchingKey in parallel
taskThread.start();
}
// wait for all task threads to be ready
tasksBarrier.await();
// wait for all task threads to finish
tasksBarrier.await();
barrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
throw new AssertionError(e);
}
});
thread.start();
}

// wait for all executor threads to be ready
barrier.await();
// wait for all executor threads to finish
barrier.await();

updateLatch.await();

assertThat(failures, empty());

for (int i = 0; i < numberOfThreads; i++) {
// assert that total executed tasks is same for every executor as we initiated
assertEquals(tasksSubmittedPerThread, executors[i].tasks.size());
}
}

public void testSingleBatchSubmission() throws InterruptedException {
Map<Integer, TestListener> tasks = new HashMap<>();
final int numOfTasks = randomInt(10);
Expand Down

0 comments on commit 30e4e5e

Please sign in to comment.