From dfa1c877e05d65454194c50352b176ea0d611d06 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 16 Dec 2024 13:26:24 +0000 Subject: [PATCH] [ML] Batch the chunks (#115477) (#116823) Models running on an ml node have a queue of requests, when that queue is full new requests are rejected. A large document can chunk into hundreds of requests and in extreme cases a single large document can overflow the queue. Avoid this by batches of chunks keeping certain number of requests in flight. --- .../ElasticsearchInternalService.java | 103 ++++++++++++--- .../EmbeddingRequestChunkerTests.java | 13 ++ .../ElasticsearchInternalServiceTests.java | 122 ++++++++++++++++-- 3 files changed, 205 insertions(+), 33 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 868b5168df493..43ac87ec56787 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -60,6 +60,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; @@ -656,25 +657,13 @@ public void chunkedInfer( esModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); - for (var batch : batchedRequests) { - var inferenceRequest = buildInferenceRequest( - esModel.mlNodeDeploymentId(), - EmptyConfigUpdate.INSTANCE, - batch.batch().inputs(), - inputType, - timeout - ); - - ActionListener mlResultsListener = batch.listener() - .delegateFailureAndWrap( - (l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l) - ); - - var maybeDeployListener = mlResultsListener.delegateResponse( - (l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener) - ); - - client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener); + if (batchedRequests.isEmpty()) { + listener.onResponse(List.of()); + } else { + // Avoid filling the inference queue by executing the batches in series + // Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request + var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests); + sequentialRunner.run(); } } else { listener.onFailure(notElasticsearchModelException(model)); @@ -990,4 +979,80 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) { return null; } } + + /** + * Iterates over the batch executing a limited number requests at a time to avoid + * filling the ML node inference queue. + * + * First, a single request is executed, which can also trigger deploying a model + * if necessary. When this request is successfully executed, a callback executes + * N requests in parallel next. Each of these requests also has a callback that + * executes one more request, so that at all time N requests are in-flight. This + * continues until all requests are executed. + */ + class BatchIterator { + private static final int NUM_REQUESTS_INFLIGHT = 20; // * batch size = 200 + + private final AtomicInteger index = new AtomicInteger(); + private final ElasticsearchInternalModel esModel; + private final List requestAndListeners; + private final InputType inputType; + private final TimeValue timeout; + + BatchIterator( + ElasticsearchInternalModel esModel, + InputType inputType, + TimeValue timeout, + List requestAndListeners + ) { + this.esModel = esModel; + this.requestAndListeners = requestAndListeners; + this.inputType = inputType; + this.timeout = timeout; + } + + void run() { + // The first request may deploy the model, and upon completion runs + // NUM_REQUESTS_INFLIGHT in parallel. + inferenceExecutor.execute(() -> inferBatch(NUM_REQUESTS_INFLIGHT, true)); + } + + private void inferBatch(int runAfterCount, boolean maybeDeploy) { + int batchIndex = index.getAndIncrement(); + if (batchIndex >= requestAndListeners.size()) { + return; + } + executeRequest(batchIndex, maybeDeploy, () -> { + for (int i = 0; i < runAfterCount; i++) { + // Subsequent requests may not deploy the model, because the first request + // already did so. Upon completion, it runs one more request. + inferenceExecutor.execute(() -> inferBatch(1, false)); + } + }); + } + + private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAfter) { + EmbeddingRequestChunker.BatchRequestAndListener batch = requestAndListeners.get(batchIndex); + var inferenceRequest = buildInferenceRequest( + esModel.mlNodeDeploymentId(), + EmptyConfigUpdate.INSTANCE, + batch.batch().inputs(), + inputType, + timeout + ); + logger.trace("Executing batch index={}", batchIndex); + + ActionListener listener = batch.listener() + .delegateFailureAndWrap( + (l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l) + ); + if (runAfter != null) { + listener = ActionListener.runAfter(listener, runAfter); + } + if (maybeDeploy) { + listener = listener.delegateResponse((l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, l)); + } + client.execute(InferModelAction.INSTANCE, inferenceRequest, listener); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index c1be537a6b0a7..4fdf254101d3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -24,12 +24,25 @@ import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; public class EmbeddingRequestChunkerTests extends ESTestCase { + public void testEmptyInput() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); + assertThat(batches, empty()); + } + + public void testBlankInput() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + } + public void testShortInputsAreSingleBatch() { String input = "one chunk"; var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index d90579831a3da..a055274bf10c9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; @@ -59,6 +60,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.junit.After; import org.junit.Before; @@ -66,12 +68,14 @@ import org.mockito.Mockito; import java.util.ArrayList; +import java.util.Arrays; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -824,16 +828,16 @@ public void testParsePersistedConfig() { } } - public void testChunkInfer_E5WithNullChunkingSettings() { + public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException { testChunkInfer_e5(null); } - public void testChunkInfer_E5ChunkingSettingsSet() { + public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException { testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_e5(ChunkingSettings chunkingSettings) { + private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); @@ -881,6 +885,9 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -889,22 +896,23 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } - public void testChunkInfer_SparseWithNullChunkingSettings() { + public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException { testChunkInfer_Sparse(null); } - public void testChunkInfer_SparseWithChunkingSettingsSet() { + public void testChunkInfer_SparseWithChunkingSettingsSet() throws InterruptedException { testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { + private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); @@ -928,6 +936,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { var service = createService(client); var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); @@ -947,6 +956,9 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -955,22 +967,23 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } - public void testChunkInfer_ElserWithNullChunkingSettings() { + public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException { testChunkInfer_Elser(null); } - public void testChunkInfer_ElserWithChunkingSettingsSet() { + public void testChunkInfer_ElserWithChunkingSettingsSet() throws InterruptedException { testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) { + private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); @@ -1014,6 +1027,9 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -1022,9 +1038,10 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } @@ -1085,7 +1102,7 @@ public void testChunkInferSetsTokenization() { } @SuppressWarnings("unchecked") - public void testChunkInfer_FailsBatch() { + public void testChunkInfer_FailsBatch() throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); @@ -1121,6 +1138,9 @@ public void testChunkInfer_FailsBatch() { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -1129,12 +1149,86 @@ public void testChunkInfer_FailsBatch() { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } + @SuppressWarnings("unchecked") + public void testChunkingLargeDocument() throws InterruptedException { + int numBatches = randomIntBetween(3, 6); + + // how many response objects to return in each batch + int[] numResponsesPerBatch = new int[numBatches]; + for (int i = 0; i < numBatches - 1; i++) { + numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE; + } + numResponsesPerBatch[numBatches - 1] = randomIntBetween(1, ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE); + int numChunks = Arrays.stream(numResponsesPerBatch).sum(); + + // build a doc with enough words to make numChunks of chunks + int wordsPerChunk = 10; + int numWords = numChunks * wordsPerChunk; + var input = "word ".repeat(numWords); + + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + // mock the inference response + doAnswer(invocationOnMock -> { + var request = (InferModelAction.Request) invocationOnMock.getArguments()[1]; + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + var mlTrainedModelResults = new ArrayList(); + for (int i = 0; i < request.numberOfDocuments(); i++) { + mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + } + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var service = createService(client); + + var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { + assertThat(chunkedResponse, hasSize(1)); + assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(sparseResults.chunks(), hasSize(numChunks)); + + gotResults.set(true); + }, ESTestCase::fail); + + // Create model using the word boundary chunker. + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), + new WordBoundaryChunkingSettings(wordsPerChunk, 0) + ); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + + // For the given input we know how many requests will be made + service.chunkedInfer( + model, + null, + List.of(input), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + latchedListener + ); + + latch.await(); + assertTrue("Listener not called with results", gotResults.get()); + } + public void testParsePersistedConfig_Rerank() { // with task settings {