Skip to content

Commit

Permalink
[8.17][ML] Fix timeout ingesting an empty string into a semantic_text…
Browse files Browse the repository at this point in the history
… field (elastic#118746)

Backport of elastic#117840
# Conflicts:
#	x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
  • Loading branch information
davidkyle committed Dec 16, 2024
1 parent 187aecf commit c81f06f
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 5 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117840.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117840
summary: Fix timeout ingesting an empty string into a `semantic_text` field
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
*
* @param input Text to chunk
* @param maxNumberWordsPerChunk Maximum size of the chunk
* @return The input text chunked
* @param includePrecedingSentence Include the previous sentence
* @return The input text offsets
*/
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
var chunks = new ArrayList<String>();
Expand Down Expand Up @@ -154,6 +155,11 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
chunks.add(input.substring(chunkStart));
}

if (chunks.isEmpty()) {
// The input did not chunk, return the entire input
chunks.add(input);
}

return chunks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0");
}

if (input.isEmpty()) {
return List.of();
}

var chunkPositions = new ArrayList<ChunkPosition>();

// This position in the chunk is where the next overlapping chunk will start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,78 @@
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.hamcrest.Matchers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testEmptyInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testWhitespaceInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
}

public void testBlankInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testBlankInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testInputThatDoesNotChunk_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
testListener()
);
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testInputThatDoesNotChunk_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testShortInputsAreSingleBatch() {
String input = "one chunk";
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;

import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;
Expand All @@ -27,6 +28,53 @@

public class SentenceBoundaryChunkerTests extends ESTestCase {

/**
* Utility method for testing.
*/
private List<String> textChunks(
SentenceBoundaryChunker chunker,
String input,
int maxNumberWordsPerChunk,
boolean includePrecedingSentence
) {
return chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence);
}

public void testEmptyString() {
var chunks = textChunks(new SentenceBoundaryChunker(), "", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(""));
}

public void testBlankString() {
var chunks = textChunks(new SentenceBoundaryChunker(), " ", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = textChunks(new SentenceBoundaryChunker(), " b", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" b"));

chunks = textChunks(new SentenceBoundaryChunker(), "b", 100, randomBoolean());
assertThat(chunks, Matchers.contains("b"));

chunks = textChunks(new SentenceBoundaryChunker(), ". ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(". "));

chunks = textChunks(new SentenceBoundaryChunker(), " , ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" , "));

chunks = textChunks(new SentenceBoundaryChunker(), " ,", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = textChunks(new SentenceBoundaryChunker(), input, 100, randomBoolean());
assertThat(chunks, Matchers.contains(input));
}

public void testChunkSplitLargeChunkSizes() {
for (int maxWordsPerChunk : new int[] { 100, 200 }) {
var chunker = new SentenceBoundaryChunker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -226,6 +227,39 @@ public void testWhitespace() {
assertThat(chunks, contains(" "));
}

private List<String> textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) {
return chunker.chunk(input, chunkSize, overlap);
}

public void testBlankString() {
var chunks = textChunks(new WordBoundaryChunker(), " ", 100, 10);
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = textChunks(new WordBoundaryChunker(), " b", 100, 10);
assertThat(chunks, Matchers.contains(" b"));

chunks = textChunks(new WordBoundaryChunker(), "b", 100, 10);
assertThat(chunks, Matchers.contains("b"));

chunks = textChunks(new WordBoundaryChunker(), ". ", 100, 10);
assertThat(chunks, Matchers.contains(". "));

chunks = textChunks(new WordBoundaryChunker(), " , ", 100, 10);
assertThat(chunks, Matchers.contains(" , "));

chunks = textChunks(new WordBoundaryChunker(), " ,", 100, 10);
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = textChunks(new WordBoundaryChunker(), input, 100, 10);
assertThat(chunks, Matchers.contains(input));
}

public void testPunctuation() {
int chunkSize = 1;
var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,22 @@ public void testDeploymentThreadsIncludedInUsage() throws IOException {
}
}

public void testInferEmptyInput() throws IOException {
String modelId = "empty_input";
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId);

Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
request.setJsonEntity("""
{ "docs": [] }
""");

var inferenceResponse = client().performRequest(request);
assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}"));
}

private void putModelDefinition(String modelId) throws IOException {
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
Response.Builder responseBuilder = Response.builder();
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());

if (request.numberOfDocuments() == 0) {
listener.onResponse(responseBuilder.setId(request.getId()).build());
return;
}

if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
responseBuilder.setLicensed(true);
doInfer(task, request, responseBuilder, parentTaskId, listener);
Expand Down

0 comments on commit c81f06f

Please sign in to comment.