diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/ChunkedInference.java similarity index 73% rename from server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java rename to server/src/main/java/org/elasticsearch/inference/ChunkedInference.java index 10e00e9860200..c54e5a98d56cc 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkedInference.java @@ -12,23 +12,27 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.xcontent.XContent; +import java.io.IOException; import java.util.Iterator; -public interface ChunkedInferenceServiceResults extends InferenceServiceResults { +public interface ChunkedInference { /** * Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields. - * The iterator iterates over all the chunks stored in the {@link ChunkedInferenceServiceResults}. * * @param xcontent provided by the SemanticTextField * @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding). */ - Iterator chunksAsMatchedTextAndByteReference(XContent xcontent); + Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException; /** - * A chunk of inference results containing matched text and the bytes reference. + * A chunk of inference results containing matched text, the substring location + * in the original text and the bytes reference. * @param matchedText + * @param textOffset * @param bytesReference */ - record Chunk(String matchedText, BytesReference bytesReference) {} + record Chunk(String matchedText, TextOffset textOffset, BytesReference bytesReference) {} + + record TextOffset(int start, int end) {} } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 4497254aad1f0..e892646481e0c 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -127,7 +127,7 @@ void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ); /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java new file mode 100644 index 0000000000000..c2f70b0be2916 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingByte.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public record ChunkedInferenceEmbeddingByte(List chunks) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding()))); + } + return asChunk.iterator(); + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, byte[] value) throws IOException { + XContentBuilder builder = XContentBuilder.builder(xContent); + builder.startArray(); + for (byte v : value) { + builder.value(v); + } + builder.endArray(); + return BytesReference.bytes(builder); + } + + public record ByteEmbeddingChunk(byte[] embedding, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java new file mode 100644 index 0000000000000..651d135b761dd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingFloat.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +public record ChunkedInferenceEmbeddingFloat(List chunks) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.embedding()))); + } + return asChunk.iterator(); + } + + /** + * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. + */ + private static BytesReference toBytesReference(XContent xContent, float[] value) throws IOException { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startArray(); + for (float v : value) { + b.value(v); + } + b.endArray(); + return BytesReference.bytes(b); + } + + public record FloatEmbeddingChunk(float[] embedding, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java new file mode 100644 index 0000000000000..37bf92e0dbfce --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceEmbeddingSparse.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; + +public record ChunkedInferenceEmbeddingSparse(List chunks) implements ChunkedInference { + + public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { + validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); + + var results = new ArrayList(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + results.add( + new ChunkedInferenceEmbeddingSparse( + List.of( + new SparseEmbeddingChunk( + sparseEmbeddingResults.embeddings().get(i).tokens(), + inputs.get(i), + new TextOffset(0, inputs.get(i).length()) + ) + ) + ) + ); + } + + return results; + } + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) throws IOException { + var asChunk = new ArrayList(); + for (var chunk : chunks) { + asChunk.add(new Chunk(chunk.matchedText(), chunk.offset(), toBytesReference(xcontent, chunk.weightedTokens()))); + } + return asChunk.iterator(); + } + + private static BytesReference toBytesReference(XContent xContent, List tokens) throws IOException { + XContentBuilder b = XContentBuilder.builder(xContent); + b.startObject(); + for (var weightedToken : tokens) { + weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); + } + b.endObject(); + return BytesReference.bytes(b); + } + + public record SparseEmbeddingChunk(List weightedTokens, String matchedText, TextOffset offset) {} +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java new file mode 100644 index 0000000000000..65be9f12d7686 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedInferenceError.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.xcontent.XContent; + +import java.util.Iterator; +import java.util.stream.Stream; + +public record ChunkedInferenceError(Exception exception) implements ChunkedInference { + + @Override + public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { + return Stream.of(exception).map(e -> new Chunk(e.getMessage(), new TextOffset(0, 0), BytesArray.EMPTY)).iterator(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java deleted file mode 100644 index 18f88a8ff022a..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; - -import java.io.IOException; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Stream; - -public class ErrorChunkedInferenceResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "error_chunked"; - - private final Exception exception; - - public ErrorChunkedInferenceResults(Exception exception) { - this.exception = Objects.requireNonNull(exception); - } - - public ErrorChunkedInferenceResults(StreamInput in) throws IOException { - this.exception = in.readException(); - } - - public Exception getException() { - return exception; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeException(exception); - } - - @Override - public boolean equals(Object object) { - if (object == this) { - return true; - } - if (object == null || getClass() != object.getClass()) { - return false; - } - ErrorChunkedInferenceResults that = (ErrorChunkedInferenceResults) object; - // Just compare the message for serialization test purposes - return Objects.equals(exception.getMessage(), that.exception.getMessage()); - } - - @Override - public int hashCode() { - // Just compare the message for serialization test purposes - return Objects.hash(exception.getMessage()); - } - - @Override - public List transformToCoordinationFormat() { - return null; - } - - @Override - public List transformToLegacyFormat() { - return null; - } - - @Override - public Map asMap() { - Map asMap = new LinkedHashMap<>(); - asMap.put(NAME, exception.getMessage()); - return asMap; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.field(NAME, exception.getMessage()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return Stream.of(exception).map(e -> new Chunk(e.getMessage(), BytesArray.EMPTY)).iterator(); - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java deleted file mode 100644 index c961050acefdb..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public class InferenceChunkedSparseEmbeddingResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_sparse_embedding_results"; - public static final String FIELD_NAME = "sparse_embedding_chunk"; - - public static InferenceChunkedSparseEmbeddingResults ofMlResult(MlChunkedTextExpansionResults mlInferenceResults) { - return new InferenceChunkedSparseEmbeddingResults(mlInferenceResults.getChunks()); - } - - /** - * Returns a list of {@link InferenceChunkedSparseEmbeddingResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedSparseEmbeddingResults} will have a single chunk containing the entire results from the - * {@link SparseEmbeddingResults}. - */ - public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), sparseEmbeddingResults.embeddings().get(i))); - } - - return results; - } - - private static InferenceChunkedSparseEmbeddingResults ofSingle(String input, SparseEmbeddingResults.Embedding embedding) { - var weightedTokens = embedding.tokens() - .stream() - .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) - .toList(); - - return new InferenceChunkedSparseEmbeddingResults(List.of(new MlChunkedTextExpansionResults.ChunkedResult(input, weightedTokens))); - } - - private final List chunkedResults; - - public InferenceChunkedSparseEmbeddingResults(List chunks) { - this.chunkedResults = chunks; - } - - public InferenceChunkedSparseEmbeddingResults(StreamInput in) throws IOException { - this.chunkedResults = in.readCollectionAsList(MlChunkedTextExpansionResults.ChunkedResult::new); - } - - public List getChunkedResults() { - return chunkedResults; - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunkedResults.iterator()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunkedResults); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of( - FIELD_NAME, - chunkedResults.stream().map(MlChunkedTextExpansionResults.ChunkedResult::asMap).collect(Collectors.toList()) - ); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedSparseEmbeddingResults that = (InferenceChunkedSparseEmbeddingResults) o; - return Objects.equals(chunkedResults, that.chunkedResults); - } - - @Override - public int hashCode() { - return Objects.hash(chunkedResults); - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunkedResults.stream() - .map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.weightedTokens()))) - .iterator(); - } - - /** - * Serialises the {@link WeightedToken} list, according to the provided {@link XContent}, - * into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, List tokens) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startObject(); - for (var weightedToken : tokens) { - weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); - } - b.endObject(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java deleted file mode 100644 index 6bd66664068d5..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingByteResults(List chunks, boolean isTruncated) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_byte_results"; - public static final String FIELD_NAME = "text_embedding_byte_chunk"; - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingByteResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedTextEmbeddingByteResults} will have a single chunk containing the entire results from the - * {@link InferenceTextEmbeddingByteResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingByteResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), textEmbeddings.embeddings().get(i).values())); - } - - return results; - } - - private static InferenceChunkedTextEmbeddingByteResults ofSingle(String input, byte[] byteEmbeddings) { - return new InferenceChunkedTextEmbeddingByteResults(List.of(new InferenceByteEmbeddingChunk(input, byteEmbeddings)), false); - } - - public InferenceChunkedTextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceByteEmbeddingChunk::new), in.readBoolean()); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - out.writeBoolean(isTruncated); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingByteResults that = (InferenceChunkedTextEmbeddingByteResults) o; - return isTruncated == that.isTruncated && Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks, isTruncated); - } - - public record InferenceByteEmbeddingChunk(String matchedText, byte[] embedding) implements Writeable, ToXContentObject { - - public InferenceByteEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeByteArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (byte value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceByteEmbeddingChunk that = (InferenceByteEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - private static BytesReference toBytesReference(XContent xContent, byte[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (byte v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java deleted file mode 100644 index 369f22a807913..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.utils.FloatConversionUtils; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingFloatResults(List chunks) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_float_results"; - public static final String FIELD_NAME = "text_embedding_float_chunk"; - - public InferenceChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceFloatEmbeddingChunk::new)); - } - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingFloatResults}. - * Each {@link InferenceChunkedTextEmbeddingFloatResults} contain a single chunk with the text and the - * {@link InferenceTextEmbeddingFloatResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingFloatResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - - for (int i = 0; i < inputs.size(); i++) { - results.add( - new InferenceChunkedTextEmbeddingFloatResults( - List.of(new InferenceFloatEmbeddingChunk(inputs.get(i), textEmbeddings.embeddings().get(i).values())) - ) - ); - } - - return results; - } - - public static InferenceChunkedTextEmbeddingFloatResults ofMlResults(MlChunkedTextEmbeddingFloatResults mlInferenceResult) { - return new InferenceChunkedTextEmbeddingFloatResults( - mlInferenceResult.getChunks() - .stream() - .map(chunk -> new InferenceFloatEmbeddingChunk(chunk.matchedText(), FloatConversionUtils.floatArrayOf(chunk.embedding()))) - .toList() - ); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - // TODO add isTruncated flag - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingFloatResults that = (InferenceChunkedTextEmbeddingFloatResults) o; - return Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks); - } - - public record InferenceFloatEmbeddingChunk(String matchedText, float[] embedding) implements Writeable, ToXContentObject { - - public InferenceFloatEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readFloatArray()); - } - - public static InferenceFloatEmbeddingChunk of(String matchedText, double[] doubleEmbedding) { - return new InferenceFloatEmbeddingChunk(matchedText, FloatConversionUtils.floatArrayOf(doubleEmbedding)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeFloatArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (float value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceFloatEmbeddingChunk that = (InferenceFloatEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, float[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (float v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java index 4c68d02264457..cb69f1e403e9c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingUtils.java @@ -31,7 +31,7 @@ public static int getFirstEmbeddingSize(List embeddings) throws Il * Throws an exception if the number of elements in the input text list is different than the results in text embedding * response. */ - static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { + public static void validateInputSizeAgainstEmbeddings(List inputs, int embeddingSize) { if (inputs.size() != embeddingSize) { throw new IllegalArgumentException( Strings.format("The number of inputs [%s] does not match the embeddings [%s]", inputs.size(), embeddingSize) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java deleted file mode 100644 index 83678cd030bc2..0000000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.INFERENCE; -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.TEXT; - -public class InferenceChunkedTextEmbeddingFloatResultsTests extends ESTestCase { - /** - * Similar to {@link org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults#asMap()} but it converts the - * embeddings float array into a list of floats to make testing equality easier. - */ - public static Map asMapWithListsInsteadOfArrays(InferenceChunkedTextEmbeddingFloatResults result) { - return Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - result.getChunks() - .stream() - .map(InferenceChunkedTextEmbeddingFloatResultsTests::inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays) - .collect(Collectors.toList()) - ); - } - - /** - * Similar to {@link MlChunkedTextEmbeddingFloatResults.EmbeddingChunk#asMap()} but it converts the double array into a list of doubles - * to make testing equality easier. - */ - public static Map inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays( - InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk chunk - ) { - var chunkAsList = new ArrayList(chunk.embedding().length); - for (double embedding : chunk.embedding()) { - chunkAsList.add((float) embedding); - } - var map = new HashMap(); - map.put(TEXT, chunk.matchedText()); - map.put(INFERENCE, chunkAsList); - return map; - } -} diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ae11a02d312e2..488361ed29c27 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -17,7 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -36,7 +36,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import java.io.IOException; @@ -140,7 +140,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { case ANY, TEXT_EMBEDDING -> { @@ -165,9 +165,24 @@ private InferenceTextEmbeddingFloatResults makeResults(List input, int d return new InferenceTextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List input, int dimensions) { + private List makeChunkedResults(List input, int dimensions) { InferenceTextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions); - return InferenceChunkedTextEmbeddingFloatResults.listOf(input, nonChunkedResults); + + var results = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + results.add( + new ChunkedInferenceEmbeddingFloat( + List.of( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + nonChunkedResults.embeddings().get(i).values(), + input.get(i), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) + ) + ); + } + return results; } protected ServiceSettings getServiceSettingsFromMap(Map serviceSettingsMap) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 9320571572f0a..04225503a6373 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -128,7 +128,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index fe0223cce0323..6fa4023c12c7e 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -34,9 +34,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; @@ -131,7 +130,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { switch (model.getConfigurations().getTaskType()) { case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input)); @@ -156,16 +155,22 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } - private List makeChunkedResults(List input) { - List results = new ArrayList<>(); + private List makeChunkedResults(List input) { + List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); for (int j = 0; j < 5; j++) { tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); } results.add( - new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)) + new ChunkedInferenceEmbeddingSparse( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + tokens, + input.get(i), + new ChunkedInference.TextOffset(0, input.get(i).length()) + ) + ) ) ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 6d7983bc8cb53..5b20c25bcb226 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -160,7 +160,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 673b841317a3d..af72934565f74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -17,10 +17,6 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; @@ -107,7 +103,6 @@ public static List getNamedWriteables() { ); addInferenceResultsNamedWriteables(namedWriteables); - addChunkedInferenceResultsNamedWriteables(namedWriteables); // Empty default task settings namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); @@ -425,37 +420,6 @@ private static void addInternalNamedWriteables(List namedWriteables) { - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - ErrorChunkedInferenceResults.NAME, - ErrorChunkedInferenceResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedSparseEmbeddingResults.NAME, - InferenceChunkedSparseEmbeddingResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedTextEmbeddingFloatResults.NAME, - InferenceChunkedTextEmbeddingFloatResults::new - ) - ); - namedWriteables.add( - new NamedWriteableRegistry.Entry( - InferenceServiceResults.class, - InferenceChunkedTextEmbeddingByteResults.NAME, - InferenceChunkedTextEmbeddingByteResults::new - ) - ); - } - private static void addChunkingSettingsNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index d178e927aa65d..a9195ea24af3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -29,7 +29,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; @@ -37,11 +37,12 @@ import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -141,7 +142,7 @@ private record FieldInferenceResponse( int inputOrder, boolean isOriginalFieldInput, Model model, - ChunkedInferenceServiceResults chunkedResults + ChunkedInference chunkedResults ) {} private record FieldInferenceResponseAccumulator( @@ -273,19 +274,19 @@ public void onFailure(Exception exc) { final List currentBatch = requests.subList(0, currentBatchSize); final List nextBatch = requests.subList(currentBatchSize, requests.size()); final List inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); - ActionListener> completionListener = new ActionListener<>() { + ActionListener> completionListener = new ActionListener<>() { @Override - public void onResponse(List results) { + public void onResponse(List results) { try { var requestsIterator = requests.iterator(); - for (ChunkedInferenceServiceResults result : results) { + for (ChunkedInference result : results) { var request = requestsIterator.next(); var acc = inferenceResults.get(request.index); - if (result instanceof ErrorChunkedInferenceResults error) { + if (result instanceof ChunkedInferenceError error) { acc.addFailure( new ElasticsearchException( "Exception when running inference id [{}] on field [{}]", - error.getException(), + error.exception(), inferenceProvider.model.getInferenceEntityId(), request.field ) @@ -359,7 +360,7 @@ private void addInferenceResponseFailure(int id, Exception failure) { * Otherwise, the source of the request is augmented with the field inference results under the * {@link SemanticTextField#INFERENCE_FIELD} field. */ - private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) { + private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) throws IOException { if (response.failures().isEmpty() == false) { for (var failure : response.failures()) { item.abort(item.index(), failure); @@ -376,7 +377,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons // ensure that the order in the original field is consistent in case of multiple inputs Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder)); List inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList()); - List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); + List results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList()); var result = new SemanticTextField( fieldName, inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2aef54e56f4b9..9b0b1104df660 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -11,18 +11,17 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import java.util.ArrayList; import java.util.List; @@ -72,8 +71,8 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private List>> floatResults; private List>> byteResults; private List>> sparseResults; - private AtomicArray errors; - private ActionListener> finalListener; + private AtomicArray errors; + private ActionListener> finalListener; public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, EmbeddingType embeddingType) { this(inputs, maxNumberOfInputsPerBatch, DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP, embeddingType); @@ -189,7 +188,7 @@ private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { * @param finalListener The listener to call once all the batches are processed * @return Batches and listeners */ - public List batchRequestsWithListeners(ActionListener> finalListener) { + public List batchRequestsWithListeners(ActionListener> finalListener) { this.finalListener = finalListener; int numberOfRequests = batchedRequests.size(); @@ -331,9 +330,8 @@ private ElasticsearchStatusException unexpectedResultTypeException(String got, S @Override public void onFailure(Exception e) { - var errorResult = new ErrorChunkedInferenceResults(e); for (var pos : positions) { - errors.set(pos.inputIndex(), errorResult); + errors.set(pos.inputIndex(), e); } if (resultCount.incrementAndGet() == totalNumberOfRequests) { @@ -342,10 +340,10 @@ public void onFailure(Exception e) { } private void sendResponse() { - var response = new ArrayList(chunkedOffsets.size()); + var response = new ArrayList(chunkedOffsets.size()); for (int i = 0; i < chunkedOffsets.size(); i++) { if (errors.get(i) != null) { - response.add(errors.get(i)); + response.add(new ChunkedInferenceError(errors.get(i))); } else { response.add(mergeResultsWithInputs(i)); } @@ -355,16 +353,16 @@ private void sendResponse() { } } - private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { + private ChunkedInference mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { - case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex)); - case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex)); - case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex)); + case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex), floatResults.get(resultIndex)); + case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex), sparseResults.get(resultIndex)); }; } - private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -375,18 +373,22 @@ private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { embeddingChunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(chunks.get(i), all.get(i).values()) + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + all.get(i).values(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) ); } - return new InferenceChunkedTextEmbeddingFloatResults(embeddingChunks); + return new ChunkedInferenceEmbeddingFloat(embeddingChunks); } - private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -397,18 +399,22 @@ private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { embeddingChunks.add( - new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(chunks.get(i), all.get(i).values()) + new ChunkedInferenceEmbeddingByte.ByteEmbeddingChunk( + all.get(i).values(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) ); } - return new InferenceChunkedTextEmbeddingByteResults(embeddingChunks, false); + return new ChunkedInferenceEmbeddingByte(embeddingChunks); } - private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( - List chunks, + private ChunkedInferenceEmbeddingSparse mergeSparseResultsWithInputs( + ChunkOffsetsAndInput chunks, AtomicArray> debatchedResults ) { var all = new ArrayList(); @@ -419,12 +425,18 @@ private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( assert chunks.size() == all.size(); - var embeddingChunks = new ArrayList(); + var embeddingChunks = new ArrayList(); for (int i = 0; i < chunks.size(); i++) { - embeddingChunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunks.get(i), all.get(i).tokens())); + embeddingChunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + all.get(i).tokens(), + chunks.chunkText(i), + new ChunkedInference.TextOffset(chunks.offsets().get(i).start(), chunks.offsets().get(i).end()) + ) + ); } - return new InferenceChunkedSparseEmbeddingResults(embeddingChunks); + return new ChunkedInferenceEmbeddingSparse(embeddingChunks); } public record BatchRequest(List subBatches) { @@ -460,5 +472,13 @@ record ChunkOffsetsAndInput(List offsets, String input) { List toChunkText() { return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList()); } + + int size() { + return offsets.size(); + } + + String chunkText(int index) { + return input.substring(offsets.get(index).start(), offsets.get(index).end()); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 0f26f6577860f..d651729dee259 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -13,7 +13,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -31,7 +31,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -70,7 +69,7 @@ public record SemanticTextField(String fieldName, List originalValues, I public record InferenceResult(String inferenceId, ModelSettings modelSettings, List chunks) {} - public record Chunk(String text, BytesReference rawEmbeddings) {} + record Chunk(String text, BytesReference rawEmbeddings) {} public record ModelSettings( TaskType taskType, @@ -307,13 +306,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } /** - * Converts the provided {@link ChunkedInferenceServiceResults} into a list of {@link Chunk}. + * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}. */ - public static List toSemanticTextFieldChunks(List results, XContentType contentType) { + public static List toSemanticTextFieldChunks(List results, XContentType contentType) throws IOException { List chunks = new ArrayList<>(); for (var result : results) { - for (Iterator it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it - .hasNext();) { + for (var it = result.chunksAsMatchedTextAndByteReference(contentType.xContent()); it.hasNext();) { var chunkAsByteReference = it.next(); chunks.add(new Chunk(chunkAsByteReference.matchedText(), chunkAsByteReference.bytesReference())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index f9890c62a749e..0139d14e89637 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -11,7 +11,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -74,7 +74,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { init(); chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); @@ -88,7 +88,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { init(); // a non-null query is not supported and is dropped by all providers @@ -110,7 +110,7 @@ protected abstract void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ); public void start(Model model, ActionListener listener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 2637d9755bd55..64eba16492aec 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -289,7 +289,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AlibabaCloudSearchModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 48b3c3df03e11..e9b646d7cbf39 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -16,7 +16,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -114,7 +114,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index b3d503de8e3eb..9983a9b7927bc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -220,7 +220,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { throw new UnsupportedOperationException("Anthropic service does not support chunked inference"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index bba331fc0b5df..c3fb92a31ad03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -107,7 +107,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 16c94dfa9ad94..1be29d0789545 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -261,7 +261,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof AzureOpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index b3d8b3b6efce3..12f2bd2a9c00d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -260,7 +260,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof CohereModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 8acef40840636..32175511500b0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -29,8 +29,8 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator; @@ -101,7 +101,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { // Pass-through without actually performing chunking (result will have a single chunk per input) ActionListener inferListener = listener.delegateFailureAndWrap( @@ -253,14 +253,12 @@ public void checkModelConfig(Model model, ActionListener listener) { } } - private static List translateToChunkedResults( - InferenceInputs inputs, - InferenceServiceResults inferenceResults - ) { + private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - return InferenceChunkedSparseEmbeddingResults.listOf(DocumentsOnlyInput.of(inputs).getInputs(), sparseEmbeddingResults); + var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); + return ChunkedInferenceEmbeddingSparse.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); + return List.of(new ChunkedInferenceError(error.getException())); } else { String expectedClass = Strings.format("%s", SparseEmbeddingResults.class.getSimpleName()); throw createInvalidChunkedResultException(expectedClass, inferenceResults.getWriteableName()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2d60e7343f762..5c6edbdf60778 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceResults; @@ -687,7 +687,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); } @@ -700,7 +700,7 @@ public void chunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if ((TaskType.TEXT_EMBEDDING.equals(model.getTaskType()) || TaskType.SPARSE_EMBEDDING.equals(model.getTaskType())) == false) { listener.onFailure( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 57a8a66a3f3a6..15594676f6f41 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -315,7 +315,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 857d475499aae..dbb3cd4b54b4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -213,7 +213,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 51cca72f26054..027dc5b8fe73c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -116,7 +116,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof HuggingFaceModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 75920efa251f2..1ed240d5d37af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -28,9 +28,9 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; @@ -43,12 +43,14 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; public class HuggingFaceElserService extends HuggingFaceBaseService { @@ -88,7 +90,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { ActionListener inferListener = listener.delegateFailureAndWrap( (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) @@ -98,16 +100,31 @@ protected void doChunkedInfer( doInfer(model, inputs, taskSettings, inputType, timeout, inferListener); } - private static List translateToChunkedResults( - DocumentsOnlyInput inputs, - InferenceServiceResults inferenceResults - ) { + private static List translateToChunkedResults(DocumentsOnlyInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { - return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs.getInputs(), textEmbeddingResults); + validateInputSizeAgainstEmbeddings(inputs.getInputs(), textEmbeddingResults.embeddings().size()); + + var results = new ArrayList(inputs.getInputs().size()); + + for (int i = 0; i < inputs.getInputs().size(); i++) { + results.add( + new ChunkedInferenceEmbeddingFloat( + List.of( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + textEmbeddingResults.embeddings().get(i).values(), + inputs.getInputs().get(i), + new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) + ) + ) + ) + ); + } + return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - return InferenceChunkedSparseEmbeddingResults.listOf(inputs.getInputs(), sparseEmbeddingResults); + var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); + return ChunkedInferenceEmbeddingSparse.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); + return List.of(new ChunkedInferenceError(error.getException())); } else { String expectedClasses = Strings.format( "One of [%s,%s]", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 981a3e95808ef..8bb1a1837de0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -283,7 +283,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index fe0edb851902b..5bd3cd0b56c4f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -95,7 +95,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 20ff1c617d21f..ec673164a260b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -264,7 +264,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { if (model instanceof OpenAiModel == false) { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index c68a629b999c5..0b7d136ffb04c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -21,7 +21,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; @@ -34,8 +34,8 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -57,9 +57,9 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult; import static org.hamcrest.Matchers.containsString; @@ -160,8 +160,8 @@ public void testItemFailures() throws Exception { Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10) ); - model.putResult("I am a failure", new ErrorChunkedInferenceResults(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomSparseEmbeddings(List.of("I am a success"))); + model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); + model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -290,10 +290,9 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; List inputs = (List) invocationOnMock.getArguments()[2]; - ActionListener> listener = (ActionListener< - List>) invocationOnMock.getArguments()[6]; + ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { - List results = new ArrayList<>(); + List results = new ArrayList<>(); for (String input : inputs) { results.add(model.getResults(input)); } @@ -348,7 +347,7 @@ private static BulkItemRequest[] randomBulkItemRequest( // This prevents a situation where embeddings in the expected docMap do not match those in the model, which could happen if // embeddings were overwritten. if (model.hasResult(inputText)) { - ChunkedInferenceServiceResults results = model.getResults(inputText); + var results = model.getResults(inputText); semanticTextField = semanticTextFieldFromChunkedInferenceResults( field, model, @@ -371,7 +370,7 @@ private static BulkItemRequest[] randomBulkItemRequest( } private static class StaticModel extends TestModel { - private final Map resultMap; + private final Map resultMap; StaticModel( String inferenceEntityId, @@ -397,11 +396,11 @@ public static StaticModel createRandomInstance() { ); } - ChunkedInferenceServiceResults getResults(String text) { - return resultMap.getOrDefault(text, new InferenceChunkedSparseEmbeddingResults(List.of())); + ChunkedInference getResults(String text) { + return resultMap.getOrDefault(text, new ChunkedInferenceEmbeddingSparse(List.of())); } - void putResult(String text, ChunkedInferenceServiceResults result) { + void putResult(String text, ChunkedInference result) { resultMap.put(text, result); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index dec7d15760aa6..03249163c7f82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -8,12 +8,12 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -313,16 +313,16 @@ public void testMergingListener_Float() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); @@ -333,15 +333,15 @@ public void testMergingListener_Float() { } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var chunkedFloatResult = (InferenceChunkedTextEmbeddingFloatResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var chunkedFloatResult = (ChunkedInferenceEmbeddingFloat) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText()); } @@ -386,16 +386,16 @@ public void testMergingListener_Byte() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("1st small", chunkedByteResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); assertThat(chunkedByteResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); assertThat(chunkedByteResult.chunks().get(1).matchedText(), startsWith(" passage_input20 ")); @@ -406,15 +406,15 @@ public void testMergingListener_Byte() { } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("2nd small", chunkedByteResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var chunkedByteResult = (InferenceChunkedTextEmbeddingByteResults) chunkedResult; + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingByte.class)); + var chunkedByteResult = (ChunkedInferenceEmbeddingByte) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); assertEquals("3rd small", chunkedByteResult.chunks().get(0).matchedText()); } @@ -466,34 +466,34 @@ public void testMergingListener_Sparse() { assertThat(finalListener.results, hasSize(4)); { var chunkedResult = finalListener.results.get(0); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("1st small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("1st small", chunkedSparseResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(1); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("2nd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("2nd small", chunkedSparseResult.chunks().get(0).matchedText()); } { var chunkedResult = finalListener.results.get(2); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); - assertEquals("3rd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(1)); + assertEquals("3rd small", chunkedSparseResult.chunks().get(0).matchedText()); } { // this is the large input split in multiple chunks var chunkedResult = finalListener.results.get(3); - assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; - assertThat(chunkedSparseResult.getChunkedResults(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(chunkedSparseResult.getChunkedResults().get(0).matchedText(), startsWith("passage_input0 ")); - assertThat(chunkedSparseResult.getChunkedResults().get(1).matchedText(), startsWith(" passage_input10 ")); - assertThat(chunkedSparseResult.getChunkedResults().get(8).matchedText(), startsWith(" passage_input80 ")); + assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var chunkedSparseResult = (ChunkedInferenceEmbeddingSparse) chunkedResult; + assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each + assertThat(chunkedSparseResult.chunks().get(0).matchedText(), startsWith("passage_input0 ")); + assertThat(chunkedSparseResult.chunks().get(1).matchedText(), startsWith(" passage_input10 ")); + assertThat(chunkedSparseResult.chunks().get(8).matchedText(), startsWith(" passage_input80 ")); } } @@ -501,13 +501,13 @@ public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of("1st small", "2nd small", "3rd small"); var failureMessage = new AtomicReference(); - var listener = new ActionListener>() { + var listener = new ActionListener>() { @Override - public void onResponse(List chunkedInferenceServiceResults) { - assertThat(chunkedInferenceServiceResults.get(0), instanceOf(ErrorChunkedInferenceResults.class)); - var error = (ErrorChunkedInferenceResults) chunkedInferenceServiceResults.get(0); - failureMessage.set(error.getException().getMessage()); + public void onResponse(List chunkedResults) { + assertThat(chunkedResults.get(0), instanceOf(ChunkedInferenceError.class)); + var error = (ChunkedInferenceError) chunkedResults.get(0); + failureMessage.set(error.exception().getMessage()); } @Override @@ -531,12 +531,12 @@ private ChunkedResultsListener testListener() { return new ChunkedResultsListener(); } - private static class ChunkedResultsListener implements ActionListener> { - List results; + private static class ChunkedResultsListener implements ActionListener> { + List results; @Override - public void onResponse(List chunkedInferenceServiceResults) { - this.results = chunkedInferenceServiceResults; + public void onResponse(List chunks) { + this.results = chunks; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 563093930c358..dcdd9b3d42341 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -19,9 +19,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; import org.elasticsearch.xpack.inference.model.TestModel; @@ -158,38 +157,39 @@ public void testModelSettingsValidation() { assertThat(ex.getMessage(), containsString("required [element_type] field is missing")); } - public static InferenceChunkedTextEmbeddingFloatResults randomInferenceChunkedTextEmbeddingFloatResults( - Model model, - List inputs - ) throws IOException { - List chunks = new ArrayList<>(); + public static ChunkedInferenceEmbeddingFloat randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { + List chunks = new ArrayList<>(); for (String input : inputs) { float[] values = new float[model.getServiceSettings().dimensions()]; for (int j = 0; j < values.length; j++) { values[j] = (float) randomDouble(); } - chunks.add(new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(input, values)); + chunks.add( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk(values, input, new ChunkedInference.TextOffset(0, input.length())) + ); } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); + return new ChunkedInferenceEmbeddingFloat(chunks); } - public static InferenceChunkedSparseEmbeddingResults randomSparseEmbeddings(List inputs) { - List chunks = new ArrayList<>(); + public static ChunkedInferenceEmbeddingSparse randomChunkedInferenceEmbeddingSparse(List inputs) { + List chunks = new ArrayList<>(); for (String input : inputs) { var tokens = new ArrayList(); for (var token : input.split("\\s+")) { tokens.add(new WeightedToken(token, randomFloat())); } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(input, tokens)); + chunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk(tokens, input, new ChunkedInference.TextOffset(0, input.length())) + ); } - return new InferenceChunkedSparseEmbeddingResults(chunks); + return new ChunkedInferenceEmbeddingSparse(chunks); } public static SemanticTextField randomSemanticText(String fieldName, Model model, List inputs, XContentType contentType) throws IOException { - ChunkedInferenceServiceResults results = switch (model.getTaskType()) { - case TEXT_EMBEDDING -> randomInferenceChunkedTextEmbeddingFloatResults(model, inputs); - case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs); + ChunkedInference results = switch (model.getTaskType()) { + case TEXT_EMBEDDING -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; return semanticTextFieldFromChunkedInferenceResults(fieldName, model, inputs, results, contentType); @@ -199,9 +199,9 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( String fieldName, Model model, List inputs, - ChunkedInferenceServiceResults results, + ChunkedInference results, XContentType contentType - ) { + ) throws IOException { return new SemanticTextField( fieldName, inputs, @@ -232,18 +232,24 @@ public static Object randomSemanticTextInput() { } } - public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField field) throws IOException { + public static ChunkedInference toChunkedResult(SemanticTextField field) throws IOException { switch (field.inference().modelSettings().taskType()) { case SPARSE_EMBEDDING -> { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var chunk : field.inference().chunks()) { var tokens = parseWeightedTokens(chunk.rawEmbeddings(), field.contentType()); - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunk.text(), tokens)); + chunks.add( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + tokens, + chunk.text(), + new ChunkedInference.TextOffset(0, chunk.text().length()) + ) + ); } - return new InferenceChunkedSparseEmbeddingResults(chunks); + return new ChunkedInferenceEmbeddingSparse(chunks); } case TEXT_EMBEDDING -> { - List chunks = new ArrayList<>(); + List chunks = new ArrayList<>(); for (var chunk : field.inference().chunks()) { double[] values = parseDenseVector( chunk.rawEmbeddings(), @@ -251,13 +257,14 @@ public static ChunkedInferenceServiceResults toChunkedResult(SemanticTextField f field.contentType() ); chunks.add( - new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk( + new ChunkedInferenceEmbeddingFloat.FloatEmbeddingChunk( + FloatConversionUtils.floatArrayOf(values), chunk.text(), - FloatConversionUtils.floatArrayOf(values) + new ChunkedInference.TextOffset(0, chunk.text().length()) ) ); } - return new InferenceChunkedTextEmbeddingFloatResults(chunks); + return new ChunkedInferenceEmbeddingFloat(chunks); } default -> throw new AssertionError("Invalid task_type: " + field.inference().modelSettings().taskType().name()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java deleted file mode 100644 index 4be00ea9e5822..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; - -import java.io.IOException; - -public class ErrorChunkedInferenceResultsTests extends AbstractWireSerializingTestCase { - - public static ErrorChunkedInferenceResults createRandomResults() { - return new ErrorChunkedInferenceResults( - randomBoolean() - ? new ElasticsearchTimeoutException(randomAlphaOfLengthBetween(10, 50)) - : new ElasticsearchStatusException(randomAlphaOfLengthBetween(10, 50), randomFrom(RestStatus.values())) - ); - } - - @Override - protected Writeable.Reader instanceReader() { - return ErrorChunkedInferenceResults::new; - } - - @Override - protected ErrorChunkedInferenceResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected ErrorChunkedInferenceResults mutateInstance(ErrorChunkedInferenceResults instance) throws IOException { - return new ErrorChunkedInferenceResults(new RuntimeException(randomAlphaOfLengthBetween(10, 50))); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java deleted file mode 100644 index 8685ad9f0e124..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedSparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { - - public static InferenceChunkedSparseEmbeddingResults createRandomResults() { - var chunks = new ArrayList(); - int numChunks = randomIntBetween(1, 5); - - for (int i = 0; i < numChunks; i++) { - var tokenWeights = new ArrayList(); - int numTokens = randomIntBetween(1, 8); - for (int j = 0; j < numTokens; j++) { - tokenWeights.add(new WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false))); - } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights)); - } - - return new InferenceChunkedSparseEmbeddingResults(chunks); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult("text", List.of(new WeightedToken("token", 0.1f)))) - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_FromSparseEmbeddingResults() { - var entity = InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text", "text2"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedSparseEmbeddingResults::new; - } - - @Override - protected InferenceChunkedSparseEmbeddingResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedSparseEmbeddingResults mutateInstance(InferenceChunkedSparseEmbeddingResults instance) throws IOException { - return randomValueOtherThan(instance, InferenceChunkedSparseEmbeddingResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java deleted file mode 100644 index c1215e8a3d71b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase< - InferenceChunkedTextEmbeddingByteResults> { - - public static InferenceChunkedTextEmbeddingByteResults createRandomResults() { - int numChunks = randomIntBetween(1, 5); - var chunks = new ArrayList(numChunks); - - for (int i = 0; i < numChunks; i++) { - chunks.add(createRandomChunk()); - } - - return new InferenceChunkedTextEmbeddingByteResults(chunks, randomBoolean()); - } - - private static InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk createRandomChunk() { - int columns = randomIntBetween(1, 10); - byte[] bytes = new byte[columns]; - for (int i = 0; i < columns; i++) { - bytes[i] = randomByte(); - } - - return new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(randomAlphaOfLength(6), bytes); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedTextEmbeddingByteResults( - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })), - false - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_ForTextEmbeddingByteResults() { - var entity = InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text", "text2"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedTextEmbeddingByteResults::new; - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults mutateInstance(InferenceChunkedTextEmbeddingByteResults instance) - throws IOException { - return randomValueOtherThan(instance, InferenceChunkedTextEmbeddingByteResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 47a96bf78dda1..06aaad4e73c57 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -126,7 +126,7 @@ protected void doChunkedInfer( Map taskSettings, InputType inputType, TimeValue timeout, - ActionListener> listener + ActionListener> listener ) { } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index a154ded395822..46c3a062f7db0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -29,8 +29,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -392,7 +392,7 @@ public void testChunkedInfer_InvalidTaskType() throws IOException { null ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); try { service.chunkedInfer( model, @@ -417,7 +417,7 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { var model = createModelForTaskType(taskType, chunkingSettings); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -425,9 +425,9 @@ private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettin assertThat(results, hasSize(2)); var firstResult = results.get(0); if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - assertThat(firstResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + assertThat(firstResult, instanceOf(ChunkedInferenceEmbeddingFloat.class)); } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { - assertThat(firstResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + assertThat(firstResult, instanceOf(ChunkedInferenceEmbeddingSparse.class)); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 197606df02a1f..80c2b672a8feb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; @@ -1551,7 +1551,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep requestSender.enqueue(mockResults2); } - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc", "xyz"), @@ -1564,15 +1564,15 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("abc", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123F, 0.678F }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("xyz", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.223F, 0.278F }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 08fc097a56f40..a9ef4bd551175 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1186,7 +1186,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("foo", "bar"), @@ -1199,15 +1199,15 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 1.0123f, -1.0123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index cc68d54b11e91..ac8e769ef13a3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1335,7 +1335,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); model.setUri(new URI(getUrl(webServer))); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("foo", "bar"), @@ -1348,15 +1348,15 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 1.123f, -1.123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index a8d1a1ec28d09..e207bcfdeada5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -36,8 +36,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingByte; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -1442,7 +1442,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); // 2 inputs service.chunkedInfer( model, @@ -1456,15 +1456,15 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding(), 0.0f); @@ -1532,7 +1532,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { "model", CohereEmbeddingType.BYTE ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); // 2 inputs service.chunkedInfer( model, @@ -1546,15 +1546,15 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingByteResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingByte.class)); + var floatResult = (ChunkedInferenceEmbeddingByte) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new byte[] { 23, -23 }, floatResult.chunks().get(0).embedding()); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingByteResults.class)); - var byteResult = (InferenceChunkedTextEmbeddingByteResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingByte.class)); + var byteResult = (ChunkedInferenceEmbeddingByte) results.get(1); assertThat(byteResult.chunks(), hasSize(1)); assertEquals("bar", byteResult.chunks().get(0).matchedText()); assertArrayEquals(new byte[] { 24, -24 }, byteResult.chunks().get(0).embedding()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index dae99cea77ec4..11dc7206d959a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -31,8 +31,8 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -453,7 +453,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(eisGatewayUrl); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("input text"), @@ -464,22 +464,21 @@ public void testChunkedInfer_PassesThrough() throws IOException { ); var results = listener.actionGet(TIMEOUT); - MatcherAssert.assertThat( - results.get(0).asMap(), - Matchers.is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of( - Map.of( - ChunkedNlpInferenceResults.TEXT, - "input text", - ChunkedNlpInferenceResults.INFERENCE, - Map.of("hello", 2.1259406f, "greet", 1.7073475f) - ) + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var sparseResult = (ChunkedInferenceEmbeddingSparse) results.get(0); + assertThat( + sparseResult.chunks(), + is( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), + "input text", + new ChunkedInference.TextOffset(0, "input text".length()) ) ) ) ); + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 17e6583f11c8f..21d7efbc7b03c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -24,7 +24,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceResults; @@ -42,9 +42,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -865,26 +865,26 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var result1 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var result1 = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(0); assertThat(result1.chunks(), hasSize(1)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), - result1.getChunks().get(0).embedding(), + result1.chunks().get(0).embedding(), 0.0001f ); - assertEquals("foo", result1.getChunks().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var result2 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var result2 = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(1); assertThat(result2.chunks(), hasSize(1)); assertArrayEquals( ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), - result2.getChunks().get(0).embedding(), + result2.chunks().get(0).embedding(), 0.0001f ); - assertEquals("bar", result2.getChunks().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -940,22 +940,22 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result1 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(0); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - result1.getChunkedResults().get(0).weightedTokens() + result1.chunks().get(0).weightedTokens() ); - assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result2 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(1); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - result2.getChunkedResults().get(0).weightedTokens() + result2.chunks().get(0).weightedTokens() ); - assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -1010,22 +1010,22 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(2)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result1 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(0); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), - result1.getChunkedResults().get(0).weightedTokens() + result1.chunks().get(0).weightedTokens() ); - assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals("foo", result1.chunks().get(0).matchedText()); + assertThat(chunkedResponse.get(1), instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var result2 = (ChunkedInferenceEmbeddingSparse) chunkedResponse.get(1); assertEquals( ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), - result2.getChunkedResults().get(0).weightedTokens() + result2.chunks().get(0).weightedTokens() ); - assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); + assertEquals("bar", result2.chunks().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -1126,12 +1126,12 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(3)); // a single failure fails the batch for (var er : chunkedResponse) { - assertThat(er, instanceOf(ErrorChunkedInferenceResults.class)); - assertEquals("boom", ((ErrorChunkedInferenceResults) er).getException().getMessage()); + assertThat(er, instanceOf(ChunkedInferenceError.class)); + assertEquals("boom", ((ChunkedInferenceError) er).exception().getMessage()); } gotResults.set(true); @@ -1190,10 +1190,10 @@ public void testChunkingLargeDocument() throws InterruptedException { var service = createService(client); var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { + var resultsListener = ActionListener.>wrap(chunkedResponse -> { assertThat(chunkedResponse, hasSize(1)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); + assertThat(chunkedResponse.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var sparseResults = (ChunkedInferenceEmbeddingFloat) chunkedResponse.get(0); assertThat(sparseResults.chunks(), hasSize(numChunks)); gotResults.set(true); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 0e2f4847c88ee..ea82c09eef1e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -36,7 +36,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -868,7 +868,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -876,8 +876,8 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed // first result { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); @@ -885,8 +885,8 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed // second result { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index c70692eb29a27..64f86a0d0f280 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -24,14 +24,13 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -39,7 +38,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; @@ -49,6 +47,8 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; @@ -90,7 +90,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc"), @@ -101,14 +101,16 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE ); var result = listener.actionGet(TIMEOUT).get(0); - - MatcherAssert.assertThat( - result.asMap(), - Matchers.is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of( - Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, Map.of(".", 0.13315596f)) + assertThat(result, instanceOf(ChunkedInferenceEmbeddingSparse.class)); + var sparseResult = (ChunkedInferenceEmbeddingSparse) result; + assertThat( + sparseResult.chunks(), + is( + List.of( + new ChunkedInferenceEmbeddingSparse.SparseEmbeddingChunk( + List.of(new WeightedToken(".", 0.13315596f)), + "abc", + new ChunkedInference.TextOffset(0, "abc".length()) ) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 3e5e2d7c12074..f3d7cbfea38dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,8 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -46,7 +45,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -59,7 +57,6 @@ import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; -import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -774,7 +771,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc"), @@ -785,19 +782,12 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th ); var result = listener.actionGet(TIMEOUT).get(0); - assertThat(result, CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - - MatcherAssert.assertThat( - asMapWithListsInsteadOfArrays((InferenceChunkedTextEmbeddingFloatResults) result), - Matchers.is( - Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - List.of( - Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, List.of(-0.0123f, 0.0123f)) - ) - ) - ) - ); + assertThat(result, CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var embeddingResult = (ChunkedInferenceEmbeddingFloat) result; + assertThat(embeddingResult.chunks(), hasSize(1)); + assertThat(embeddingResult.chunks().get(0).matchedText(), is("abc")); + assertThat(embeddingResult.chunks().get(0).offset(), is(new ChunkedInference.TextOffset(0, "abc".length()))); + assertArrayEquals(new float[] { -0.0123f, 0.0123f }, embeddingResult.chunks().get(0).embedding(), 0.001f); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( @@ -828,7 +818,7 @@ public void testChunkedInfer() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret"); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc"), @@ -841,8 +831,8 @@ public void testChunkedInfer() throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(1)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("abc", floatResult.chunks().get(0).matchedText()); assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 5aa826f1d80fe..3d298823ea19f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -35,7 +35,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ibmwatsonx.IbmWatsonxActionCreator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -684,7 +684,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws apiKey, getUrl(webServer) ); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener); var results = listener.actionGet(TIMEOUT); @@ -692,8 +692,8 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws // first result { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); @@ -701,8 +701,8 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws // second result { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index e0cfa4a5ca4be..c547531ec1289 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -18,7 +18,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,7 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -665,7 +665,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("abc", "def"), @@ -679,14 +679,14 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 4672bc28b2bf0..915015d43ba2d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -34,7 +34,7 @@ import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -1550,7 +1550,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - PlainActionFuture> listener = new PlainActionFuture<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, List.of("foo", "bar"), @@ -1563,15 +1563,15 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { var results = listener.actionGet(TIMEOUT); assertThat(results, hasSize(2)); { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("foo", floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); } { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbeddingFloat.class)); + var floatResult = (ChunkedInferenceEmbeddingFloat) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); assertEquals("bar", floatResult.chunks().get(0).matchedText()); assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding()));