Skip to content

Commit

Permalink
Preserve order of inference results
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 2, 2023
1 parent c24cc0f commit cf87694
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
Expand Down Expand Up @@ -639,6 +640,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, WarningInferenceResults.NAME, WarningInferenceResults::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, ErrorInferenceResults.NAME, ErrorInferenceResults::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, NerResults.NAME, NerResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, FillMaskResults.NAME, FillMaskResults::new));
namedWriteables.add(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class ErrorInferenceResults implements InferenceResults {

public static final String NAME = "error";
public static final ParseField WARNING = new ParseField("error");

private final Exception exception;

public ErrorInferenceResults(Exception exception) {
this.exception = Objects.requireNonNull(exception);
}

public ErrorInferenceResults(StreamInput in) throws IOException {
this.exception = in.readException();
}

public Exception getException() {
return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeException(exception);
}

@Override
public boolean equals(Object object) {
if (object == this) {
return true;
}
if (object == null || getClass() != object.getClass()) {
return false;
}
ErrorInferenceResults that = (ErrorInferenceResults) object;
// Just compare the message for serialization test purposes
return Objects.equals(exception.getMessage(), that.exception.getMessage());
}

@Override
public int hashCode() {
// Just compare the message for serialization test purposes
return Objects.hash(exception.getMessage());
}

@Override
public String getResultsField() {
return NAME;
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> asMap = new LinkedHashMap<>();
asMap.put(NAME, exception.getMessage());
return asMap;
}

@Override
public String toString() {
return Strings.toString(this);
}

@Override
public Object predictedValue() {
return null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(NAME, exception.getMessage());
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.rest.RestStatus;

import java.io.IOException;

import static org.hamcrest.Matchers.equalTo;

public class ErrorInferenceResultsTests extends InferenceResultsTestCase<ErrorInferenceResults> {

@Override
protected Writeable.Reader<ErrorInferenceResults> instanceReader() {
return ErrorInferenceResults::new;
}

@Override
protected ErrorInferenceResults createTestInstance() {
return new ErrorInferenceResults(new ElasticsearchStatusException(randomAlphaOfLength(8), randomFrom(RestStatus.values())));
}

@Override
protected ErrorInferenceResults mutateInstance(ErrorInferenceResults instance) throws IOException {
return null;// TODO implement https://github.com/elastic/elasticsearch/issues/25929
}

@Override
void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String resultsField) {
assertThat(document.getFieldValue(resultsField + ".error", String.class), equalTo(createdInstance.getException().getMessage()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,24 @@
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction<
TrainedModelDeploymentTask,
Expand Down Expand Up @@ -96,13 +97,55 @@ protected void taskOperation(
}

// Multiple documents to infer on, wait for all results
ActionListener<Collection<InferenceResults>> collectingListener = ActionListener.wrap(pyTorchResults -> {
listener.onResponse(new InferTrainedModelDeploymentAction.Response(new ArrayList<>(pyTorchResults)));
}, listener::onFailure);

GroupedActionListener<InferenceResults> groupedListener = new GroupedActionListener<>(nlpInputs.size(), collectingListener);
// and return order the results to match the request order
AtomicInteger count = new AtomicInteger();
AtomicArray<InferenceResults> results = new AtomicArray<>(nlpInputs.size());
int slot = 0;
for (var input : nlpInputs) {
task.infer(input, request.getUpdate(), request.isHighPriority(), request.getInferenceTimeout(), actionTask, groupedListener);
task.infer(
input,
request.getUpdate(),
request.isHighPriority(),
request.getInferenceTimeout(),
actionTask,
orderedListener(count, results, slot++, nlpInputs.size(), listener)
);
}
}

/**
* Create a listener that groups the results is the correct order.
* Exceptions are converted to {@link ErrorInferenceResults},
* the listener will never call {@code finalListener::onFailure}
* instead failures are returned as inference results.
*/
private ActionListener<InferenceResults> orderedListener(
AtomicInteger count,
AtomicArray<InferenceResults> results,
int slot,
int totalNumberOfResponses,
ActionListener<InferTrainedModelDeploymentAction.Response> finalListener
) {
return new ActionListener<>() {
@Override
public void onResponse(InferenceResults response) {
results.setOnce(slot, response);
if (count.incrementAndGet() == totalNumberOfResponses) {
sendResponse();
}
}

@Override
public void onFailure(Exception e) {
results.setOnce(slot, new ErrorInferenceResults(e));
if (count.incrementAndGet() == totalNumberOfResponses) {
sendResponse();
}
}

private void sendResponse() {
finalListener.onResponse(new InferTrainedModelDeploymentAction.Response(results.asList()));
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
Expand Down Expand Up @@ -348,15 +349,24 @@ public void onFailure(Exception e) {
}

private void sendResponse() {
if (results.nonNullLength() > 0) {
if (failure.get() != null) {
finalListener.onFailure(failure.get());
} else {
for (int i = 0; i < results.length(); i++) {
if (results.get(i) != null) {
responseBuilder.addInferenceResults(results.get(i));
var resultList = results.get(i);
if (resultList != null) {
for (var result : resultList) {
if (result instanceof ErrorInferenceResults errorResult) {
// Any failure fails all requests
// TODO is this the correct behaviour for batched requests?
finalListener.onFailure(errorResult.getException());
return;
}
}
responseBuilder.addInferenceResults(resultList);
}
}
finalListener.onResponse(responseBuilder.build());
} else {
finalListener.onFailure(failure.get());
}
}
};
Expand Down

0 comments on commit cf87694

Please sign in to comment.