Skip to content

Commit

Permalink
Removing DefaultChunkingSettings, cleaning up chunking settings class…
Browse files Browse the repository at this point in the history
… and related tests, add chunking strategy to inference index
  • Loading branch information
dan-rubinstein committed Aug 30, 2024
1 parent 1cbe316 commit 943e0ab
Show file tree
Hide file tree
Showing 23 changed files with 163 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

public abstract class ChunkingSettings implements ToXContentObject, VersionedNamedWriteable {
protected String chunkingStrategy;

public ChunkingSettings(String chunkingStrategy) {
this.chunkingStrategy = chunkingStrategy;
}

public String getChunkingStrategy() {
return chunkingStrategy;
}
public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {
ChunkingStrategy getChunkingStrategy();
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.inference;

import org.elasticsearch.common.Strings;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.assignment;
package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.util.FeatureFlag;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ public static XContentBuilder mappings() {
.startObject("chunking_settings")
.field("dynamic", "false")
.startObject("properties")
.startObject("strategy")
.field("type", "keyword")
.endObject()
.endObject()
.endObject()
.endObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.inference.chunking.DefaultChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings;
Expand Down Expand Up @@ -426,9 +425,6 @@ private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteabl
}

private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ChunkingSettings.class, DefaultChunkingSettings.NAME, DefaultChunkingSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.inference.ChunkingStrategy;

public class ChunkerBuilder {
public static Chunker fromChunkingStrategy(ChunkingStrategy chunkingStrategy) {
if (chunkingStrategy == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;

import java.util.Map;

public class ChunkingSettingsBuilder {
public static final WordBoundaryChunkingSettings DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100);

public static ChunkingSettings fromMap(Map<String, Object> settings) {
if (settings.isEmpty()) {
return new DefaultChunkingSettings();
return DEFAULT_SETTINGS;
}
if (settings.containsKey(ChunkingSettingsOptions.STRATEGY.toString()) == false) {
throw new IllegalArgumentException("Can't generate Chunker without ChunkingStrategy provided");
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ public EmbeddingRequestChunker(
private void splitIntoBatchedRequests(List<String> inputs) {
Function<String, List<String>> chunkFunction;
if (chunkingSettings != null) {
var chunkingStrategy = chunkingSettings.getChunkingStrategy() != null
? ChunkingStrategy.fromString(chunkingSettings.getChunkingStrategy())
: null;
var chunkingStrategy = chunkingSettings.getChunkingStrategy() != null ? chunkingSettings.getChunkingStrategy() : null;
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingStrategy);
chunkFunction = input -> chunker.chunk(input, chunkingSettings);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class SentenceBoundaryChunkingSettings extends ChunkingSettings {
public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
public static final String NAME = "SentenceBoundaryChunkingSettings";
protected static final String STRATEGY = ChunkingStrategy.SENTENCE.toString();
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE;
protected final int maxChunkSize;

public SentenceBoundaryChunkingSettings(Integer maxChunkSize) {
super(STRATEGY);
this.maxChunkSize = maxChunkSize;
}

public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException {
super(STRATEGY);
maxChunkSize = in.readInt();
}

Expand Down Expand Up @@ -76,4 +76,22 @@ public TransportVersion getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(maxChunkSize);
}

@Override
public ChunkingStrategy getChunkingStrategy() {
return STRATEGY;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SentenceBoundaryChunkingSettings that = (SentenceBoundaryChunkingSettings) o;
return Objects.equals(maxChunkSize, that.maxChunkSize);
}

@Override
public int hashCode() {
return Objects.hash(maxChunkSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
*/
public class WordBoundaryChunker implements Chunker {

private static final int DEFAULT_MAX_CHUNK_SIZE = 250;
private static final int DEFAULT_OVERLAP = 100;
private BreakIterator wordIterator;

public WordBoundaryChunker() {
Expand All @@ -50,8 +48,6 @@ record ChunkPosition(int start, int end, int wordCount) {}
public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) {
return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap);
} else if (chunkingSettings instanceof DefaultChunkingSettings) {
return chunk(input, DEFAULT_MAX_CHUNK_SIZE, DEFAULT_OVERLAP);
} else {
throw new IllegalArgumentException(Strings.format("WordBoundaryChunker can't use ChunkingSettings %s", chunkingSettings));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class WordBoundaryChunkingSettings extends ChunkingSettings {
public class WordBoundaryChunkingSettings implements ChunkingSettings {
public static final String NAME = "WordBoundaryChunkingSettings";
protected static final String STRATEGY = ChunkingStrategy.WORD.toString();
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.WORD;
protected final int maxChunkSize;
protected final int overlap;

public WordBoundaryChunkingSettings(Integer maxChunkSize, Integer overlap) {
super(STRATEGY);
this.maxChunkSize = maxChunkSize;
this.overlap = overlap;
}

public WordBoundaryChunkingSettings(StreamInput in) throws IOException {
super(STRATEGY);
maxChunkSize = in.readInt();
overlap = in.readInt();
}
Expand All @@ -46,13 +46,17 @@ public static WordBoundaryChunkingSettings fromMap(Map<String, Object> map) {
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
Integer overlap = ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax(
map,
ChunkingSettingsOptions.OVERLAP.toString(),
maxChunkSize != null ? maxChunkSize / 2 : null,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

Integer overlap = null;
if (maxChunkSize != null) {
overlap = ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax(
map,
ChunkingSettingsOptions.OVERLAP.toString(),
maxChunkSize / 2,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
Expand Down Expand Up @@ -88,4 +92,22 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeInt(maxChunkSize);
out.writeInt(overlap);
}

@Override
public ChunkingStrategy getChunkingStrategy() {
return STRATEGY;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WordBoundaryChunkingSettings that = (WordBoundaryChunkingSettings) o;
return Objects.equals(maxChunkSize, that.maxChunkSize) && Objects.equals(overlap, that.overlap);
}

@Override
public int hashCode() {
return Objects.hash(maxChunkSize, overlap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,13 @@ public static Integer extractRequiredPositiveInteger(
public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
Map<String, Object> map,
String settingName,
Integer maxValue,
int maxValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (maxValue != null && field != null && field > maxValue) {
if (field != null && field > maxValue) {
validationException.addValidationError(
ServiceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, field, maxValue)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator;
Expand Down Expand Up @@ -271,7 +271,7 @@ protected void doChunkedInfer(
input,
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT,
model.getConfigurations().getChunkingSettings()
openAiModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;

Expand All @@ -27,7 +29,7 @@ public static ModelConfigurations createRandomInstance() {
randomAlphaOfLength(6),
randomServiceSettings(),
randomTaskSettings(taskType),
null
ChunkingSettingsFeatureFlag.isEnabled() && randomBoolean() ? ChunkingSettingsTests.createRandomChunkingSettings() : null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.chunking;

import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.test.ESTestCase;

import java.util.Map;
Expand Down
Loading

0 comments on commit 943e0ab

Please sign in to comment.