From cf876942dd147c86a12984bcf76338014d35f529 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 2 Oct 2023 15:34:07 +0100 Subject: [PATCH] Preserve order of inference results --- .../MlInferenceNamedXContentProvider.java | 4 + .../results/ErrorInferenceResults.java | 97 +++++++++++++++++++ .../results/ErrorInferenceResultsTests.java | 40 ++++++++ ...portInferTrainedModelDeploymentAction.java | 59 +++++++++-- .../TransportInternalInferModelAction.java | 20 +++- 5 files changed, 207 insertions(+), 13 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7f0d12af5f465..00587936848f8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; @@ -639,6 +640,9 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(InferenceResults.class, WarningInferenceResults.NAME, WarningInferenceResults::new) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry(InferenceResults.class, ErrorInferenceResults.NAME, ErrorInferenceResults::new) + ); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, NerResults.NAME, NerResults::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, FillMaskResults.NAME, FillMaskResults::new)); namedWriteables.add( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java new file mode 100644 index 0000000000000..daf1175bbc547 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +public class ErrorInferenceResults implements InferenceResults { + + public static final String NAME = "error"; + public static final ParseField WARNING = new ParseField("error"); + + private final Exception exception; + + public ErrorInferenceResults(Exception exception) { + this.exception = Objects.requireNonNull(exception); + } + + public ErrorInferenceResults(StreamInput in) throws IOException { + this.exception = in.readException(); + } + + public Exception getException() { + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeException(exception); + } + + @Override + public boolean equals(Object object) { + if (object == this) { + return true; + } + if (object == null || getClass() != object.getClass()) { + return false; + } + ErrorInferenceResults that = (ErrorInferenceResults) object; + // Just compare the message for serialization test purposes + return Objects.equals(exception.getMessage(), that.exception.getMessage()); + } + + @Override + public int hashCode() { + // Just compare the message for serialization test purposes + return Objects.hash(exception.getMessage()); + } + + @Override + public String getResultsField() { + return NAME; + } + + @Override + public Map asMap() { + Map asMap = new LinkedHashMap<>(); + asMap.put(NAME, exception.getMessage()); + return asMap; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public Object predictedValue() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(NAME, exception.getMessage()); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java new file mode 100644 index 0000000000000..e25b2da55b15b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.rest.RestStatus; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class ErrorInferenceResultsTests extends InferenceResultsTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return ErrorInferenceResults::new; + } + + @Override + protected ErrorInferenceResults createTestInstance() { + return new ErrorInferenceResults(new ElasticsearchStatusException(randomAlphaOfLength(8), randomFrom(RestStatus.values()))); + } + + @Override + protected ErrorInferenceResults mutateInstance(ErrorInferenceResults instance) throws IOException { + return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929 + } + + @Override + void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String resultsField) { + assertThat(document.getFieldValue(resultsField + ".error", String.class), equalTo(createdInstance.getException().getMessage())); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index 43040b4f94823..e0088a0c21f3a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -12,23 +12,24 @@ import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import java.util.ArrayList; -import java.util.Collection; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction< TrainedModelDeploymentTask, @@ -96,13 +97,55 @@ protected void taskOperation( } // Multiple documents to infer on, wait for all results - ActionListener> collectingListener = ActionListener.wrap(pyTorchResults -> { - listener.onResponse(new InferTrainedModelDeploymentAction.Response(new ArrayList<>(pyTorchResults))); - }, listener::onFailure); - - GroupedActionListener groupedListener = new GroupedActionListener<>(nlpInputs.size(), collectingListener); + // and return order the results to match the request order + AtomicInteger count = new AtomicInteger(); + AtomicArray results = new AtomicArray<>(nlpInputs.size()); + int slot = 0; for (var input : nlpInputs) { - task.infer(input, request.getUpdate(), request.isHighPriority(), request.getInferenceTimeout(), actionTask, groupedListener); + task.infer( + input, + request.getUpdate(), + request.isHighPriority(), + request.getInferenceTimeout(), + actionTask, + orderedListener(count, results, slot++, nlpInputs.size(), listener) + ); } } + + /** + * Create a listener that groups the results is the correct order. + * Exceptions are converted to {@link ErrorInferenceResults}, + * the listener will never call {@code finalListener::onFailure} + * instead failures are returned as inference results. + */ + private ActionListener orderedListener( + AtomicInteger count, + AtomicArray results, + int slot, + int totalNumberOfResponses, + ActionListener finalListener + ) { + return new ActionListener<>() { + @Override + public void onResponse(InferenceResults response) { + results.setOnce(slot, response); + if (count.incrementAndGet() == totalNumberOfResponses) { + sendResponse(); + } + } + + @Override + public void onFailure(Exception e) { + results.setOnce(slot, new ErrorInferenceResults(e)); + if (count.incrementAndGet() == totalNumberOfResponses) { + sendResponse(); + } + } + + private void sendResponse() { + finalListener.onResponse(new InferTrainedModelDeploymentAction.Response(results.asList())); + } + }; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index c30b7a3232f57..233ea8a5bf989 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; @@ -348,15 +349,24 @@ public void onFailure(Exception e) { } private void sendResponse() { - if (results.nonNullLength() > 0) { + if (failure.get() != null) { + finalListener.onFailure(failure.get()); + } else { for (int i = 0; i < results.length(); i++) { - if (results.get(i) != null) { - responseBuilder.addInferenceResults(results.get(i)); + var resultList = results.get(i); + if (resultList != null) { + for (var result : resultList) { + if (result instanceof ErrorInferenceResults errorResult) { + // Any failure fails all requests + // TODO is this the correct behaviour for batched requests? + finalListener.onFailure(errorResult.getException()); + return; + } + } + responseBuilder.addInferenceResults(resultList); } } finalListener.onResponse(responseBuilder.build()); - } else { - finalListener.onFailure(failure.get()); } } };