Skip to content

Commit

Permalink
[ML] Batch the chunks (elastic#115477)
Browse files Browse the repository at this point in the history
Models running on an ml node have a queue of requests, when that queue 
is full new requests are rejected. A large document can chunk into hundreds 
of requests and in extreme cases a single large document can overflow the 
queue. Avoid this by batches of chunks keeping certain number of requests
in flight.
  • Loading branch information
davidkyle committed Nov 12, 2024
1 parent 79ceabf commit 63b58a5
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;
Expand Down Expand Up @@ -680,25 +681,13 @@ public void chunkedInfer(
esModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

for (var batch : batchedRequests) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);

ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
);

var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
);

client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
if (batchedRequests.isEmpty()) {
listener.onResponse(List.of());
} else {
// Avoid filling the inference queue by executing the batches in series
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
sequentialRunner.run();
}
} else {
listener.onFailure(notElasticsearchModelException(model));
Expand Down Expand Up @@ -1017,6 +1006,82 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
}
}

/**
* Iterates over the batch executing a limited number requests at a time to avoid
* filling the ML node inference queue.
*
* First, a single request is executed, which can also trigger deploying a model
* if necessary. When this request is successfully executed, a callback executes
* N requests in parallel next. Each of these requests also has a callback that
* executes one more request, so that at all time N requests are in-flight. This
* continues until all requests are executed.
*/
class BatchIterator {
private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200

private final AtomicInteger index = new AtomicInteger();
private final ElasticsearchInternalModel esModel;
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
private final InputType inputType;
private final TimeValue timeout;

BatchIterator(
ElasticsearchInternalModel esModel,
InputType inputType,
TimeValue timeout,
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
) {
this.esModel = esModel;
this.requestAndListeners = requestAndListeners;
this.inputType = inputType;
this.timeout = timeout;
}

void run() {
// The first request may deploy the model, and upon completion runs
// NUM_REQUESTS_INFLIGHT in parallel.
inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true));
}

private void inferBatch(int runAfterCount, boolean maybeDeploy) {
int batchIndex = index.getAndIncrement();
if (batchIndex >= requestAndListeners.size()) {
return;
}
executeRequest(batchIndex, maybeDeploy, () -> {
for (int i = 0; i < runAfterCount; i++) {
// Subsequent requests may not deploy the model, because the first request
// already did so. Upon completion, it runs one more request.
inferenceExecutor.execute(() -> inferBatch(1, false));
}
});
}

private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) {
EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex);
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);
logger.trace("Executing batch index={}", batchIndex);

ActionListener<InferModelAction.Response> listener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
);
if (runAfter != null) {
listener = ActionListener.runAfter(listener, runAfter);
}
if (maybeDeploy) {
listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l));
}
client.execute(InferModelAction.INSTANCE, inferenceRequest, listener);
}
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,25 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testBlankInput() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
}

public void testShortInputsAreSingleBatch() {
String input = "one chunk";
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
Expand Down
Loading

0 comments on commit 63b58a5

Please sign in to comment.