diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java deleted file mode 100644 index 10e00e9860200..0000000000000 --- a/server/src/main/java/org/elasticsearch/inference/ChunkedInferenceServiceResults.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.inference; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.xcontent.XContent; - -import java.util.Iterator; - -public interface ChunkedInferenceServiceResults extends InferenceServiceResults { - - /** - * Implementations of this function serialize their embeddings to {@link BytesReference} for storage in semantic text fields. - * The iterator iterates over all the chunks stored in the {@link ChunkedInferenceServiceResults}. - * - * @param xcontent provided by the SemanticTextField - * @return an iterator of the serialized {@link Chunk} which includes the matched text (input) and bytes reference (output/embedding). - */ - Iterator chunksAsMatchedTextAndByteReference(XContent xcontent); - - /** - * A chunk of inference results containing matched text and the bytes reference. - * @param matchedText - * @param bytesReference - */ - record Chunk(String matchedText, BytesReference bytesReference) {} -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java deleted file mode 100644 index 18f88a8ff022a..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ErrorChunkedInferenceResults.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; - -import java.io.IOException; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Stream; - -public class ErrorChunkedInferenceResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "error_chunked"; - - private final Exception exception; - - public ErrorChunkedInferenceResults(Exception exception) { - this.exception = Objects.requireNonNull(exception); - } - - public ErrorChunkedInferenceResults(StreamInput in) throws IOException { - this.exception = in.readException(); - } - - public Exception getException() { - return exception; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeException(exception); - } - - @Override - public boolean equals(Object object) { - if (object == this) { - return true; - } - if (object == null || getClass() != object.getClass()) { - return false; - } - ErrorChunkedInferenceResults that = (ErrorChunkedInferenceResults) object; - // Just compare the message for serialization test purposes - return Objects.equals(exception.getMessage(), that.exception.getMessage()); - } - - @Override - public int hashCode() { - // Just compare the message for serialization test purposes - return Objects.hash(exception.getMessage()); - } - - @Override - public List transformToCoordinationFormat() { - return null; - } - - @Override - public List transformToLegacyFormat() { - return null; - } - - @Override - public Map asMap() { - Map asMap = new LinkedHashMap<>(); - asMap.put(NAME, exception.getMessage()); - return asMap; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContentHelper.field(NAME, exception.getMessage()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return Stream.of(exception).map(e -> new Chunk(e.getMessage(), BytesArray.EMPTY)).iterator(); - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java deleted file mode 100644 index c961050acefdb..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedSparseEmbeddingResults.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public class InferenceChunkedSparseEmbeddingResults implements ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_sparse_embedding_results"; - public static final String FIELD_NAME = "sparse_embedding_chunk"; - - public static InferenceChunkedSparseEmbeddingResults ofMlResult(MlChunkedTextExpansionResults mlInferenceResults) { - return new InferenceChunkedSparseEmbeddingResults(mlInferenceResults.getChunks()); - } - - /** - * Returns a list of {@link InferenceChunkedSparseEmbeddingResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedSparseEmbeddingResults} will have a single chunk containing the entire results from the - * {@link SparseEmbeddingResults}. - */ - public static List listOf(List inputs, SparseEmbeddingResults sparseEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs, sparseEmbeddingResults.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), sparseEmbeddingResults.embeddings().get(i))); - } - - return results; - } - - private static InferenceChunkedSparseEmbeddingResults ofSingle(String input, SparseEmbeddingResults.Embedding embedding) { - var weightedTokens = embedding.tokens() - .stream() - .map(weightedToken -> new WeightedToken(weightedToken.token(), weightedToken.weight())) - .toList(); - - return new InferenceChunkedSparseEmbeddingResults(List.of(new MlChunkedTextExpansionResults.ChunkedResult(input, weightedTokens))); - } - - private final List chunkedResults; - - public InferenceChunkedSparseEmbeddingResults(List chunks) { - this.chunkedResults = chunks; - } - - public InferenceChunkedSparseEmbeddingResults(StreamInput in) throws IOException { - this.chunkedResults = in.readCollectionAsList(MlChunkedTextExpansionResults.ChunkedResult::new); - } - - public List getChunkedResults() { - return chunkedResults; - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunkedResults.iterator()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunkedResults); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordindated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of( - FIELD_NAME, - chunkedResults.stream().map(MlChunkedTextExpansionResults.ChunkedResult::asMap).collect(Collectors.toList()) - ); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedSparseEmbeddingResults that = (InferenceChunkedSparseEmbeddingResults) o; - return Objects.equals(chunkedResults, that.chunkedResults); - } - - @Override - public int hashCode() { - return Objects.hash(chunkedResults); - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunkedResults.stream() - .map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.weightedTokens()))) - .iterator(); - } - - /** - * Serialises the {@link WeightedToken} list, according to the provided {@link XContent}, - * into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, List tokens) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startObject(); - for (var weightedToken : tokens) { - weightedToken.toXContent(b, ToXContent.EMPTY_PARAMS); - } - b.endObject(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java deleted file mode 100644 index 6bd66664068d5..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingByteResults.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingByteResults(List chunks, boolean isTruncated) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_byte_results"; - public static final String FIELD_NAME = "text_embedding_byte_chunk"; - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingByteResults}. The number of entries in the list will match the input list size. - * Each {@link InferenceChunkedTextEmbeddingByteResults} will have a single chunk containing the entire results from the - * {@link InferenceTextEmbeddingByteResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingByteResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - results.add(ofSingle(inputs.get(i), textEmbeddings.embeddings().get(i).values())); - } - - return results; - } - - private static InferenceChunkedTextEmbeddingByteResults ofSingle(String input, byte[] byteEmbeddings) { - return new InferenceChunkedTextEmbeddingByteResults(List.of(new InferenceByteEmbeddingChunk(input, byteEmbeddings)), false); - } - - public InferenceChunkedTextEmbeddingByteResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceByteEmbeddingChunk::new), in.readBoolean()); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - out.writeBoolean(isTruncated); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingByteResults that = (InferenceChunkedTextEmbeddingByteResults) o; - return isTruncated == that.isTruncated && Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks, isTruncated); - } - - public record InferenceByteEmbeddingChunk(String matchedText, byte[] embedding) implements Writeable, ToXContentObject { - - public InferenceByteEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeByteArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (byte value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceByteEmbeddingChunk that = (InferenceByteEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - private static BytesReference toBytesReference(XContent xContent, byte[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (byte v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java deleted file mode 100644 index 369f22a807913..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResults.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.utils.FloatConversionUtils; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils.validateInputSizeAgainstEmbeddings; - -public record InferenceChunkedTextEmbeddingFloatResults(List chunks) - implements - ChunkedInferenceServiceResults { - - public static final String NAME = "chunked_text_embedding_service_float_results"; - public static final String FIELD_NAME = "text_embedding_float_chunk"; - - public InferenceChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException { - this(in.readCollectionAsList(InferenceFloatEmbeddingChunk::new)); - } - - /** - * Returns a list of {@link InferenceChunkedTextEmbeddingFloatResults}. - * Each {@link InferenceChunkedTextEmbeddingFloatResults} contain a single chunk with the text and the - * {@link InferenceTextEmbeddingFloatResults}. - */ - public static List listOf(List inputs, InferenceTextEmbeddingFloatResults textEmbeddings) { - validateInputSizeAgainstEmbeddings(inputs, textEmbeddings.embeddings().size()); - - var results = new ArrayList(inputs.size()); - - for (int i = 0; i < inputs.size(); i++) { - results.add( - new InferenceChunkedTextEmbeddingFloatResults( - List.of(new InferenceFloatEmbeddingChunk(inputs.get(i), textEmbeddings.embeddings().get(i).values())) - ) - ); - } - - return results; - } - - public static InferenceChunkedTextEmbeddingFloatResults ofMlResults(MlChunkedTextEmbeddingFloatResults mlInferenceResult) { - return new InferenceChunkedTextEmbeddingFloatResults( - mlInferenceResult.getChunks() - .stream() - .map(chunk -> new InferenceFloatEmbeddingChunk(chunk.matchedText(), FloatConversionUtils.floatArrayOf(chunk.embedding()))) - .toList() - ); - } - - @Override - public Iterator toXContentChunked(ToXContent.Params params) { - // TODO add isTruncated flag - return ChunkedToXContent.builder(params).array(FIELD_NAME, chunks.iterator()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(chunks); - } - - @Override - public List transformToCoordinationFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action"); - } - - @Override - public List transformToLegacyFormat() { - throw new UnsupportedOperationException("Chunked results are not returned in the legacy format"); - } - - @Override - public Map asMap() { - return Map.of(FIELD_NAME, chunks); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public List getChunks() { - return chunks; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceChunkedTextEmbeddingFloatResults that = (InferenceChunkedTextEmbeddingFloatResults) o; - return Objects.equals(chunks, that.chunks); - } - - @Override - public int hashCode() { - return Objects.hash(chunks); - } - - public record InferenceFloatEmbeddingChunk(String matchedText, float[] embedding) implements Writeable, ToXContentObject { - - public InferenceFloatEmbeddingChunk(StreamInput in) throws IOException { - this(in.readString(), in.readFloatArray()); - } - - public static InferenceFloatEmbeddingChunk of(String matchedText, double[] doubleEmbedding) { - return new InferenceFloatEmbeddingChunk(matchedText, FloatConversionUtils.floatArrayOf(doubleEmbedding)); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(matchedText); - out.writeFloatArray(embedding); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(ChunkedNlpInferenceResults.TEXT, matchedText); - - builder.startArray(ChunkedNlpInferenceResults.INFERENCE); - for (float value : embedding) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceFloatEmbeddingChunk that = (InferenceFloatEmbeddingChunk) o; - return Objects.equals(matchedText, that.matchedText) && Arrays.equals(embedding, that.embedding); - } - - @Override - public int hashCode() { - int result = Objects.hash(matchedText); - result = 31 * result + Arrays.hashCode(embedding); - return result; - } - } - - @Override - public Iterator chunksAsMatchedTextAndByteReference(XContent xcontent) { - return chunks.stream().map(chunk -> new Chunk(chunk.matchedText(), toBytesReference(xcontent, chunk.embedding()))).iterator(); - } - - /** - * Serialises the {@code value} array, according to the provided {@link XContent}, into a {@link BytesReference}. - */ - private static BytesReference toBytesReference(XContent xContent, float[] value) { - try { - XContentBuilder b = XContentBuilder.builder(xContent); - b.startArray(); - for (float v : value) { - b.value(v); - } - b.endArray(); - return BytesReference.bytes(b); - } catch (IOException exc) { - throw new RuntimeException(exc); - } - } -} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java deleted file mode 100644 index 83678cd030bc2..0000000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/InferenceChunkedTextEmbeddingFloatResultsTests.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.core.inference.results; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.INFERENCE; -import static org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults.TEXT; - -public class InferenceChunkedTextEmbeddingFloatResultsTests extends ESTestCase { - /** - * Similar to {@link org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults#asMap()} but it converts the - * embeddings float array into a list of floats to make testing equality easier. - */ - public static Map asMapWithListsInsteadOfArrays(InferenceChunkedTextEmbeddingFloatResults result) { - return Map.of( - InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME, - result.getChunks() - .stream() - .map(InferenceChunkedTextEmbeddingFloatResultsTests::inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays) - .collect(Collectors.toList()) - ); - } - - /** - * Similar to {@link MlChunkedTextEmbeddingFloatResults.EmbeddingChunk#asMap()} but it converts the double array into a list of doubles - * to make testing equality easier. - */ - public static Map inferenceFloatEmbeddingChunkAsMapWithListsInsteadOfArrays( - InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk chunk - ) { - var chunkAsList = new ArrayList(chunk.embedding().length); - for (double embedding : chunk.embedding()) { - chunkAsList.add((float) embedding); - } - var map = new HashMap(); - map.put(TEXT, chunk.matchedText()); - map.put(INFERENCE, chunkAsList); - return map; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java deleted file mode 100644 index 4be00ea9e5822..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ErrorChunkedInferenceResultsTests.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.ElasticsearchTimeoutException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; - -import java.io.IOException; - -public class ErrorChunkedInferenceResultsTests extends AbstractWireSerializingTestCase { - - public static ErrorChunkedInferenceResults createRandomResults() { - return new ErrorChunkedInferenceResults( - randomBoolean() - ? new ElasticsearchTimeoutException(randomAlphaOfLengthBetween(10, 50)) - : new ElasticsearchStatusException(randomAlphaOfLengthBetween(10, 50), randomFrom(RestStatus.values())) - ); - } - - @Override - protected Writeable.Reader instanceReader() { - return ErrorChunkedInferenceResults::new; - } - - @Override - protected ErrorChunkedInferenceResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected ErrorChunkedInferenceResults mutateInstance(ErrorChunkedInferenceResults instance) throws IOException { - return new ErrorChunkedInferenceResults(new RuntimeException(randomAlphaOfLengthBetween(10, 50))); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java deleted file mode 100644 index 8685ad9f0e124..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedSparseEmbeddingResultsTests.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedSparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { - - public static InferenceChunkedSparseEmbeddingResults createRandomResults() { - var chunks = new ArrayList(); - int numChunks = randomIntBetween(1, 5); - - for (int i = 0; i < numChunks; i++) { - var tokenWeights = new ArrayList(); - int numTokens = randomIntBetween(1, 8); - for (int j = 0; j < numTokens; j++) { - tokenWeights.add(new WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false))); - } - chunks.add(new MlChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights)); - } - - return new InferenceChunkedSparseEmbeddingResults(chunks); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedSparseEmbeddingResults( - List.of(new MlChunkedTextExpansionResults.ChunkedResult("text", List.of(new WeightedToken("token", 0.1f)))) - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_FromSparseEmbeddingResults() { - var entity = InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedSparseEmbeddingResults.FIELD_NAME, - List.of(Map.of(ChunkedNlpInferenceResults.TEXT, "text", ChunkedNlpInferenceResults.INFERENCE, Map.of("token", 0.1f))) - ) - ) - ); - - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "sparse_embedding_chunk" : [ - { - "text" : "text", - "inference" : { - "token" : 0.1 - } - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedSparseEmbeddingResults.listOf( - List.of("text", "text2"), - new SparseEmbeddingResults(List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken("token", 0.1f)), false))) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedSparseEmbeddingResults::new; - } - - @Override - protected InferenceChunkedSparseEmbeddingResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedSparseEmbeddingResults mutateInstance(InferenceChunkedSparseEmbeddingResults instance) throws IOException { - return randomValueOtherThan(instance, InferenceChunkedSparseEmbeddingResultsTests::createRandomResults); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java deleted file mode 100644 index c1215e8a3d71b..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceChunkedTextEmbeddingByteResultsTests.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.is; - -public class InferenceChunkedTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase< - InferenceChunkedTextEmbeddingByteResults> { - - public static InferenceChunkedTextEmbeddingByteResults createRandomResults() { - int numChunks = randomIntBetween(1, 5); - var chunks = new ArrayList(numChunks); - - for (int i = 0; i < numChunks; i++) { - chunks.add(createRandomChunk()); - } - - return new InferenceChunkedTextEmbeddingByteResults(chunks, randomBoolean()); - } - - private static InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk createRandomChunk() { - int columns = randomIntBetween(1, 10); - byte[] bytes = new byte[columns]; - for (int i = 0; i < columns; i++) { - bytes[i] = randomByte(); - } - - return new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(randomAlphaOfLength(6), bytes); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk() { - var entity = new InferenceChunkedTextEmbeddingByteResults( - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })), - false - ); - - assertThat( - entity.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(entity, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_CreatesTheRightJsonForASingleChunk_ForTextEmbeddingByteResults() { - var entity = InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ); - - assertThat(entity.size(), is(1)); - - var firstEntry = entity.get(0); - - assertThat( - firstEntry.asMap(), - is( - Map.of( - InferenceChunkedTextEmbeddingByteResults.FIELD_NAME, - List.of(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk("text", new byte[] { (byte) 1 })) - ) - ) - ); - String xContentResult = Strings.toString(firstEntry, true, true); - assertThat(xContentResult, is(""" - { - "text_embedding_byte_chunk" : [ - { - "text" : "text", - "inference" : [ - 1 - ] - } - ] - }""")); - } - - public void testToXContent_ThrowsWhenInputSizeIsDifferentThanEmbeddings() { - var exception = expectThrows( - IllegalArgumentException.class, - () -> InferenceChunkedTextEmbeddingByteResults.listOf( - List.of("text", "text2"), - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 1 })) - ) - ) - ); - - assertThat(exception.getMessage(), is("The number of inputs [2] does not match the embeddings [1]")); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceChunkedTextEmbeddingByteResults::new; - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected InferenceChunkedTextEmbeddingByteResults mutateInstance(InferenceChunkedTextEmbeddingByteResults instance) - throws IOException { - return randomValueOtherThan(instance, InferenceChunkedTextEmbeddingByteResultsTests::createRandomResults); - } -}