From 316447da99156bff718ee234804e6373f76a433f Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Thu, 8 Feb 2024 09:12:08 -0500 Subject: [PATCH] Moving byte embeddings to text_embedding_bytes field --- .../results/TextEmbeddingByteResults.java | 9 ++++----- .../results/TextEmbeddingByteResultsTests.java | 14 +++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java index 4ffef36359589..c29434d0f1c59 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingByteResults.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -43,7 +42,7 @@ */ public record TextEmbeddingByteResults(List embeddings) implements InferenceServiceResults, TextEmbedding { public static final String NAME = "text_embedding_service_byte_results"; - public static final String TEXT_EMBEDDING = TaskType.TEXT_EMBEDDING.toString(); + public static final String TEXT_EMBEDDING_BYTES = "text_embedding_bytes"; public TextEmbeddingByteResults(StreamInput in) throws IOException { this(in.readCollectionAsList(Embedding::new)); @@ -56,7 +55,7 @@ public int getFirstEmbeddingSize() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startArray(TEXT_EMBEDDING); + builder.startArray(TEXT_EMBEDDING_BYTES); for (Embedding embedding : embeddings) { embedding.toXContent(builder, params); } @@ -78,7 +77,7 @@ public String getWriteableName() { public List transformToCoordinationFormat() { return embeddings.stream() .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) - .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) + .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING_BYTES, values, false)) .toList(); } @@ -94,7 +93,7 @@ public List transformToLegacyFormat() { public Map asMap() { Map map = new LinkedHashMap<>(); - map.put(TEXT_EMBEDDING, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); + map.put(TEXT_EMBEDDING_BYTES, embeddings.stream().map(Embedding::asMap).collect(Collectors.toList())); return map; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java index b9318db6ece34..f12865a9a5db8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java @@ -49,7 +49,7 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE entity.asMap(), is( Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING, + TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, List.of(Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23))) ) ) @@ -58,7 +58,7 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" { - "text_embedding" : [ + "text_embedding_bytes" : [ { "embedding" : [ 23 @@ -78,7 +78,7 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I entity.asMap(), is( Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING, + TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, List.of( Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 23)), Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, List.of((byte) 24)) @@ -90,7 +90,7 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" { - "text_embedding" : [ + "text_embedding_bytes" : [ { "embedding" : [ 23 @@ -118,12 +118,12 @@ public void testTransformToCoordinationFormat() { is( List.of( new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( - TextEmbeddingByteResults.TEXT_EMBEDDING, + TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 23F, 24F }, false ), new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( - TextEmbeddingByteResults.TEXT_EMBEDDING, + TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, new double[] { 25F, 26F }, false ) @@ -158,7 +158,7 @@ protected TextEmbeddingByteResults mutateInstance(TextEmbeddingByteResults insta public static Map buildExpectation(List> embeddings) { return Map.of( - TextEmbeddingByteResults.TEXT_EMBEDDING, + TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, embeddings.stream().map(embedding -> Map.of(TextEmbeddingByteResults.Embedding.EMBEDDING, embedding)).toList() ); }