From c81f06f765a53850c02bc511306dce182f755eaf Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 16 Dec 2024 11:15:56 +0000 Subject: [PATCH] [8.17][ML] Fix timeout ingesting an empty string into a semantic_text field (#118746) Backport of https://github.com/elastic/elasticsearch/pull/117840 # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java --- docs/changelog/117840.yaml | 5 ++ .../chunking/SentenceBoundaryChunker.java | 8 ++- .../chunking/WordBoundaryChunker.java | 4 -- .../EmbeddingRequestChunkerTests.java | 60 +++++++++++++++++++ .../SentenceBoundaryChunkerTests.java | 48 +++++++++++++++ .../chunking/WordBoundaryChunkerTests.java | 34 +++++++++++ .../xpack/ml/integration/PyTorchModelIT.java | 16 +++++ .../TransportInternalInferModelAction.java | 5 ++ 8 files changed, 175 insertions(+), 5 deletions(-) create mode 100644 docs/changelog/117840.yaml diff --git a/docs/changelog/117840.yaml b/docs/changelog/117840.yaml new file mode 100644 index 0000000000000..e1f469643af42 --- /dev/null +++ b/docs/changelog/117840.yaml @@ -0,0 +1,5 @@ +pr: 117840 +summary: Fix timeout ingesting an empty string into a `semantic_text` field +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 5df940d6a3fba..f36dfb694a752 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -63,7 +63,8 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * * @param input Text to chunk * @param maxNumberWordsPerChunk Maximum size of the chunk - * @return The input text chunked + * @param includePrecedingSentence Include the previous sentence + * @return The input text offsets */ public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { var chunks = new ArrayList(); @@ -154,6 +155,11 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl chunks.add(input.substring(chunkStart)); } + if (chunks.isEmpty()) { + // The input did not chunk, return the entire input + chunks.add(input); + } + return chunks; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index c9c752b9aabbc..f73044c0e02fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -104,10 +104,6 @@ List chunkPositions(String input, int chunkSize, int overlap) { throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0"); } - if (input.isEmpty()) { - return List.of(); - } - var chunkPositions = new ArrayList(); // This position in the chunk is where the next overlapping chunk will start diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index c1be537a6b0a7..760882fa2b39d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -18,18 +18,78 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.hamcrest.Matchers; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; public class EmbeddingRequestChunkerTests extends ESTestCase { + public void testEmptyInput_WordChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); + assertThat(batches, empty()); + } + + public void testEmptyInput_SentenceChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); + assertThat(batches, empty()); + } + + public void testWhitespaceInput_SentenceChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" ")); + } + + public void testBlankInput_WordChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); + } + + public void testBlankInput_SentenceChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("")); + } + + public void testInputThatDoesNotChunk_WordChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners( + testListener() + ); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + } + + public void testInputThatDoesNotChunk_SentenceChunker() { + var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); + var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1)) + .batchRequestsWithListeners(testListener()); + assertThat(batches, hasSize(1)); + assertThat(batches.get(0).batch().inputs(), hasSize(1)); + assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA")); + } + public void testShortInputsAreSingleBatch() { String input = "one chunk"; var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index afce8c57e0350..5ec5ee4343c71 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Locale; import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT; @@ -27,6 +28,53 @@ public class SentenceBoundaryChunkerTests extends ESTestCase { + /** + * Utility method for testing. + */ + private List textChunks( + SentenceBoundaryChunker chunker, + String input, + int maxNumberWordsPerChunk, + boolean includePrecedingSentence + ) { + return chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence); + } + + public void testEmptyString() { + var chunks = textChunks(new SentenceBoundaryChunker(), "", 100, randomBoolean()); + assertThat(chunks, hasSize(1)); + assertThat(chunks.get(0), Matchers.is("")); + } + + public void testBlankString() { + var chunks = textChunks(new SentenceBoundaryChunker(), " ", 100, randomBoolean()); + assertThat(chunks, hasSize(1)); + assertThat(chunks.get(0), Matchers.is(" ")); + } + + public void testSingleChar() { + var chunks = textChunks(new SentenceBoundaryChunker(), " b", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(" b")); + + chunks = textChunks(new SentenceBoundaryChunker(), "b", 100, randomBoolean()); + assertThat(chunks, Matchers.contains("b")); + + chunks = textChunks(new SentenceBoundaryChunker(), ". ", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(". ")); + + chunks = textChunks(new SentenceBoundaryChunker(), " , ", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(" , ")); + + chunks = textChunks(new SentenceBoundaryChunker(), " ,", 100, randomBoolean()); + assertThat(chunks, Matchers.contains(" ,")); + } + + public void testSingleCharRepeated() { + var input = "a".repeat(32_000); + var chunks = textChunks(new SentenceBoundaryChunker(), input, 100, randomBoolean()); + assertThat(chunks, Matchers.contains(input)); + } + public void testChunkSplitLargeChunkSizes() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index ef643a4b36fdc..93f951bbfa983 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; import java.util.List; import java.util.Locale; @@ -226,6 +227,39 @@ public void testWhitespace() { assertThat(chunks, contains(" ")); } + private List textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) { + return chunker.chunk(input, chunkSize, overlap); + } + + public void testBlankString() { + var chunks = textChunks(new WordBoundaryChunker(), " ", 100, 10); + assertThat(chunks, hasSize(1)); + assertThat(chunks.get(0), Matchers.is(" ")); + } + + public void testSingleChar() { + var chunks = textChunks(new WordBoundaryChunker(), " b", 100, 10); + assertThat(chunks, Matchers.contains(" b")); + + chunks = textChunks(new WordBoundaryChunker(), "b", 100, 10); + assertThat(chunks, Matchers.contains("b")); + + chunks = textChunks(new WordBoundaryChunker(), ". ", 100, 10); + assertThat(chunks, Matchers.contains(". ")); + + chunks = textChunks(new WordBoundaryChunker(), " , ", 100, 10); + assertThat(chunks, Matchers.contains(" , ")); + + chunks = textChunks(new WordBoundaryChunker(), " ,", 100, 10); + assertThat(chunks, Matchers.contains(" ,")); + } + + public void testSingleCharRepeated() { + var input = "a".repeat(32_000); + var chunks = textChunks(new WordBoundaryChunker(), input, 100, 10); + assertThat(chunks, Matchers.contains(input)); + } + public void testPunctuation() { int chunkSize = 1; var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 4e92cad1026a3..04f349d67d7fe 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -1142,6 +1142,22 @@ public void testDeploymentThreadsIncludedInUsage() throws IOException { } } + public void testInferEmptyInput() throws IOException { + String modelId = "empty_input"; + createPassThroughModel(modelId); + putModelDefinition(modelId); + putVocabulary(List.of("these", "are", "my", "words"), modelId); + startDeployment(modelId); + + Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s"); + request.setJsonEntity(""" + { "docs": [] } + """); + + var inferenceResponse = client().performRequest(request); + assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}")); + } + private void putModelDefinition(String modelId) throws IOException { putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index e0405b1749536..20a4ceeae59b3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -132,6 +132,11 @@ protected void doExecute(Task task, Request request, ActionListener li Response.Builder responseBuilder = Response.builder(); TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId()); + if (request.numberOfDocuments() == 0) { + listener.onResponse(responseBuilder.setId(request.getId()).build()); + return; + } + if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) { responseBuilder.setLicensed(true); doInfer(task, request, responseBuilder, parentTaskId, listener);