diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java index 5af2fbf66004f..e91287070da13 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java @@ -11,14 +11,6 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; -public abstract class ChunkingSettings implements ToXContentObject, VersionedNamedWriteable { - protected String chunkingStrategy; - - public ChunkingSettings(String chunkingStrategy) { - this.chunkingStrategy = chunkingStrategy; - } - - public String getChunkingStrategy() { - return chunkingStrategy; - } +public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable { + ChunkingStrategy getChunkingStrategy(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategy.java b/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java similarity index 79% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategy.java rename to server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java index d8b38ae6d5a34..1a49f8f484ec9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategy.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingStrategy.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference.chunking; +package org.elasticsearch.inference; import org.elasticsearch.common.Strings; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/ChunkingSettingsFeatureFlag.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/ChunkingSettingsFeatureFlag.java similarity index 91% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/ChunkingSettingsFeatureFlag.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/ChunkingSettingsFeatureFlag.java index 767a2c9b1edfa..fae69058df565 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/ChunkingSettingsFeatureFlag.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/ChunkingSettingsFeatureFlag.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.ml.inference.assignment; +package org.elasticsearch.xpack.core.inference; import org.elasticsearch.common.util.FeatureFlag; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java index 1f11a17dbc0e3..01d75580cd113 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java @@ -87,6 +87,9 @@ public static XContentBuilder mappings() { .startObject("chunking_settings") .field("dynamic", "false") .startObject("properties") + .startObject("strategy") + .field("type", "keyword") + .endObject() .endObject() .endObject() .endObject() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 1ce16672112ef..c020908a10266 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.inference.chunking.DefaultChunkingSettings; import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; @@ -426,9 +425,6 @@ private static void addChunkedInferenceResultsNamedWriteables(List namedWriteables) { - namedWriteables.add( - new NamedWriteableRegistry.Entry(ChunkingSettings.class, DefaultChunkingSettings.NAME, DefaultChunkingSettings::new) - ); namedWriteables.add( new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java index fb4854a712d7a..830f1579348f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilder.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.chunking; +import org.elasticsearch.inference.ChunkingStrategy; + public class ChunkerBuilder { public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) { if (chunkingStrategy == null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index d87bb02789f07..477c3ea6352f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -8,13 +8,16 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import java.util.Map; public class ChunkingSettingsBuilder { + public static final WordBoundaryChunkingSettings DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); + public static ChunkingSettings fromMap(Map settings) { if (settings.isEmpty()) { - return new DefaultChunkingSettings(); + return DEFAULT_SETTINGS; } if (settings.containsKey(ChunkingSettingsOptions.STRATEGY.toString()) == false) { throw new IllegalArgumentException("Can't generate Chunker without ChunkingStrategy provided"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/DefaultChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/DefaultChunkingSettings.java deleted file mode 100644 index 678768ac8b2ba..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/DefaultChunkingSettings.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.chunking; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; - -public class DefaultChunkingSettings extends ChunkingSettings { - public static final String NAME = "DefaultChunkingSettings"; - - public DefaultChunkingSettings() { - super(null); - } - - public DefaultChunkingSettings(StreamInput in) { - super(null); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS; - } - - @Override - public void writeTo(StreamOutput out) throws IOException {} -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 20009110ac1a9..9ae05a83c8abf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -106,9 +106,7 @@ public EmbeddingRequestChunker( private void splitIntoBatchedRequests(List inputs) { Function> chunkFunction; if (chunkingSettings != null) { - var chunkingStrategy = chunkingSettings.getChunkingStrategy() != null - ? ChunkingStrategy.fromString(chunkingSettings.getChunkingStrategy()) - : null; + var chunkingStrategy = chunkingSettings.getChunkingStrategy() != null ? chunkingSettings.getChunkingStrategy() : null; var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingStrategy); chunkFunction = input -> chunker.chunk(input, chunkingSettings); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 7ba7292219282..a0534987f84e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -13,25 +13,25 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; import java.util.Map; +import java.util.Objects; -public class SentenceBoundaryChunkingSettings extends ChunkingSettings { +public class SentenceBoundaryChunkingSettings implements ChunkingSettings { public static final String NAME = "SentenceBoundaryChunkingSettings"; - protected static final String STRATEGY = ChunkingStrategy.SENTENCE.toString(); + private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE; protected final int maxChunkSize; public SentenceBoundaryChunkingSettings(Integer maxChunkSize) { - super(STRATEGY); this.maxChunkSize = maxChunkSize; } public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { - super(STRATEGY); maxChunkSize = in.readInt(); } @@ -76,4 +76,22 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxChunkSize); } + + @Override + public ChunkingStrategy getChunkingStrategy() { + return STRATEGY; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SentenceBoundaryChunkingSettings that = (SentenceBoundaryChunkingSettings) o; + return Objects.equals(maxChunkSize, that.maxChunkSize); + } + + @Override + public int hashCode() { + return Objects.hash(maxChunkSize); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index 4b4983b9a59c4..cbffc88ac528a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -29,8 +29,6 @@ */ public class WordBoundaryChunker implements Chunker { - private static final int DEFAULT_MAX_CHUNK_SIZE = 250; - private static final int DEFAULT_OVERLAP = 100; private BreakIterator wordIterator; public WordBoundaryChunker() { @@ -50,8 +48,6 @@ record ChunkPosition(int start, int end, int wordCount) {} public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) { return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap); - } else if (chunkingSettings instanceof DefaultChunkingSettings) { - return chunk(input, DEFAULT_MAX_CHUNK_SIZE, DEFAULT_OVERLAP); } else { throw new IllegalArgumentException(Strings.format("WordBoundaryChunker can't use ChunkingSettings %s", chunkingSettings)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index b1235438f39a5..deb0b4376c99d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -13,27 +13,27 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; import java.util.Map; +import java.util.Objects; -public class WordBoundaryChunkingSettings extends ChunkingSettings { +public class WordBoundaryChunkingSettings implements ChunkingSettings { public static final String NAME = "WordBoundaryChunkingSettings"; - protected static final String STRATEGY = ChunkingStrategy.WORD.toString(); + private static final ChunkingStrategy STRATEGY = ChunkingStrategy.WORD; protected final int maxChunkSize; protected final int overlap; public WordBoundaryChunkingSettings(Integer maxChunkSize, Integer overlap) { - super(STRATEGY); this.maxChunkSize = maxChunkSize; this.overlap = overlap; } public WordBoundaryChunkingSettings(StreamInput in) throws IOException { - super(STRATEGY); maxChunkSize = in.readInt(); overlap = in.readInt(); } @@ -46,13 +46,17 @@ public static WordBoundaryChunkingSettings fromMap(Map map) { ModelConfigurations.CHUNKING_SETTINGS, validationException ); - Integer overlap = ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax( - map, - ChunkingSettingsOptions.OVERLAP.toString(), - maxChunkSize != null ? maxChunkSize / 2 : null, - ModelConfigurations.CHUNKING_SETTINGS, - validationException - ); + + Integer overlap = null; + if (maxChunkSize != null) { + overlap = ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax( + map, + ChunkingSettingsOptions.OVERLAP.toString(), + maxChunkSize / 2, + ModelConfigurations.CHUNKING_SETTINGS, + validationException + ); + } if (validationException.validationErrors().isEmpty() == false) { throw validationException; @@ -88,4 +92,22 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxChunkSize); out.writeInt(overlap); } + + @Override + public ChunkingStrategy getChunkingStrategy() { + return STRATEGY; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WordBoundaryChunkingSettings that = (WordBoundaryChunkingSettings) o; + return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(overlap, that.overlap); + } + + @Override + public int hashCode() { + return Objects.hash(maxChunkSize, overlap); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 13e082c75a029..35d97b1b5041a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -417,13 +417,13 @@ public static Integer extractRequiredPositiveInteger( public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax( Map map, String settingName, - Integer maxValue, + int maxValue, String scope, ValidationException validationException ) { Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException); - if (maxValue != null && field != null && field > maxValue) { + if (field != null && field > maxValue) { validationException.addValidationError( ServiceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, field, maxValue) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index d652c6c3352ec..710b503b97192 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -25,7 +25,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.inference.assignment.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; @@ -271,7 +271,7 @@ protected void doChunkedInfer( input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT, - model.getConfigurations().getChunkingSettings() + openAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); } else { batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index 8b7405aa9e92a..5a1922fd200f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -14,6 +14,8 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; @@ -27,7 +29,7 @@ public static ModelConfigurations createRandomInstance() { randomAlphaOfLength(6), randomServiceSettings(), randomTaskSettings(taskType), - null + ChunkingSettingsFeatureFlag.isEnabled() && randomBoolean() ? ChunkingSettingsTests.createRandomChunkingSettings() : null ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java index 6b9b15db16408..ae7c534db4d5d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkerBuilderTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.chunking; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.test.ESTestCase; import java.util.Map; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 690ec963144aa..061ea677e6fe1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.test.ESTestCase; import java.util.Collections; @@ -16,10 +17,12 @@ public class ChunkingSettingsBuilderTests extends ESTestCase { + public static final WordBoundaryChunkingSettings DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); + public void testEmptyChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(Collections.emptyMap()); - assertTrue(chunkingSettings instanceof DefaultChunkingSettings); + assertEquals(DEFAULT_SETTINGS, chunkingSettings); } public void testChunkingStrategyNotProvided() { @@ -29,13 +32,14 @@ public void testChunkingStrategyNotProvided() { } public void testValidChunkingSettingsMap() { - chunkingSettingsMapToChunkingSettingsClass().forEach((chunkingSettings, chunkingSettingsClass) -> { - assertEquals(chunkingSettingsClass, ChunkingSettingsBuilder.fromMap(new HashMap<>(chunkingSettings)).getClass()); + chunkingSettingsMapToChunkingSettings().forEach((chunkingSettingsMap, chunkingSettings) -> { + assertEquals(chunkingSettings, ChunkingSettingsBuilder.fromMap(new HashMap<>(chunkingSettingsMap))); }); } - private Map, Class> chunkingSettingsMapToChunkingSettingsClass() { + private Map, ChunkingSettings> chunkingSettingsMapToChunkingSettings() { var maxChunkSize = randomNonNegativeInt(); + var overlap = randomIntBetween(1, maxChunkSize / 2); return Map.of( Map.of( ChunkingSettingsOptions.STRATEGY.toString(), @@ -43,16 +47,16 @@ private Map, Class> chunkingSett ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize, ChunkingSettingsOptions.OVERLAP.toString(), - randomIntBetween(1, maxChunkSize / 2) + overlap ), - WordBoundaryChunkingSettings.class, + new WordBoundaryChunkingSettings(maxChunkSize, overlap), Map.of( ChunkingSettingsOptions.STRATEGY.toString(), ChunkingStrategy.SENTENCE.toString(), ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize ), - SentenceBoundaryChunkingSettings.class + new SentenceBoundaryChunkingSettings(maxChunkSize) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java index 42203357ffd5e..2482586c75595 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.test.ESTestCase; import java.util.HashMap; @@ -17,6 +18,7 @@ public class ChunkingSettingsTests extends ESTestCase { public static ChunkingSettings createRandomChunkingSettings() { ChunkingStrategy randomStrategy = randomFrom(ChunkingStrategy.values()); + switch (randomStrategy) { case WORD -> { var maxChunkSize = randomNonNegativeInt(); @@ -25,9 +27,7 @@ public static ChunkingSettings createRandomChunkingSettings() { case SENTENCE -> { return new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); } - default -> { - return new DefaultChunkingSettings(); - } + default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]"); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java index bb549bae76155..802cea5986b30 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingStrategyTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.chunking; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.test.ESTestCase; public class ChunkingStrategyTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java index f610e907a2bd5..034feab42dc45 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettingsTests.java @@ -8,13 +8,16 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.common.ValidationException; -import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Optional; -public class SentenceBoundaryChunkingSettingsTests extends ESTestCase { +public class SentenceBoundaryChunkingSettingsTests extends AbstractWireSerializingTestCase { public void testMaxChunkSizeNotProvided() { assertThrows( @@ -29,7 +32,7 @@ public void testValidInputsProvided() { buildChunkingSettingsMap(Optional.of(maxChunkSize)) ); - assertEquals(settings.getChunkingStrategy(), ChunkingStrategy.SENTENCE.toString()); + assertEquals(settings.getChunkingStrategy(), ChunkingStrategy.SENTENCE); assertEquals(settings.maxChunkSize, maxChunkSize); } @@ -40,4 +43,24 @@ public Map buildChunkingSettingsMap(Optional maxChunkSi return settingsMap; } + + @Override + protected Writeable.Reader instanceReader() { + return SentenceBoundaryChunkingSettings::new; + } + + @Override + protected SentenceBoundaryChunkingSettings createTestInstance() { + return new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); + } + + @Override + protected SentenceBoundaryChunkingSettings mutateInstance(SentenceBoundaryChunkingSettings instance) throws IOException { + var chunkSize = instance.maxChunkSize; + while (chunkSize == instance.maxChunkSize) { + chunkSize = randomNonNegativeInt(); + } + + return new SentenceBoundaryChunkingSettings(chunkSize); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index 8915a5204274e..21d8c65ad7dcd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -138,20 +138,6 @@ public void testNumberOfChunksWithWordBoundaryChunkingSettings() { } } - public void testNumberOfChunksWithDefaultChunkingSettings() { - for (int numWords : new int[] { 10, 22, 50, 73, 100 }) { - var sb = new StringBuilder(); - for (int i = 0; i < numWords; i++) { - sb.append(i).append(' '); - } - var whiteSpacedText = sb.toString(); - assertExpectedNumberOfChunksWithDefaultChunkingSettings(whiteSpacedText, numWords); - assertExpectedNumberOfChunksWithDefaultChunkingSettings(whiteSpacedText, numWords); - assertExpectedNumberOfChunksWithDefaultChunkingSettings(whiteSpacedText, numWords); - assertExpectedNumberOfChunksWithDefaultChunkingSettings(whiteSpacedText, numWords); - } - } - public void testInvalidChunkingSettingsProvided() { ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt()); assertThrows(IllegalArgumentException.class, () -> { new WordBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); }); @@ -269,12 +255,6 @@ private void assertExpectedNumberOfChunksWithWordBoundaryChunkingSettings( assertEquals(expected, chunks.size()); } - private void assertExpectedNumberOfChunksWithDefaultChunkingSettings(String input, int numWords) { - var chunks = new WordBoundaryChunker().chunk(input, new DefaultChunkingSettings()); - int expected = expectedNumberOfChunks(numWords, DEFAULT_MAX_CHUNK_SIZE, DEFAULT_OVERLAP); - assertEquals(expected, chunks.size()); - } - private void assertExpectedNumberOfChunks(String input, int numWords, int windowSize, int overlap) { var chunks = new WordBoundaryChunker().chunk(input, windowSize, overlap); int expected = expectedNumberOfChunks(numWords, windowSize, overlap); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java index 9bfb3f2dccc48..fcdc13dd62bf3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettingsTests.java @@ -8,13 +8,17 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.common.ValidationException; -import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.ChunkingStrategy; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; -public class WordBoundaryChunkingSettingsTests extends ESTestCase { +public class WordBoundaryChunkingSettingsTests extends AbstractWireSerializingTestCase { public void testMaxChunkSizeNotProvided() { assertThrows(ValidationException.class, () -> { @@ -43,7 +47,7 @@ public void testValidInputsProvided() { buildChunkingSettingsMap(Optional.of(maxChunkSize), Optional.of(overlap)) ); - assertEquals(settings.getChunkingStrategy(), ChunkingStrategy.WORD.toString()); + assertEquals(settings.getChunkingStrategy(), ChunkingStrategy.WORD); assertEquals(settings.maxChunkSize, maxChunkSize); assertEquals(settings.overlap, overlap); } @@ -56,4 +60,38 @@ public Map buildChunkingSettingsMap(Optional maxChunkSi return settingsMap; } + + @Override + protected Writeable.Reader instanceReader() { + return WordBoundaryChunkingSettings::new; + } + + @Override + protected WordBoundaryChunkingSettings createTestInstance() { + var maxChunkSize = randomNonNegativeInt(); + return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2)); + } + + @Override + protected WordBoundaryChunkingSettings mutateInstance(WordBoundaryChunkingSettings instance) throws IOException { + var valueToMutate = randomFrom(List.of(ChunkingSettingsOptions.MAX_CHUNK_SIZE, ChunkingSettingsOptions.OVERLAP)); + var maxChunkSize = instance.maxChunkSize; + var overlap = instance.overlap; + + if (valueToMutate.equals(ChunkingSettingsOptions.MAX_CHUNK_SIZE)) { + while (maxChunkSize == instance.maxChunkSize) { + maxChunkSize = randomNonNegativeInt(); + } + + if (overlap > maxChunkSize / 2) { + overlap = randomIntBetween(1, maxChunkSize / 2); + } + } else if (valueToMutate.equals(ChunkingSettingsOptions.OVERLAP)) { + while (overlap == instance.overlap) { + overlap = randomIntBetween(1, maxChunkSize / 2); + } + } + + return new WordBoundaryChunkingSettings(maxChunkSize, overlap); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index d07f43dabbdf7..c0d4b2bbcda6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -29,9 +29,9 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.assignment.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;