Skip to content

Commit

Permalink
Bigger batches
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 25, 2024
1 parent 0bcdd9c commit 3ac9e8c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1016,10 +1016,30 @@ class BatchIterator {
}

void run() {
inferenceExecutor.execute(() -> inferBatchAndRunAfter(requestAndListeners.get(index.get())));
inferenceExecutor.execute(this::inferBatchAndRunAfter);
}

private void inferBatchAndRunAfter(EmbeddingRequestChunker.BatchRequestAndListener batch) {
private void inferBatchAndRunAfter() {
int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200
int requestCount = 0;
// loop does not include the final request
while (requestCount < NUM_REQUESTS_INFLIGHT - 1 && index.get() < requestAndListeners.size() - 1) {

var batch = requestAndListeners.get(index.get());
executeRequest(batch);
requestCount++;
index.incrementAndGet();
}

var batch = requestAndListeners.get(index.get());
executeRequest(batch, () -> {
if (index.incrementAndGet() < requestAndListeners.size()) {
run(); // start the next batch
}
});
}

private void executeRequest(EmbeddingRequestChunker.BatchRequestAndListener batch) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
Expand All @@ -1033,18 +1053,30 @@ private void inferBatchAndRunAfter(EmbeddingRequestChunker.BatchRequestAndListen
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
);

// schedule the next request once the results have been processed
var runNextListener = ActionListener.runAfter(mlResultsListener, () -> {
if (index.incrementAndGet() < requestAndListeners.size()) {
run();
}
});

var maybeDeployListener = runNextListener.delegateResponse(
var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)
);

client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
}

private void executeRequest(EmbeddingRequestChunker.BatchRequestAndListener batch, Runnable runAfter) {
var inferenceRequest = buildInferenceRequest(
esModel.mlNodeDeploymentId(),
EmptyConfigUpdate.INSTANCE,
batch.batch().inputs(),
inputType,
timeout
);

ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
.delegateFailureAndWrap(
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
);

// schedule the next request once the results have been processed
var runNextListener = ActionListener.runAfter(mlResultsListener, runAfter);
client.execute(InferModelAction.INSTANCE, inferenceRequest, runNextListener);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,8 @@ public void testChunkingLargeDocument() throws InterruptedException {
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());

int wordsPerChunk = 10;
int numBatches = randomIntBetween(3, 6);
int numBatches = 3;
randomIntBetween(3, 6);
int numChunks = randomIntBetween(
((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1,
numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE
Expand All @@ -1291,6 +1292,9 @@ public void testChunkingLargeDocument() throws InterruptedException {
numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
}
numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
if (numResponsesPerBatch[numBatches - 1] == 0) {
numResponsesPerBatch[numBatches - 1] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
}

var batchIndex = new AtomicInteger();
Client client = mock(Client.class);
Expand Down Expand Up @@ -1347,7 +1351,7 @@ public void testChunkingLargeDocument() throws InterruptedException {
);

latch.await();
assertTrue("Listener not called", gotResults.get());
assertTrue("Listener not called with results", gotResults.get());
}

public void testParsePersistedConfig_Rerank() {
Expand Down

0 comments on commit 3ac9e8c

Please sign in to comment.