From 88a724a29321fae468db917406757b5dbde4f8d9 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 9 Dec 2024 16:36:12 +0000 Subject: [PATCH] [ML] Refactor the Chunker classes to return offsets (#117977) (#118279) --- .../xpack/inference/chunking/Chunker.java | 4 +- .../chunking/EmbeddingRequestChunker.java | 55 ++++++++++++------- .../chunking/SentenceBoundaryChunker.java | 20 ++++--- .../chunking/WordBoundaryChunker.java | 22 +++----- .../EmbeddingRequestChunkerTests.java | 24 ++++---- .../SentenceBoundaryChunkerTests.java | 34 +++++++++--- .../chunking/WordBoundaryChunkerTests.java | 36 ++++++++---- 7 files changed, 119 insertions(+), 76 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java index af7c706c807ec..b8908ee139c29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java @@ -12,5 +12,7 @@ import java.util.List; public interface Chunker { - List chunk(String input, ChunkingSettings chunkingSettings); + record ChunkOffset(int start, int end) {}; + + List chunk(String input, ChunkingSettings chunkingSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index c5897f32d6eb8..2aef54e56f4b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private final EmbeddingType embeddingType; private final ChunkingSettings chunkingSettings; - private List> chunkedInputs; + private List chunkedOffsets; private List>> floatResults; private List>> byteResults; private List>> sparseResults; @@ -109,7 +109,7 @@ public EmbeddingRequestChunker( } private void splitIntoBatchedRequests(List inputs) { - Function> chunkFunction; + Function> chunkFunction; if (chunkingSettings != null) { var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); chunkFunction = input -> chunker.chunk(input, chunkingSettings); @@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List inputs) { chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap); } - chunkedInputs = new ArrayList<>(inputs.size()); + chunkedOffsets = new ArrayList<>(inputs.size()); switch (embeddingType) { case FLOAT -> floatResults = new ArrayList<>(inputs.size()); case BYTE -> byteResults = new ArrayList<>(inputs.size()); @@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List inputs) { for (int i = 0; i < inputs.size(); i++) { var chunks = chunkFunction.apply(inputs.get(i)); - int numberOfSubBatches = addToBatches(chunks, i); + var offSetsAndInput = new ChunkOffsetsAndInput(chunks, inputs.get(i)); + int numberOfSubBatches = addToBatches(offSetsAndInput, i); // size the results array with the expected number of request/responses switch (embeddingType) { case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches)); case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches)); case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches)); } - chunkedInputs.add(chunks); + chunkedOffsets.add(offSetsAndInput); } } - private int addToBatches(List chunks, int inputIndex) { + private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { BatchRequest lastBatch; if (batchedRequests.isEmpty()) { lastBatch = new BatchRequest(new ArrayList<>()); @@ -157,16 +158,24 @@ private int addToBatches(List chunks, int inputIndex) { if (freeSpace > 0) { // use any free space in the previous batch before creating new batches - int toAdd = Math.min(freeSpace, chunks.size()); - lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))); + int toAdd = Math.min(freeSpace, chunk.offsets().size()); + lastBatch.addSubBatch( + new SubBatch( + new ChunkOffsetsAndInput(chunk.offsets().subList(0, toAdd), chunk.input()), + new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd) + ) + ); } int start = freeSpace; - while (start < chunks.size()) { - int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start); + while (start < chunk.offsets().size()) { + int toAdd = Math.min(maxNumberOfInputsPerBatch, chunk.offsets().size() - start); var batch = new BatchRequest(new ArrayList<>()); batch.addSubBatch( - new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)) + new SubBatch( + new ChunkOffsetsAndInput(chunk.offsets().subList(start, start + toAdd), chunk.input()), + new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd) + ) ); batchedRequests.add(batch); start += toAdd; @@ -333,8 +342,8 @@ public void onFailure(Exception e) { } private void sendResponse() { - var response = new ArrayList(chunkedInputs.size()); - for (int i = 0; i < chunkedInputs.size(); i++) { + var response = new ArrayList(chunkedOffsets.size()); + for (int i = 0; i < chunkedOffsets.size(); i++) { if (errors.get(i) != null) { response.add(errors.get(i)); } else { @@ -348,9 +357,9 @@ private void sendResponse() { private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { - case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex)); - case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex)); - case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex)); + case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex)); + case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex)); }; } @@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) { } public List inputs() { - return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList()); + return subBatches.stream().flatMap(s -> s.requests().toChunkText().stream()).collect(Collectors.toList()); } } @@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener requests, SubBatchPositionsAndCount positions) { - public int size() { - return requests.size(); + record SubBatch(ChunkOffsetsAndInput requests, SubBatchPositionsAndCount positions) { + int size() { + return requests.offsets().size(); + } + } + + record ChunkOffsetsAndInput(List offsets, String input) { + List toChunkText() { + return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 5df940d6a3fba..b2d6c83b89211 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -34,7 +34,6 @@ public class SentenceBoundaryChunker implements Chunker { public SentenceBoundaryChunker() { sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - } /** @@ -45,7 +44,7 @@ public SentenceBoundaryChunker() { * @return The input text chunked */ @Override - public List chunk(String input, ChunkingSettings chunkingSettings) { + public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) { return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0); } else { @@ -65,8 +64,8 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * @param maxNumberWordsPerChunk Maximum size of the chunk * @return The input text chunked */ - public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { - var chunks = new ArrayList(); + public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { + var chunks = new ArrayList(); sentenceIterator.setText(input); wordIterator.setText(input); @@ -91,7 +90,7 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl int nextChunkWordCount = wordsInSentenceCount; if (chunkWordCount > 0) { // add a new chunk containing all the input up to this sentence - chunks.add(input.substring(chunkStart, chunkEnd)); + chunks.add(new ChunkOffset(chunkStart, chunkEnd)); if (includePrecedingSentence) { if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) { @@ -127,12 +126,17 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl for (; i < sentenceSplits.size() - 1; i++) { // Because the substring was passed to splitLongSentence() // the returned positions need to be offset by chunkStart - chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end())); + chunks.add( + new ChunkOffset( + chunkStart + sentenceSplits.get(i).offsets().start(), + chunkStart + sentenceSplits.get(i).offsets().end() + ) + ); } // The final split is partially filled. // Set the next chunk start to the beginning of the // final split of the long sentence. - chunkStart = chunkStart + sentenceSplits.get(i).start(); // start pos needs to be offset by chunkStart + chunkStart = chunkStart + sentenceSplits.get(i).offsets().start(); // start pos needs to be offset by chunkStart chunkWordCount = sentenceSplits.get(i).wordCount(); } } else { @@ -151,7 +155,7 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl } if (chunkWordCount > 0) { - chunks.add(input.substring(chunkStart)); + chunks.add(new ChunkOffset(chunkStart, input.length())); } return chunks; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index c9c752b9aabbc..b15e2134f4cf7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; /** * Breaks text into smaller strings or chunks on Word boundaries. @@ -35,7 +36,7 @@ public WordBoundaryChunker() { wordIterator = BreakIterator.getWordInstance(Locale.ROOT); } - record ChunkPosition(int start, int end, int wordCount) {} + record ChunkPosition(ChunkOffset offsets, int wordCount) {} /** * Break the input text into small chunks as dictated @@ -45,7 +46,7 @@ record ChunkPosition(int start, int end, int wordCount) {} * @return List of chunked text */ @Override - public List chunk(String input, ChunkingSettings chunkingSettings) { + public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) { return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap); } else { @@ -64,18 +65,9 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * Can be 0 but must be non-negative. * @return List of chunked text */ - public List chunk(String input, int chunkSize, int overlap) { - - if (input.isEmpty()) { - return List.of(""); - } - + public List chunk(String input, int chunkSize, int overlap) { var chunkPositions = chunkPositions(input, chunkSize, overlap); - var chunks = new ArrayList(chunkPositions.size()); - for (var pos : chunkPositions) { - chunks.add(input.substring(pos.start, pos.end)); - } - return chunks; + return chunkPositions.stream().map(ChunkPosition::offsets).collect(Collectors.toList()); } /** @@ -127,7 +119,7 @@ List chunkPositions(String input, int chunkSize, int overlap) { wordsSinceStartWindowWasMarked++; if (wordsInChunkCountIncludingOverlap >= chunkSize) { - chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap)); + chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, boundary), wordsInChunkCountIncludingOverlap)); wordsInChunkCountIncludingOverlap = overlap; if (overlap == 0) { @@ -149,7 +141,7 @@ List chunkPositions(String input, int chunkSize, int overlap) { // if it ends on a boundary than the count should equal overlap in which case // we can ignore it, unless this is the first chunk in which case we want to add it if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) { - chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap)); + chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, input.length()), wordsInChunkCountIncludingOverlap)); } return chunkPositions; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 4fdf254101d3e..a82d2f474ca4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -62,7 +62,7 @@ public void testMultipleShortInputsAreSingleBatch() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < inputs.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(i, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -102,7 +102,7 @@ public void testManyInputsMakeManyBatches() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < batches.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(inputIndex, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -146,7 +146,7 @@ public void testChunkingSettingsProvided() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < batches.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(inputIndex, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -184,17 +184,17 @@ public void testLongInputChunkedOverMultipleBatches() { assertEquals(0, subBatch.positions().inputIndex()); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(1, subBatch.positions().embeddingCount()); - assertThat(subBatch.requests(), contains("1st small")); + assertThat(subBatch.requests().toChunkText(), contains("1st small")); } { var subBatch = batch.subBatches().get(1); assertEquals(1, subBatch.positions().inputIndex()); // 2nd input assertEquals(0, subBatch.positions().chunkIndex()); // 1st part of the 2nd input assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks - assertThat(subBatch.requests().get(0), startsWith("passage_input0 ")); - assertThat(subBatch.requests().get(1), startsWith(" passage_input20 ")); - assertThat(subBatch.requests().get(2), startsWith(" passage_input40 ")); - assertThat(subBatch.requests().get(3), startsWith(" passage_input60 ")); + assertThat(subBatch.requests().toChunkText().get(0), startsWith("passage_input0 ")); + assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input20 ")); + assertThat(subBatch.requests().toChunkText().get(2), startsWith(" passage_input40 ")); + assertThat(subBatch.requests().toChunkText().get(3), startsWith(" passage_input60 ")); } } { @@ -207,22 +207,22 @@ public void testLongInputChunkedOverMultipleBatches() { assertEquals(1, subBatch.positions().inputIndex()); // 2nd input assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input assertEquals(2, subBatch.positions().embeddingCount()); - assertThat(subBatch.requests().get(0), startsWith(" passage_input80 ")); - assertThat(subBatch.requests().get(1), startsWith(" passage_input100 ")); + assertThat(subBatch.requests().toChunkText().get(0), startsWith(" passage_input80 ")); + assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input100 ")); } { var subBatch = batch.subBatches().get(1); assertEquals(2, subBatch.positions().inputIndex()); // 3rd input assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk - assertThat(subBatch.requests(), contains("2nd small")); + assertThat(subBatch.requests().toChunkText(), contains("2nd small")); } { var subBatch = batch.subBatches().get(2); assertEquals(3, subBatch.positions().inputIndex()); // 4th input assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk - assertThat(subBatch.requests(), contains("3rd small")); + assertThat(subBatch.requests().toChunkText(), contains("3rd small")); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index afce8c57e0350..de943f7f57ab8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -15,7 +15,9 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT; import static org.hamcrest.Matchers.containsString; @@ -27,10 +29,24 @@ public class SentenceBoundaryChunkerTests extends ESTestCase { + /** + * Utility method for testing. + * Use the chunk functions that return offsets where possible + */ + private List textChunks( + SentenceBoundaryChunker chunker, + String input, + int maxNumberWordsPerChunk, + boolean includePrecedingSentence + ) { + var chunkPositions = chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence); + return chunkPositions.stream().map(offset -> input.substring(offset.start(), offset.end())).collect(Collectors.toList()); + } + public void testChunkSplitLargeChunkSizes() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); @@ -48,7 +64,7 @@ public void testChunkSplitLargeChunkSizes_withOverlap() { boolean overlap = true; for (int maxWordsPerChunk : new int[] { 70, 80, 100, 120, 150, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, overlap); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, overlap); int[] overlaps = chunkOverlaps(sentenceSizes(TEST_TEXT), maxWordsPerChunk, overlap); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(overlaps.length)); @@ -107,7 +123,7 @@ public void testWithOverlap_SentencesFitInChunks() { } var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(sb.toString(), chunkSize, true); + var chunks = textChunks(chunker, sb.toString(), chunkSize, true); assertThat(chunks, hasSize(numChunks)); for (int i = 0; i < numChunks; i++) { assertThat("num sentences " + numSentences, chunks.get(i), startsWith("SStart" + sentenceStartIndexes[i])); @@ -128,10 +144,10 @@ private String makeSentence(int numWords, int sentenceIndex) { public void testChunk_ChunkSizeLargerThanText() { int maxWordsPerChunk = 500; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); assertEquals(chunks.get(0), TEST_TEXT); - chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true); assertEquals(chunks.get(0), TEST_TEXT); } @@ -142,7 +158,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(expectedNumberOFChunks[i])); for (var chunk : chunks) { @@ -171,7 +187,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize_WithOverlap() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true); assertThat(chunks.get(0), containsString("Word segmentation is the problem of dividing")); assertThat(chunks.get(chunks.size() - 1), containsString(", with solidification being a stronger norm.")); } @@ -190,7 +206,7 @@ public void testShortLongShortSentences_WithOverlap() { } var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(sb.toString(), maxWordsPerChunk, true); + var chunks = textChunks(chunker, sb.toString(), maxWordsPerChunk, true); assertThat(chunks, hasSize(5)); assertTrue(chunks.get(0).trim().startsWith("SStart0")); // Entire sentence assertTrue(chunks.get(0).trim().endsWith(".")); // Entire sentence @@ -303,7 +319,7 @@ public void testChunkSplitLargeChunkSizesWithChunkingSettings() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk, 0); - var chunks = chunker.chunk(TEST_TEXT, chunkingSettings); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index ef643a4b36fdc..2ef28f2cf2e77 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; @@ -65,9 +66,22 @@ public class WordBoundaryChunkerTests extends ESTestCase { NUM_WORDS_IN_TEST_TEXT = wordCount; } + /** + * Utility method for testing. + * Use the chunk functions that return offsets where possible + */ + List textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) { + if (input.isEmpty()) { + return List.of(""); + } + + var chunkPositions = chunker.chunk(input, chunkSize, overlap); + return chunkPositions.stream().map(p -> input.substring(p.start(), p.end())).collect(Collectors.toList()); + } + public void testSingleSplit() { var chunker = new WordBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, 10_000, 0); + var chunks = textChunks(chunker, TEST_TEXT, 10_000, 0); assertThat(chunks, hasSize(1)); assertEquals(TEST_TEXT, chunks.get(0)); } @@ -168,11 +182,11 @@ public void testWindowSpanningWords() { } var whiteSpacedText = input.toString().stripTrailing(); - var chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 20, 10); + var chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 20, 10); assertChunkContents(chunks, numWords, 20, 10); - chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 10, 4); + chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 10, 4); assertChunkContents(chunks, numWords, 10, 4); - chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 15, 3); + chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 15, 3); assertChunkContents(chunks, numWords, 15, 3); } @@ -217,28 +231,28 @@ public void testWindowSpanning_TextShorterThanWindow() { } public void testEmptyString() { - var chunks = new WordBoundaryChunker().chunk("", 10, 5); - assertThat(chunks, contains("")); + var chunks = textChunks(new WordBoundaryChunker(), "", 10, 5); + assertThat(chunks.toString(), chunks, contains("")); } public void testWhitespace() { - var chunks = new WordBoundaryChunker().chunk(" ", 10, 5); + var chunks = textChunks(new WordBoundaryChunker(), " ", 10, 5); assertThat(chunks, contains(" ")); } public void testPunctuation() { int chunkSize = 1; - var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0); + var chunks = textChunks(new WordBoundaryChunker(), "Comma, separated", chunkSize, 0); assertThat(chunks, contains("Comma", ", separated")); - chunks = new WordBoundaryChunker().chunk("Mme. Thénardier", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Mme. Thénardier", chunkSize, 0); assertThat(chunks, contains("Mme", ". Thénardier")); - chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0); assertThat(chunks, contains("Won't", " you", " chunk")); chunkSize = 10; - chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0); assertThat(chunks, contains("Won't you chunk")); }