Skip to content

Commit

Permalink
[ML] Fix text embedding response format for TransportCoordinatedInfer…
Browse files Browse the repository at this point in the history
…enceAction (elastic#102890)

* Fix for response format

* Adding tests
  • Loading branch information
jonathan-buttner authored Dec 4, 2023
1 parent a67d5b8 commit 8be0446
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@

public interface InferenceServiceResults extends NamedWriteable, ToXContentFragment {

/**
* Transform the result to match the format required for the TransportCoordinatedInferenceAction.
* For the inference plugin TextEmbeddingResults, the {@link #transformToLegacyFormat()} transforms the
* results into an intermediate format only used by the plugin's return value. It doesn't align with what the
* TransportCoordinatedInferenceAction expects. TransportCoordinatedInferenceAction expects an ml plugin
* TextEmbeddingResults.
*
* For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.
*/
List<? extends InferenceResults> transformToCoordinationFormat();

/**
* Transform the result to match the format required for versions prior to
* {@link org.elasticsearch.TransportVersions#INFERENCE_SERVICE_RESULTS_ADDED}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ public Map<String, Object> asMap() {
return map;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return transformToLegacyFormat();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return embeddings.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public String getWriteableName() {
return NAME;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
import static org.hamcrest.Matchers.is;

public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase<SparseEmbeddingResults> {
Expand Down Expand Up @@ -151,6 +153,25 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}"""));
}

public void testTransformToCoordinationFormat() {
var results = createSparseResult(
List.of(
createEmbedding(List.of(new SparseEmbeddingResults.WeightedToken("token", 0.1F)), false),
createEmbedding(List.of(new SparseEmbeddingResults.WeightedToken("token2", 0.2F)), true)
)
).transformToCoordinationFormat();

assertThat(
results,
is(
List.of(
new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token", 0.1F)), false),
new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token2", 0.2F)), true)
)
)
);
}

public record EmbeddingExpectation(Map<String, Float> tokens, boolean isTruncated) {}

public static Map<String, Object> buildExpectation(List<EmbeddingExpectation> embeddings) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I
}"""));
}

public void testTransformToCoordinationFormat() {
var results = new TextEmbeddingResults(
List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F)))
).transformToCoordinationFormat();

assertThat(
results,
is(
List.of(
new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
TextEmbeddingResults.TEXT_EMBEDDING,
new double[] { 0.1F, 0.2F },
false
),
new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(
TextEmbeddingResults.TEXT_EMBEDDING,
new double[] { 0.3F, 0.4F },
false
)
)
)
);
}

@Override
protected Writeable.Reader<TextEmbeddingResults> instanceReader() {
return TextEmbeddingResults::new;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ private void replaceErrorOnMissing(
}

static InferModelAction.Response translateInferenceServiceResponse(InferenceServiceResults inferenceResults) {
var legacyResults = new ArrayList<InferenceResults>(inferenceResults.transformToLegacyFormat());
var legacyResults = new ArrayList<InferenceResults>(inferenceResults.transformToCoordinationFormat());
return new InferModelAction.Response(legacyResults, null, false);
}
}

0 comments on commit 8be0446

Please sign in to comment.