diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 6732e5719b897..3a5983e2e3549 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -62,6 +62,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; @@ -668,25 +669,13 @@ public void chunkedInfer( ).batchRequestsWithListeners(listener); } - for (var batch : batchedRequests) { - var inferenceRequest = buildInferenceRequest( - esModel.mlNodeDeploymentId(), - EmptyConfigUpdate.INSTANCE, - batch.batch().inputs(), - inputType, - timeout - ); - - ActionListener mlResultsListener = batch.listener() - .delegateFailureAndWrap( - (l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l) - ); - - var maybeDeployListener = mlResultsListener.delegateResponse( - (l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener) - ); - - client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener); + if (batchedRequests.isEmpty()) { + listener.onResponse(List.of()); + } else { + // Avoid filling the inference queue by executing the batches in series + // Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request + var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests); + sequentialRunner.doNextRequest(); } } else { listener.onFailure(notElasticsearchModelException(model)); @@ -1004,4 +993,58 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) { return null; } } + + // Iterates over the batch sending 1 request at a time to avoid + // filling the ml node inference queue. + class BatchIterator { + private final AtomicInteger index = new AtomicInteger(); + private final ElasticsearchInternalModel esModel; + private final List requestAndListeners; + private final InputType inputType; + private final TimeValue timeout; + + BatchIterator( + ElasticsearchInternalModel esModel, + InputType inputType, + TimeValue timeout, + List requestAndListeners + ) { + this.esModel = esModel; + this.requestAndListeners = requestAndListeners; + this.inputType = inputType; + this.timeout = timeout; + } + + void doNextRequest() { + inferenceExecutor.execute(() -> inferOnBatch(requestAndListeners.get(index.get()))); + } + + private void inferOnBatch(EmbeddingRequestChunker.BatchRequestAndListener batch) { + var inferenceRequest = buildInferenceRequest( + esModel.mlNodeDeploymentId(), + EmptyConfigUpdate.INSTANCE, + batch.batch().inputs(), + inputType, + timeout + ); + + ActionListener mlResultsListener = batch.listener() + .delegateFailureAndWrap( + (l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l) + ); + + // schedule the next request once the results have been processed + var scheduleNextListener = ActionListener.runAfter(mlResultsListener, () -> { + if (index.incrementAndGet() < requestAndListeners.size()) { + doNextRequest(); + } + }); + + var maybeDeployListener = scheduleNextListener.delegateResponse( + (l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, scheduleNextListener) + ); + + client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index c1be537a6b0a7..4fdf254101d3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -24,12 +24,25 @@ import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; public class EmbeddingRequestChunkerTests extends ESTestCase { + public void testEmptyInput() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); + assertThat(batches, empty()); + } + + public void testBlankInput() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + } + public void testShortInputsAreSingleBatch() { String input = "one chunk"; var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 5ec66687752a8..d835d0624f654 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.service.ClusterService; @@ -60,6 +61,7 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.junit.After; import org.junit.Before; @@ -73,6 +75,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -936,17 +939,17 @@ public void testParsePersistedConfig() { } } - public void testChunkInfer_E5WithNullChunkingSettings() { + public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException { testChunkInfer_e5(null); } - public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled() { + public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException { assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_e5(ChunkingSettings chunkingSettings) { + private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); @@ -994,6 +997,9 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -1002,23 +1008,24 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } - public void testChunkInfer_SparseWithNullChunkingSettings() { + public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException { testChunkInfer_Sparse(null); } - public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled() { + public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException { assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { + private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); @@ -1042,6 +1049,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { var service = createService(client); var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); @@ -1061,6 +1069,9 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -1069,23 +1080,24 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } - public void testChunkInfer_ElserWithNullChunkingSettings() { + public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException { testChunkInfer_Elser(null); } - public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled() { + public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException { assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings()); } @SuppressWarnings("unchecked") - private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) { + private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); @@ -1129,6 +1141,9 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -1137,9 +1152,10 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); } @@ -1200,7 +1216,7 @@ public void testChunkInferSetsTokenization() { } @SuppressWarnings("unchecked") - public void testChunkInfer_FailsBatch() { + public void testChunkInfer_FailsBatch() throws InterruptedException { var mlTrainedModelResults = new ArrayList(); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); @@ -1236,6 +1252,9 @@ public void testChunkInfer_FailsBatch() { gotResults.set(true); }, ESTestCase::fail); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + service.chunkedInfer( model, null, @@ -1244,9 +1263,93 @@ public void testChunkInfer_FailsBatch() { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + latchedListener + ); + + latch.await(); + assertTrue("Listener not called", gotResults.get()); + } + + @SuppressWarnings("unchecked") + public void testChunkingLargeDocument() throws InterruptedException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + + int wordsPerChunk = 10; + int numBatches = randomIntBetween(3, 6); + int numChunks = randomIntBetween( + ((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1, + numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE + ); + + // build a doc with enough words to make numChunks of chunks + int numWords = numChunks * wordsPerChunk; + var docBuilder = new StringBuilder(); + for (int i = 0; i < numWords; i++) { + docBuilder.append("word "); + } + + // how many response objects to return in each batch + int[] numResponsesPerBatch = new int[numBatches]; + for (int i = 0; i < numBatches - 1; i++) { + numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE; + } + numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE; + + var batchIndex = new AtomicInteger(); + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + // mock the inference response + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + + var mlTrainedModelResults = new ArrayList(); + for (int i = 0; i < numResponsesPerBatch[batchIndex.get()]; i++) { + mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + } + batchIndex.incrementAndGet(); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var service = createService(client); + + var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { + assertThat(chunkedResponse, hasSize(1)); + assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(sparseResults.chunks(), hasSize(numChunks)); + + gotResults.set(true); + }, ESTestCase::fail); + + // Create model using the word boundary chunker. + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null), + new WordBoundaryChunkingSettings(wordsPerChunk, 0) + ); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener<>(resultsListener, latch); + + // For the given input we know how many requests will be made + service.chunkedInfer( + model, + null, + List.of(docBuilder.toString()), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + latchedListener ); + latch.await(); assertTrue("Listener not called", gotResults.get()); }