From 8be04463e4ae5795fc3fad45f2d01314eaf81035 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Mon, 4 Dec 2023 08:32:54 -0500 Subject: [PATCH] [ML] Fix text embedding response format for TransportCoordinatedInferenceAction (#102890) * Fix for response format * Adding tests --- .../inference/InferenceServiceResults.java | 11 +++++++++ .../results/SparseEmbeddingResults.java | 5 ++++ .../results/TextEmbeddingResults.java | 8 +++++++ .../results/SparseEmbeddingResultsTests.java | 21 ++++++++++++++++ .../results/TextEmbeddingResultsTests.java | 24 +++++++++++++++++++ .../TransportCoordinatedInferenceAction.java | 2 +- 6 files changed, 70 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 37990caeec097..ab5b74faa6530 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -16,6 +16,17 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragment { + /** + * Transform the result to match the format required for the TransportCoordinatedInferenceAction. + * For the inference plugin TextEmbeddingResults, the {@link #transformToLegacyFormat()} transforms the + * results into an intermediate format only used by the plugin's return value. It doesn't align with what the + * TransportCoordinatedInferenceAction expects. TransportCoordinatedInferenceAction expects an ml plugin + * TextEmbeddingResults. + * + * For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat. + */ + List transformToCoordinationFormat(); + /** * Transform the result to match the format required for versions prior to * {@link org.elasticsearch.TransportVersions#INFERENCE_SERVICE_RESULTS_ADDED} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java index 20279e82d6c09..910ea5cab214d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java @@ -81,6 +81,11 @@ public Map asMap() { return map; } + @Override + public List transformToCoordinationFormat() { + return transformToLegacyFormat(); + } + @Override public List transformToLegacyFormat() { return embeddings.stream() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 7a7ccab2b4daa..ace5974866038 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -78,6 +78,14 @@ public String getWriteableName() { return NAME; } + @Override + public List transformToCoordinationFormat() { + return embeddings.stream() + .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) + .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) + .toList(); + } + @Override @SuppressWarnings("deprecation") public List transformToLegacyFormat() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 0a8bfd20caaf1..6f8fa0c453d09 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -11,12 +11,14 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.hamcrest.Matchers.is; public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { @@ -151,6 +153,25 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I }""")); } + public void testTransformToCoordinationFormat() { + var results = createSparseResult( + List.of( + createEmbedding(List.of(new SparseEmbeddingResults.WeightedToken("token", 0.1F)), false), + createEmbedding(List.of(new SparseEmbeddingResults.WeightedToken("token2", 0.2F)), true) + ) + ).transformToCoordinationFormat(); + + assertThat( + results, + is( + List.of( + new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token", 0.1F)), false), + new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token2", 0.2F)), true) + ) + ) + ); + } + public record EmbeddingExpectation(Map tokens, boolean isTruncated) {} public static Map buildExpectation(List embeddings) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 71d14e09872fd..09d9894d98853 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -100,6 +100,30 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I }""")); } + public void testTransformToCoordinationFormat() { + var results = new TextEmbeddingResults( + List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F))) + ).transformToCoordinationFormat(); + + assertThat( + results, + is( + List.of( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingResults.TEXT_EMBEDDING, + new double[] { 0.1F, 0.2F }, + false + ), + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingResults.TEXT_EMBEDDING, + new double[] { 0.3F, 0.4F }, + false + ) + ) + ) + ); + } + @Override protected Writeable.Reader instanceReader() { return TextEmbeddingResults::new; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index d90c9ec807495..13e04772683eb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -182,7 +182,7 @@ private void replaceErrorOnMissing( } static InferModelAction.Response translateInferenceServiceResponse(InferenceServiceResults inferenceResults) { - var legacyResults = new ArrayList(inferenceResults.transformToLegacyFormat()); + var legacyResults = new ArrayList(inferenceResults.transformToCoordinationFormat()); return new InferModelAction.Response(legacyResults, null, false); } }