From 2944ffa3afd1bfba281a91ba988a82727ccbb348 Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Wed, 11 Sep 2024 11:20:43 -0400 Subject: [PATCH] Adding validation for invalid chunking settings inputs and improving error messaging --- .../action/TransportPutInferenceModelAction.java | 2 +- .../chunking/EmbeddingRequestChunker.java | 3 +-- .../chunking/SentenceBoundaryChunker.java | 7 ++++++- .../SentenceBoundaryChunkingSettings.java | 15 +++++++++++++++ .../inference/chunking/WordBoundaryChunker.java | 4 +++- .../chunking/WordBoundaryChunkingSettings.java | 16 ++++++++++++++++ .../inference/chunking/ChunkerBuilderTests.java | 6 ++++-- .../SentenceBoundaryChunkingSettingsTests.java | 13 +++++++++---- .../WordBoundaryChunkingSettingsTests.java | 7 +++++++ 9 files changed, 62 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 7a8434f835dcc..9665316552985 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -346,7 +346,7 @@ private void createInferenceIndex(SystemIndexDescriptor indexDescriptor, ActionL logger.debug("Creating index [{}]", indexDescriptor.getPrimaryIndex()); final String indexName = indexDescriptor.getPrimaryIndex(); var request = new CreateIndexRequest(indexName).mapping(indexDescriptor.getMappings()).settings(indexDescriptor.getSettings()); - request.origin(indexDescriptor.getOrigin()); // Setting the origin + request.origin(indexDescriptor.getOrigin()); // Setting the origin allows the internal user to create the system index final OriginSettingClient originSettingClient = new OriginSettingClient(this.client, indexDescriptor.getOrigin()); originSettingClient.admin().indices().create(request, new ActionListener<>() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 9ae05a83c8abf..81ebebdb47e4f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -106,8 +106,7 @@ public EmbeddingRequestChunker( private void splitIntoBatchedRequests(List inputs) { Function> chunkFunction; if (chunkingSettings != null) { - var chunkingStrategy = chunkingSettings.getChunkingStrategy() != null ? chunkingSettings.getChunkingStrategy() : null; - var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingStrategy); + var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); chunkFunction = input -> chunker.chunk(input, chunkingSettings); } else { var chunker = new WordBoundaryChunker(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 89091bf8e594d..3a53ecc7ae958 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -48,7 +48,12 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) { return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize); } else { - throw new IllegalArgumentException(Strings.format("SentenceBoundaryChunker can't use ChunkingSettings %s", chunkingSettings)); + throw new IllegalArgumentException( + Strings.format( + "SentenceBoundaryChunker can't use ChunkingSettings with strategy [%s]", + chunkingSettings.getChunkingStrategy() + ) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index a0534987f84e5..0d1903895f615 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -19,12 +20,18 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.Objects; +import java.util.Set; public class SentenceBoundaryChunkingSettings implements ChunkingSettings { public static final String NAME = "SentenceBoundaryChunkingSettings"; private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE; + private static final Set VALID_KEYS = Set.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString() + ); protected final int maxChunkSize; public SentenceBoundaryChunkingSettings(Integer maxChunkSize) { @@ -37,6 +44,14 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { public static SentenceBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); + + var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray(); + if (invalidSettings.length > 0) { + validationException.addValidationError( + Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) + ); + } + Integer maxChunkSize = ServiceUtils.extractRequiredPositiveInteger( map, ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index cbffc88ac528a..c9c752b9aabbc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -49,7 +49,9 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) { return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap); } else { - throw new IllegalArgumentException(Strings.format("WordBoundaryChunker can't use ChunkingSettings %s", chunkingSettings)); + throw new IllegalArgumentException( + Strings.format("WordBoundaryChunker can't use ChunkingSettings with strategy [%s]", chunkingSettings.getChunkingStrategy()) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index deb0b4376c99d..6517e0eea14d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -19,12 +20,19 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.Objects; +import java.util.Set; public class WordBoundaryChunkingSettings implements ChunkingSettings { public static final String NAME = "WordBoundaryChunkingSettings"; private static final ChunkingStrategy STRATEGY = ChunkingStrategy.WORD; + private static final Set VALID_KEYS = Set.of( + ChunkingSettingsOptions.STRATEGY.toString(), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + ChunkingSettingsOptions.OVERLAP.toString() + ); protected final int maxChunkSize; protected final int overlap; @@ -40,6 +48,14 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { public static WordBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); + + var invalidSettings = map.keySet().stream().filter(key -> VALID_KEYS.contains(key) == false).toArray(); + if (invalidSettings.length > 0) { + validationException.addValidationError( + Strings.format("Sentence based chunking settings can not have the following settings: %s", Arrays.toString(invalidSettings)) + ); + } + Integer maxChunkSize = ServiceUtils.extractRequiredPositiveInteger( map, ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java index ae7c534db4d5d..d2aea45d4603c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java @@ -12,15 +12,17 @@ import java.util.Map; +import static org.hamcrest.Matchers.instanceOf; + public class ChunkerBuilderTests extends ESTestCase { public void testNullChunkingStrategy() { - assert (ChunkerBuilder.fromChunkingStrategy(null) instanceof WordBoundaryChunker); + assertThat(ChunkerBuilder.fromChunkingStrategy(null), instanceOf(WordBoundaryChunker.class)); } public void testValidChunkingStrategy() { chunkingStrategyToExpectedChunkerClassMap().forEach((chunkingStrategy, chunkerClass) -> { - assert (ChunkerBuilder.fromChunkingStrategy(chunkingStrategy).getClass().equals(chunkerClass)); + assertThat(ChunkerBuilder.fromChunkingStrategy(chunkingStrategy), instanceOf(chunkerClass)); }); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java index 034feab42dc45..3f304a593144b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.HashMap; @@ -26,6 +27,13 @@ public void testMaxChunkSizeNotProvided() { ); } + public void testInvalidInputsProvided() { + var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomNonNegativeInt())); + chunkingSettingsMap.put(randomAlphaOfLength(10), randomNonNegativeInt()); + + assertThrows(ValidationException.class, () -> { SentenceBoundaryChunkingSettings.fromMap(chunkingSettingsMap); }); + } + public void testValidInputsProvided() { int maxChunkSize = randomNonNegativeInt(); SentenceBoundaryChunkingSettings settings = SentenceBoundaryChunkingSettings.fromMap( @@ -56,10 +64,7 @@ protected SentenceBoundaryChunkingSettings createTestInstance() { @Override protected SentenceBoundaryChunkingSettings mutateInstance(SentenceBoundaryChunkingSettings instance) throws IOException { - var chunkSize = instance.maxChunkSize; - while (chunkSize == instance.maxChunkSize) { - chunkSize = randomNonNegativeInt(); - } + var chunkSize = randomValueOtherThan(instance.maxChunkSize, ESTestCase::randomNonNegativeInt); return new SentenceBoundaryChunkingSettings(chunkSize); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java index fcdc13dd62bf3..c5515f7bf0512 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java @@ -32,6 +32,13 @@ public void testOverlapNotProvided() { }); } + public void testInvalidInputsProvided() { + var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomNonNegativeInt()), Optional.of(randomNonNegativeInt())); + chunkingSettingsMap.put(randomAlphaOfLength(10), randomNonNegativeInt()); + + assertThrows(ValidationException.class, () -> { WordBoundaryChunkingSettings.fromMap(chunkingSettingsMap); }); + } + public void testOverlapGreaterThanHalfMaxChunkSize() { var maxChunkSize = randomNonNegativeInt(); var overlap = randomIntBetween((maxChunkSize / 2) + 1, maxChunkSize);