Skip to content

Commit

Permalink
Move InferenceResults to server
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Sep 26, 2023
1 parent d7011ab commit 12c41ff
Show file tree
Hide file tree
Showing 56 changed files with 85 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
package org.elasticsearch.inference;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xcontent.ToXContentFragment;

public interface InferenceResults extends NamedWriteable, ToXContentFragment {}
import java.util.Map;
import java.util.Objects;

public interface InferenceResults extends NamedWriteable, ToXContentFragment {
String MODEL_ID_RESULTS_FIELD = "model_id";

static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) {
Objects.requireNonNull(results, "results");
Objects.requireNonNull(ingestDocument, "ingestDocument");
Objects.requireNonNull(resultField, "resultField");
Map<String, Object> resultMap = results.asMap();
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
if (ingestDocument.hasField(resultField)) {
ingestDocument.appendFieldValue(resultField, resultMap);
} else {
ingestDocument.setFieldValue(resultField, resultMap);
}
}

String getResultsField();

Map<String, Object> asMap();

Object predictedValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -22,7 +23,6 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -25,7 +26,6 @@
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.stream.Collectors;

public class ClassificationInferenceResults extends SingleValueInferenceResults {
public static String PREDICTION_PROBABILITY = "prediction_probability";

public static final String NAME = "classification";

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void addMapFields(Map<String, Object> map) {
);
}
if (predictionProbability != null) {
map.put(PREDICTION_PROBABILITY, predictionProbability);
map.put(ClassificationInferenceResults.PREDICTION_PROBABILITY, predictionProbability);
}
}

Expand All @@ -123,7 +123,7 @@ public void doXContentBody(XContentBuilder builder, Params params) throws IOExce
builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
}
if (predictionProbability != null) {
builder.field(PREDICTION_PROBABILITY, predictionProbability);
builder.field(ClassificationInferenceResults.PREDICTION_PROBABILITY, predictionProbability);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void addMapFields(Map<String, Object> map) {
topClasses.stream().map(TopAnswerEntry::asValueMap).collect(Collectors.toList())
);
}
map.put(PREDICTION_PROBABILITY, score);
map.put(ClassificationInferenceResults.PREDICTION_PROBABILITY, score);
}

@Override
Expand All @@ -145,7 +145,7 @@ public void doXContentBody(XContentBuilder builder, Params params) throws IOExce
if (topClasses.size() > 0) {
builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
}
builder.field(PREDICTION_PROBABILITY, score);
builder.field(ClassificationInferenceResults.PREDICTION_PROBABILITY, score);
}

public int getStartOffset() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;

import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import org.apache.lucene.util.Accountable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import org.elasticsearch.common.Numbers;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
import static org.elasticsearch.inference.InferenceResults.writeResult;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.PREDICTION_PROBABILITY;
import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD;
import static org.hamcrest.Matchers.equalTo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.TestIngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
import static org.elasticsearch.inference.InferenceResults.writeResult;
import static org.hamcrest.Matchers.equalTo;

public class NlpClassificationInferenceResultsTests extends InferenceResultsTestCase<NlpClassificationInferenceResults> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
import static org.elasticsearch.inference.InferenceResults.writeResult;
import static org.hamcrest.Matchers.equalTo;

public class QuestionAnsweringInferenceResultsTests extends InferenceResultsTestCase<QuestionAnsweringInferenceResults> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult;
import static org.elasticsearch.inference.InferenceResults.writeResult;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
Expand All @@ -17,7 +18,6 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
Expand All @@ -34,7 +35,6 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.ml.aggs.inference;

import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
Expand All @@ -18,7 +19,6 @@
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;

import java.io.IOException;
import java.util.List;
Expand All @@ -32,7 +32,7 @@ protected InternalInferenceAggregation(String name, Map<String, Object> metadata

public InternalInferenceAggregation(StreamInput in) throws IOException {
super(in);
inferenceResult = in.readNamedWriteable(InferenceResults.class);
inferenceResult = in.readNamedWriteable(org.elasticsearch.inference.InferenceResults.class);
}

public InferenceResults getInferenceResult() {
Expand Down
Loading

0 comments on commit 12c41ff

Please sign in to comment.