diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 1568f47461e6c..99ce5910c9775 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -278,6 +278,7 @@ exports org.elasticsearch.indices.recovery; exports org.elasticsearch.indices.recovery.plan; exports org.elasticsearch.indices.store; + exports org.elasticsearch.inference; exports org.elasticsearch.ingest; exports org.elasticsearch.internal to diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java similarity index 68% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java rename to server/src/main/java/org/elasticsearch/inference/InferenceResults.java index 83f08391c656b..76dfcae2ba530 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java @@ -1,26 +1,27 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.core.ml.inference.results; + +package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xcontent.ToXContentFragment; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.Map; +import java.util.Objects; public interface InferenceResults extends NamedWriteable, ToXContentFragment { - String PREDICTION_PROBABILITY = "prediction_probability"; String MODEL_ID_RESULTS_FIELD = "model_id"; static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) { - ExceptionsHelper.requireNonNull(results, "results"); - ExceptionsHelper.requireNonNull(ingestDocument, "ingestDocument"); - ExceptionsHelper.requireNonNull(resultField, "resultField"); + Objects.requireNonNull(results, "results"); + Objects.requireNonNull(ingestDocument, "ingestDocument"); + Objects.requireNonNull(resultField, "resultField"); Map resultMap = results.asMap(); resultMap.put(MODEL_ID_RESULTS_FIELD, modelId); if (ingestDocument.hasField(resultField)) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java similarity index 74% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/InferenceService.java rename to server/src/main/java/org/elasticsearch/inference/InferenceService.java index 18704f4c32740..3bc9cf2a5a116 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -1,16 +1,14 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference.services; +package org.elasticsearch.inference; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; -import org.elasticsearch.xpack.inference.results.InferenceResult; import java.util.Map; @@ -45,20 +43,20 @@ public interface InferenceService { */ Model parseConfigLenient(String modelId, TaskType taskType, Map config); - /** - * Start or prepare the model for use. - * @param model The model - * @param listener The listener - */ - void start(Model model, ActionListener listener); - /** * Perform inference on the model. * - * @param model Model configuration + * @param model The model * @param input Inference input - * @param requestTaskSettings Settings in the request to override the model's defaults + * @param taskSettings Settings in the request to override the model's defaults * @param listener Inference result listener */ - void infer(Model model, String input, Map requestTaskSettings, ActionListener listener); + void infer(Model model, String input, Map taskSettings, ActionListener listener); + + /** + * Start or prepare the model for use. + * @param model The model + * @param listener The listener + */ + void start(Model model, ActionListener listener); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java new file mode 100644 index 0000000000000..ac1439150f8ec --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.component.AbstractLifecycleComponent; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.plugins.InferenceServicePlugin; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistry extends AbstractLifecycleComponent { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistry( + List inferenceServicePlugins, + InferenceServicePlugin.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + + for (var plugin : inferenceServicePlugins) { + namedWriteables.addAll(plugin.getInferenceServiceNamedWriteables()); + } + } + + public Map getServices() { + return services; + } + + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } + + public List getNamedWriteables() { + return namedWriteables; + } + + @Override + protected void doStart() { + + } + + @Override + protected void doStop() { + + } + + @Override + protected void doClose() throws IOException { + + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java similarity index 94% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/Model.java rename to server/src/main/java/org/elasticsearch/inference/Model.java index c0032ebf25ca7..67ee58bad733c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/Model.java +++ b/server/src/main/java/org/elasticsearch/inference/Model.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java similarity index 62% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/ServiceSettings.java rename to server/src/main/java/org/elasticsearch/inference/ServiceSettings.java index 16f39ad8e560c..01a75158f5ca2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/ServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskSettings.java b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java similarity index 62% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskSettings.java rename to server/src/main/java/org/elasticsearch/inference/TaskSettings.java index 200f1b309822d..e346fa89b8663 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java similarity index 77% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskType.java rename to server/src/main/java/org/elasticsearch/inference/TaskType.java index 5e9cc93270033..9e96a7c4c52d0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.StreamInput; @@ -49,4 +50,8 @@ public static TaskType fromStream(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeEnum(this); } + + public static String unsupportedTaskTypeErrorMsg(TaskType taskType, String serviceName) { + return "The [" + serviceName + "] service does not support task type [" + taskType + "]"; + } } diff --git a/server/src/main/java/org/elasticsearch/node/Node.java b/server/src/main/java/org/elasticsearch/node/Node.java index ad8ad68bf7650..a2da65bdaa8b4 100644 --- a/server/src/main/java/org/elasticsearch/node/Node.java +++ b/server/src/main/java/org/elasticsearch/node/Node.java @@ -146,6 +146,7 @@ import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; import org.elasticsearch.indices.store.IndicesStore; +import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -165,6 +166,7 @@ import org.elasticsearch.plugins.EnginePlugin; import org.elasticsearch.plugins.HealthPlugin; import org.elasticsearch.plugins.IndexStorePlugin; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -532,6 +534,12 @@ protected Node( Supplier documentParsingObserverSupplier = getDocumentParsingObserverSupplier(); + var factoryContext = new InferenceServicePlugin.InferenceServiceFactoryContext(client); + final InferenceServiceRegistry inferenceServiceRegistry = new InferenceServiceRegistry( + pluginsService.filterPlugins(InferenceServicePlugin.class), + factoryContext + ); + final IngestService ingestService = new IngestService( clusterService, threadPool, @@ -555,7 +563,8 @@ protected Node( searchModule.getNamedWriteables().stream(), pluginsService.flatMap(Plugin::getNamedWriteables), ClusterModule.getNamedWriteables().stream(), - SystemIndexMigrationExecutor.getNamedWriteables().stream() + SystemIndexMigrationExecutor.getNamedWriteables().stream(), + inferenceServiceRegistry.getNamedWriteables().stream() ).flatMap(Function.identity()).toList(); final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); NamedXContentRegistry xContentRegistry = new NamedXContentRegistry( @@ -1170,6 +1179,7 @@ protected Node( b.bind(WriteLoadForecaster.class).toInstance(writeLoadForecaster); b.bind(HealthPeriodicLogger.class).toInstance(healthPeriodicLogger); b.bind(CompatibilityVersions.class).toInstance(compatibilityVersions); + b.bind(InferenceServiceRegistry.class).toInstance(inferenceServiceRegistry); }); if (ReadinessService.enabled(environment)) { diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java new file mode 100644 index 0000000000000..2672a4b8fcbcf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.plugins; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.InferenceService; + +import java.util.List; + +/** + * InferenceServicePlugins implement an inference service + */ +public interface InferenceServicePlugin { + + List getInferenceServiceFactories(); + + record InferenceServiceFactoryContext(Client client) {} + + interface Factory { + /** + * InferenceServices are created from the factory context + */ + InferenceService create(InferenceServiceFactoryContext context); + } + + /** + * The named writables defined and used by each of the implemented + * InferenceServices. Each service should define named writables for + * - {@link org.elasticsearch.inference.TaskSettings} + * - {@link org.elasticsearch.inference.ServiceSettings} + * And optionally for {@link org.elasticsearch.inference.InferenceResults} + * if the service uses a new type of result. + * @return All named writables defined by the services + */ + List getInferenceServiceNamedWriteables(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index fdc4040b0be49..a06d0fe0ce0ce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; @@ -22,7 +23,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java index 916f09c84b6f5..524d5f84a177b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; @@ -25,7 +26,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 1ef60af3584e7..7f0d12af5f465 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; @@ -21,7 +22,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 475e39c8eb3d2..ebf946aa34716 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -22,6 +22,7 @@ import java.util.stream.Collectors; public class ClassificationInferenceResults extends SingleValueInferenceResults { + public static String PREDICTION_PROBABILITY = "prediction_probability"; public static final String NAME = "classification"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java index 7556b223dd317..a49e81e40a7a6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResults.java @@ -107,7 +107,7 @@ void addMapFields(Map map) { ); } if (predictionProbability != null) { - map.put(PREDICTION_PROBABILITY, predictionProbability); + map.put(ClassificationInferenceResults.PREDICTION_PROBABILITY, predictionProbability); } } @@ -123,7 +123,7 @@ public void doXContentBody(XContentBuilder builder, Params params) throws IOExce builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses); } if (predictionProbability != null) { - builder.field(PREDICTION_PROBABILITY, predictionProbability); + builder.field(ClassificationInferenceResults.PREDICTION_PROBABILITY, predictionProbability); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java index 33ef519f5e41c..ab1939013e976 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java index 0c60b05320cec..dec5343fe8ddf 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -129,7 +129,7 @@ void addMapFields(Map map) { topClasses.stream().map(TopAnswerEntry::asValueMap).collect(Collectors.toList()) ); } - map.put(PREDICTION_PROBABILITY, score); + map.put(ClassificationInferenceResults.PREDICTION_PROBABILITY, score); } @Override @@ -145,7 +145,7 @@ public void doXContentBody(XContentBuilder builder, Params params) throws IOExce if (topClasses.size() > 0) { builder.field(NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses); } - builder.field(PREDICTION_PROBABILITY, score); + builder.field(ClassificationInferenceResults.PREDICTION_PROBABILITY, score); } public int getStartOffset() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java index 97b2949d0020c..b7744a6c54800 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java index 0daaa1ebc35a9..1da92021a9edb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; import java.io.IOException; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java index fc399b8298335..a956c91007934 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index 30d1eda66a39c..10f0cc6770a9c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -13,10 +13,10 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index 7ae5b141a3c0f..636cbcac725f4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -8,11 +8,11 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java index 61498648c8c01..cf532cefbeab8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java @@ -9,7 +9,7 @@ import org.apache.lucene.util.Accountable; import org.elasticsearch.core.Nullable; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index b1d166a00a461..d23600383b34b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -13,11 +13,11 @@ import org.elasticsearch.common.Numbers; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index 8402e0eaade5e..8d010c6fbc650 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -11,13 +11,13 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 591e085ff6f36..4d8035864729a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; @@ -16,7 +17,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index 492678f035562..28e318c0dab48 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -23,7 +23,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.elasticsearch.inference.InferenceResults.writeResult; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java index 4119d56be7b5b..432e05b9cc680 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResultsTests.java @@ -14,7 +14,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.PREDICTION_PROBABILITY; +import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java index 0b5d6a0b9c92b..27503547e5705 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.TransportVersion; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.TestIngestDocument; import org.elasticsearch.test.AbstractWireSerializingTestCase; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java index d47730bd579a6..ac3cb638d88d1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java @@ -16,7 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.elasticsearch.inference.InferenceResults.writeResult; import static org.hamcrest.Matchers.equalTo; public class NlpClassificationInferenceResultsTests extends InferenceResultsTestCase { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java index 947f8732aba68..c9c65ea3f3538 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java @@ -16,7 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.elasticsearch.inference.InferenceResults.writeResult; import static org.hamcrest.Matchers.equalTo; public class QuestionAnsweringInferenceResultsTests extends InferenceResultsTestCase { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 6ea00fce676e4..27a07e8f996f7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -19,7 +19,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.elasticsearch.inference.InferenceResults.writeResult; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java index 104838921a434..c3b2fbf6fb556 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java @@ -16,8 +16,13 @@ import java.util.stream.Collectors; public class TextExpansionResultsTests extends InferenceResultsTestCase { + public static TextExpansionResults createRandomResults() { - int numTokens = randomIntBetween(0, 20); + return createRandomResults(0, 20); + } + + public static TextExpansionResults createRandomResults(int min, int max) { + int numTokens = randomIntBetween(min, max); List tokenList = new ArrayList<>(); for (int i = 0; i < numTokens; i++) { tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false))); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java new file mode 100644 index 0000000000000..411c29255fd78 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.PutInferenceModelAction; + +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; + +public class MockInferenceServiceIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return List.of(InferencePlugin.class, TestInferenceServicePlugin.class); + } + + @Override + protected Function getClientWrapper() { + final Map headers = Map.of( + "Authorization", + basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING) + ); + // we need to wrap node clients because we do not specify a user for nodes and all requests will use the system + // user. This is ok for internal n2n stuff but the test framework does other things like wiping indices, repositories, etc + // that the system user cannot do. so we wrap the node client with a user that can do these things since the client() calls + // return a node client + return client -> client.filterWithHeader(headers); + } + + public void testMockService() { + String modelId = "test-mock"; + Model putModel = putMockService(modelId, TaskType.SPARSE_EMBEDDING); + Model readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING); + assertModelsAreEqual(putModel, readModel); + + // The response is randomly generated, the input can be anything + inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, randomAlphaOfLength(10)); + } + + private Model putMockService(String modelId, TaskType taskType) { + String body = """ + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + "temperature": 3 + } + } + """; + var request = new PutInferenceModelAction.Request( + taskType.toString(), + modelId, + new BytesArray(body.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON + ); + + var response = client().execute(PutInferenceModelAction.INSTANCE, request).actionGet(); + assertEquals("test_service", response.getModel().getService()); + + assertThat(response.getModel().getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class)); + var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) response.getModel().getServiceSettings(); + assertEquals("my_model", serviceSettings.model()); + assertEquals("abc64", serviceSettings.apiKey()); + + assertThat(response.getModel().getTaskSettings(), instanceOf(TestInferenceServicePlugin.TestTaskSettings.class)); + var taskSettings = (TestInferenceServicePlugin.TestTaskSettings) response.getModel().getTaskSettings(); + assertEquals(3, (int) taskSettings.temperature()); + + return response.getModel(); + } + + public Model getModel(String modelId, TaskType taskType) { + var response = client().execute(GetInferenceModelAction.INSTANCE, new GetInferenceModelAction.Request(modelId, taskType.toString())) + .actionGet(); + return response.getModel(); + } + + private void inferOnMockService(String modelId, TaskType taskType, String input) { + var response = client().execute(InferenceAction.INSTANCE, new InferenceAction.Request(taskType, modelId, input, Map.of())) + .actionGet(); + if (taskType == TaskType.SPARSE_EMBEDDING) { + assertThat(response.getResult(), instanceOf(TextExpansionResults.class)); + var teResult = (TextExpansionResults) response.getResult(); + assertThat(teResult.getWeightedTokens(), not(empty())); + } else { + fail("test with task type [" + taskType + "] are not supported yet"); + } + } + + private void assertModelsAreEqual(Model model1, Model model2) { + // The test can't rely on Model::equals as the specific subclass + // may be different. Model loses information about it's implemented + // subtype when it is streamed across the wire. + assertEquals(model1.getModelId(), model2.getModelId()); + assertEquals(model1.getService(), model2.getService()); + assertEquals(model1.getTaskType(), model2.getTaskType()); + + // TaskSettings and Service settings are named writables so + // the actual implementing class type is not lost when streamed \ + assertEquals(model1.getServiceSettings(), model2.getServiceSettings()); + assertEquals(model1.getTaskSettings(), model2.getTaskSettings()); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index fbb6bb7e316ff..a400f84e3c2ec 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -9,15 +9,15 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.ServiceSettings; -import org.elasticsearch.xpack.inference.TaskSettings; -import org.elasticsearch.xpack.inference.TaskType; import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java new file mode 100644 index 0000000000000..b9b0d68054ef5 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.InferenceServicePlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; +import org.elasticsearch.xpack.inference.services.MapParsingUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.MapParsingUtils.throwIfNotEmptyMap; + +public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin { + + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + @Override + public List getInferenceServiceNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new), + new NamedWriteableRegistry.Entry(TaskSettings.class, TestTaskSettings.NAME, TestTaskSettings::new) + ); + } + + public class TestInferenceService implements InferenceService { + + private static final String NAME = "test_service"; + + public static TestServiceModel parseConfig( + boolean throwOnUnknownFields, + String modelId, + TaskType taskType, + Map settings + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(settings, Model.SERVICE_SETTINGS); + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + + Map taskSettingsMap; + // task settings are optional + if (settings.containsKey(Model.TASK_SETTINGS)) { + taskSettingsMap = removeFromMapOrThrowIfNull(settings, Model.TASK_SETTINGS); + } else { + taskSettingsMap = Map.of(); + } + + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + if (throwOnUnknownFields) { + throwIfNotEmptyMap(settings, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + } + + return new TestServiceModel(modelId, taskType, NAME, serviceSettings, taskSettings); + } + + public TestInferenceService(InferenceServicePlugin.InferenceServiceFactoryContext context) { + + } + + @Override + public String name() { + return NAME; + } + + @Override + public TestServiceModel parseConfigStrict(String modelId, TaskType taskType, Map config) { + return parseConfig(true, modelId, taskType, config); + } + + @Override + public TestServiceModel parseConfigLenient(String modelId, TaskType taskType, Map config) { + return parseConfig(false, modelId, taskType, config); + } + + @Override + public void infer(Model model, String input, Map taskSettings, ActionListener listener) { + switch (model.getTaskType()) { + case SPARSE_EMBEDDING -> listener.onResponse(TextExpansionResultsTests.createRandomResults(1, 10)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME), + RestStatus.BAD_REQUEST + ) + ); + } + + } + + @Override + public void start(Model model, ActionListener listener) { + listener.onResponse(true); + } + } + + public static class TestServiceModel extends Model { + + public TestServiceModel( + String modelId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings + ) { + super(modelId, taskType, service, serviceSettings, taskSettings); + } + + @Override + public TestServiceSettings getServiceSettings() { + return (TestServiceSettings) super.getServiceSettings(); + } + + @Override + public TestTaskSettings getTaskSettings() { + return (TestTaskSettings) super.getTaskSettings(); + } + } + + public record TestServiceSettings(String model, String apiKey) implements ServiceSettings { + + private static final String NAME = "test_service_settings"; + + public static TestServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = MapParsingUtils.removeAsType(map, "model", String.class); + String apiKey = MapParsingUtils.removeAsType(map, "api_key", String.class); + + if (model == null) { + validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("model", Model.SERVICE_SETTINGS)); + } + if (apiKey == null) { + validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("api_key", Model.SERVICE_SETTINGS)); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new TestServiceSettings(model, apiKey); + } + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("model", model); + builder.field("api_key", apiKey); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + out.writeString(apiKey); + } + } + + public record TestTaskSettings(Integer temperature) implements TaskSettings { + + private static final String NAME = "test_task_settings"; + + public static TestTaskSettings fromMap(Map map) { + Integer temperature = MapParsingUtils.removeAsType(map, "temperature", Integer.class); + return new TestTaskSettings(temperature); + } + + public TestTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalVInt()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(temperature); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (temperature != null) { + builder.field("temperature", temperature); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index e80d828e4e48d..b21f919bbdc8a 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -17,6 +17,5 @@ exports org.elasticsearch.xpack.inference.rest; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; - exports org.elasticsearch.xpack.inference.results; exports org.elasticsearch.xpack.inference; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 3bbc0a53a9973..3ef93c6c275d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.xpack.inference.results.InferenceResult; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResult; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; @@ -31,11 +31,6 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(TaskSettings.class, ElserMlNodeTaskSettings.NAME, ElserMlNodeTaskSettings::new) ); - // Inference results - namedWriteables.add( - new NamedWriteableRegistry.Entry(InferenceResult.class, SparseEmbeddingResult.NAME, SparseEmbeddingResult::new) - ); - return namedWriteables; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a5b0754a7f178..ba0f1b142a799 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -27,6 +27,7 @@ import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.repositories.RepositoriesService; @@ -47,7 +48,6 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -55,11 +55,10 @@ import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.function.Supplier; -public class InferencePlugin extends Plugin implements ActionPlugin, SystemIndexPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, InferenceServicePlugin, SystemIndexPlugin { public static final String NAME = "inference"; @@ -75,16 +74,6 @@ public class InferencePlugin extends Plugin implements ActionPlugin, SystemIndex ); } - @Override - public List getNamedWriteables() { - return InferenceNamedWriteablesProvider.getNamedWriteables(); - } - - @Override - public List getNamedXContent() { - return Collections.emptyList(); - } - @Override public List getRestHandlers( Settings settings, @@ -121,8 +110,7 @@ public Collection createComponents( IndicesService indicesService ) { ModelRegistry modelRegistry = new ModelRegistry(client); - ServiceRegistry serviceRegistry = new ServiceRegistry(new ElserMlNodeService(client)); - return List.of(modelRegistry, serviceRegistry); + return List.of(modelRegistry); } @Override @@ -155,4 +143,14 @@ public String getFeatureName() { public String getFeatureDescription() { return "Inference plugin for managing inference services and inference"; } + + @Override + public List getInferenceServiceFactories() { + return List.of(ElserMlNodeService::new); + } + + @Override + public List getInferenceServiceNamedWriteables() { + return InferenceNamedWriteablesProvider.getNamedWriteables(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java index 29b4accf1f4f0..b6dd41df174e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java index 2ae1a33f8f5e3..4062946935b2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java @@ -13,7 +13,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.inference.TaskType; +import org.elasticsearch.inference.TaskType; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java index a80ed84d8a6ed..7e47d3d93e3ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java @@ -12,7 +12,7 @@ import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.inference.TaskType; +import org.elasticsearch.inference.TaskType; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java index 2b95e6153a360..7938c2abd8d99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java @@ -14,14 +14,14 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.inference.TaskType; -import org.elasticsearch.xpack.inference.results.InferenceResult; import java.io.IOException; import java.util.Map; @@ -168,15 +168,19 @@ public Request build() { public static class Response extends ActionResponse implements ToXContentObject { - private final InferenceResult result; + private final InferenceResults result; - public Response(InferenceResult result) { + public Response(InferenceResults result) { this.result = result; } public Response(StreamInput in) throws IOException { super(in); - result = in.readNamedWriteable(InferenceResult.class); + result = in.readNamedWriteable(InferenceResults.class); + } + + public InferenceResults getResult() { + return result; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java index 6ec020e6b479a..6c59fc89fd152 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java @@ -15,11 +15,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index f11e6101e1e2a..1e208e83985cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -12,26 +12,26 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; public class TransportGetInferenceModelAction extends HandledTransportAction< GetInferenceModelAction.Request, PutInferenceModelAction.Response> { private final ModelRegistry modelRegistry; - private final ServiceRegistry serviceRegistry; + private final InferenceServiceRegistry serviceRegistry; @Inject public TransportGetInferenceModelAction( TransportService transportService, ActionFilters actionFilters, ModelRegistry modelRegistry, - ServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry ) { super(GetInferenceModelAction.NAME, transportService, actionFilters, GetInferenceModelAction.Request::new); this.modelRegistry = modelRegistry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 42fa61b406e9e..aab8ed98f4241 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -11,33 +11,27 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.inference.Model; import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; -import org.elasticsearch.xpack.inference.services.InferenceService; public class TransportInferenceAction extends HandledTransportAction { private final ModelRegistry modelRegistry; - private final ServiceRegistry serviceRegistry; + private final InferenceServiceRegistry serviceRegistry; @Inject public TransportInferenceAction( - Settings settings, TransportService transportService, - ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, ModelRegistry modelRegistry, - ServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry ) { super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new); this.modelRegistry = modelRegistry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 5b948fc94d36e..8ab09bafbd248 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -19,16 +19,16 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.inference.Model; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; -import org.elasticsearch.xpack.inference.services.InferenceService; import java.io.IOException; import java.util.Map; @@ -38,7 +38,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction< PutInferenceModelAction.Response> { private final ModelRegistry modelRegistry; - private final ServiceRegistry serviceRegistry; + private final InferenceServiceRegistry serviceRegistry; @Inject public TransportPutInferenceModelAction( @@ -48,7 +48,7 @@ public TransportPutInferenceModelAction( ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, ModelRegistry modelRegistry, - ServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry ) { super( PutInferenceModelAction.NAME, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 5ad9554959a27..4403ec53e7a13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; +import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; @@ -32,7 +33,6 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.inference.InferenceIndex; -import org.elasticsearch.xpack.inference.Model; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java deleted file mode 100644 index 8767630f5625b..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.registry; - -import org.elasticsearch.xpack.inference.services.InferenceService; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; - -import java.util.Optional; - -public class ServiceRegistry { - - ElserMlNodeService elserService; - - public ServiceRegistry(ElserMlNodeService elserService) { - this.elserService = elserService; - } - - public Optional getService(String name) { - if (name.equals(ElserMlNodeService.NAME)) { - return Optional.of(elserService); - } - - return Optional.empty(); - } - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java deleted file mode 100644 index 8d8351dbe38d3..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.io.stream.VersionedNamedWriteable; -import org.elasticsearch.xcontent.ToXContentFragment; - -public interface InferenceResult extends ToXContentFragment, VersionedNamedWriteable {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java deleted file mode 100644 index 3f84c91b055a1..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public class SparseEmbeddingResult implements InferenceResult { - - public static final String NAME = "sparse_embedding_result"; - - private final List weightedTokens; - - public SparseEmbeddingResult(List weightedTokens) { - this.weightedTokens = weightedTokens; - } - - public SparseEmbeddingResult(StreamInput in) throws IOException { - this.weightedTokens = in.readCollectionAsImmutableList(TextExpansionResults.WeightedToken::new); - } - - public List getWeightedTokens() { - return weightedTokens; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject("sparse_embedding"); - for (var weightedToken : weightedTokens) { - weightedToken.toXContent(builder, params); - } - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_500_074; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(weightedTokens); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SparseEmbeddingResult that = (SparseEmbeddingResult) o; - return Objects.equals(weightedTokens, that.weightedTokens); - } - - @Override - public int hashCode() { - return Objects.hash(weightedTokens); - } - - @Override - public String toString() { - return Strings.toString(this); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java index c2b4986b84dc5..31228b645cff2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java @@ -54,6 +54,12 @@ public static Map removeFromMapOrThrowIfNull(Map return value; } + public static void throwIfNotEmptyMap(Map settingsMap, String serviceName) { + if (settingsMap.isEmpty() == false) { + throw MapParsingUtils.unknownSettingsError(settingsMap, serviceName); + } + } + public static ElasticsearchStatusException unknownSettingsError(Map config, String serviceName) { // TOOD map as JSON return new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java index 499a336c5d1a6..6c7e36e5d81ee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.services.elser; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; public class ElserMlNodeModel extends Model { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index 602048f2e3e76..45acc467b047b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -9,27 +9,26 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; -import org.elasticsearch.xpack.inference.results.InferenceResult; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResult; -import org.elasticsearch.xpack.inference.services.InferenceService; -import org.elasticsearch.xpack.inference.services.MapParsingUtils; import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.MapParsingUtils.throwIfNotEmptyMap; public class ElserMlNodeService implements InferenceService { @@ -57,9 +56,9 @@ public static ElserMlNodeModel parseConfig( var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); if (throwOnUnknownFields) { - throwIfNotEmptyMap(settings); - throwIfNotEmptyMap(serviceSettingsMap); - throwIfNotEmptyMap(taskSettingsMap); + throwIfNotEmptyMap(settings, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); } return new ElserMlNodeModel(modelId, taskType, NAME, serviceSettings, taskSettings); @@ -67,8 +66,8 @@ public static ElserMlNodeModel parseConfig( private final OriginSettingClient client; - public ElserMlNodeService(Client client) { - this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); + public ElserMlNodeService(InferenceServicePlugin.InferenceServiceFactoryContext context) { + this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN); } @Override @@ -89,7 +88,7 @@ public void start(Model model, ActionListener listener) { } if (model.getTaskType() != TaskType.SPARSE_EMBEDDING) { - listener.onFailure(new IllegalStateException(unsupportedTaskTypeErrorMsg(model.getTaskType()))); + listener.onFailure(new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME))); return; } @@ -109,11 +108,13 @@ public void start(Model model, ActionListener listener) { } @Override - public void infer(Model model, String input, Map requestTaskSettings, ActionListener listener) { + public void infer(Model model, String input, Map taskSettings, ActionListener listener) { // No task settings to override with requestTaskSettings if (model.getTaskType() != TaskType.SPARSE_EMBEDDING) { - listener.onFailure(new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(model.getTaskType()), RestStatus.BAD_REQUEST)); + listener.onFailure( + new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME), RestStatus.BAD_REQUEST) + ); return; } @@ -125,8 +126,7 @@ public void infer(Model model, String input, Map requestTaskSett ); client.execute(InferTrainedModelDeploymentAction.INSTANCE, request, ActionListener.wrap(inferenceResult -> { var textExpansionResult = (TextExpansionResults) inferenceResult.getResults().get(0); - var sparseEmbeddingResult = new SparseEmbeddingResult(textExpansionResult.getWeightedTokens()); - listener.onResponse(sparseEmbeddingResult); + listener.onResponse(textExpansionResult); }, listener::onFailure)); } @@ -136,7 +136,7 @@ private static ElserMlNodeServiceSettings serviceSettingsFromMap(Map config) { if (taskType != TaskType.SPARSE_EMBEDDING) { - throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType), RestStatus.BAD_REQUEST); + throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); } // no config options yet @@ -147,14 +147,4 @@ private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Ma public String name() { return NAME; } - - private static void throwIfNotEmptyMap(Map settingsMap) { - if (settingsMap.isEmpty() == false) { - throw MapParsingUtils.unknownSettingsError(settingsMap, NAME); - } - } - - private static String unsupportedTaskTypeErrorMsg(TaskType taskType) { - return "The [" + NAME + "] service does not support task type [" + taskType + "]"; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java index 1d6a5106a1959..1314e6eab4f25 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java @@ -12,9 +12,9 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.ServiceSettings; import org.elasticsearch.xpack.inference.services.MapParsingUtils; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java index f4c75683783f0..c1d84af5b5fbe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java @@ -11,8 +11,8 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.TaskSettings; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java index 57f4c0650b932..778f4703767a6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java @@ -9,6 +9,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java index 22a4981e092dc..0b30dc9021038 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.TaskType; public class GetInferenceModelRequestTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index f937ba03ae864..3e1bea0051656 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -9,8 +9,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.TaskType; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 13896607fe9ab..795923e56c6bb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -10,16 +10,22 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultTests; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; public class InferenceActionResponseTests extends AbstractWireSerializingTestCase { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables()); + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); } @Override @@ -29,7 +35,7 @@ protected Writeable.Reader instanceReader() { @Override protected InferenceAction.Response createTestInstance() { - return new InferenceAction.Response(SparseEmbeddingResultTests.createRandomResult()); + return new InferenceAction.Response(TextExpansionResultsTests.createRandomResults()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index 770faf19585c5..9aefea9a942db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.TaskType; public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java deleted file mode 100644 index 492fb29f910b8..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.registry; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; - -import static org.mockito.Mockito.mock; - -public class ServiceRegistryTests extends ESTestCase { - - public void testGetService() { - ServiceRegistry registry = new ServiceRegistry(mock(ElserMlNodeService.class)); - var service = registry.getService(ElserMlNodeService.NAME); - assertTrue(service.isPresent()); - } - - public void testGetUnknownService() { - ServiceRegistry registry = new ServiceRegistry(mock(ElserMlNodeService.class)); - var service = registry.getService("foo"); - assertFalse(service.isPresent()); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java deleted file mode 100644 index 360dc3e97d141..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; - -import java.util.ArrayList; -import java.util.List; - -public class SparseEmbeddingResultTests extends AbstractWireSerializingTestCase { - - public static SparseEmbeddingResult createRandomResult() { - int numTokens = randomIntBetween(1, 20); - List tokenList = new ArrayList<>(); - for (int i = 0; i < numTokens; i++) { - tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false))); - } - return new SparseEmbeddingResult(tokenList); - } - - @Override - protected Writeable.Reader instanceReader() { - return SparseEmbeddingResult::new; - } - - @Override - protected SparseEmbeddingResult createTestInstance() { - return createRandomResult(); - } - - @Override - protected SparseEmbeddingResult mutateInstance(SparseEmbeddingResult instance) { - if (instance.getWeightedTokens().size() > 0) { - var tokens = instance.getWeightedTokens(); - return new SparseEmbeddingResult(tokens.subList(0, tokens.size() - 1)); - } else { - return new SparseEmbeddingResult(List.of(new TextExpansionResults.WeightedToken("a", 1.0f))); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java index bdbb4c545900c..0449c1b4a7d59 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java @@ -9,9 +9,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; import java.util.HashMap; import java.util.Map; @@ -35,7 +36,7 @@ public static Model randomModelConfig(String modelId, TaskType taskType) { } public void testParseConfigStrict() { - var service = new ElserMlNodeService(mock(Client.class)); + var service = createService(mock(Client.class)); var settings = new HashMap(); settings.put( @@ -59,7 +60,7 @@ public void testParseConfigStrict() { } public void testParseConfigStrictWithNoTaskSettings() { - var service = new ElserMlNodeService(mock(Client.class)); + var service = createService(mock(Client.class)); var settings = new HashMap(); settings.put( @@ -153,4 +154,9 @@ public void testParseConfigStrictWithUnknownSettings() { } } } + + private ElserMlNodeService createService(Client client) { + var context = new InferenceServicePlugin.InferenceServiceFactoryContext(client); + return new ElserMlNodeService(context); + } } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 80f671a2c9db3..b9ca3946412bc 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.action.InferModelAction; @@ -17,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java index f0f8287ab0d78..43040b4f94823 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java @@ -17,11 +17,11 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput; import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index 17db1cbc5a234..c30b7a3232f57 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.license.License; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -34,7 +35,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregator.java index f68d1df0e9db4..c94eacad6fb86 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.aggs.inference; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.InternalAggregation; @@ -18,7 +19,6 @@ import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.support.AggregationPath; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregation.java index 59ef2cb6e8e3e..979c3ff46dcb3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregation.java @@ -9,10 +9,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import java.io.IOException; import java.util.List; @@ -32,7 +32,7 @@ protected InternalInferenceAggregation(String name, Map metadata public InternalInferenceAggregation(StreamInput in) throws IOException { super(in); - inferenceResult = in.readNamedWriteable(InferenceResults.class); + inferenceResult = in.readNamedWriteable(org.elasticsearch.inference.InferenceResults.class); } public InferenceResults getInferenceResult() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java index 36bca38d9f587..168b0deda87d4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.metrics.Max; @@ -28,7 +29,6 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index e31f8d9992b64..fe8366065dac7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; @@ -40,7 +41,6 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 21bd66e6f35ea..03f34dacb1faf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.IdsQueryBuilder; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.threadpool.ThreadPool; @@ -32,7 +33,6 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java index bde15b4403f4c..474933236d196 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchAction.java @@ -12,10 +12,10 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index 025d0ecca00d2..ea06a0a0aba90 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; @@ -22,7 +23,6 @@ import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 8a4d9e569ea71..5d1c60964b8b9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.ConfigurationUtils; import org.elasticsearch.ingest.IngestDocument; @@ -24,7 +25,6 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; @@ -65,10 +65,10 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD; import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD; public class InferenceProcessor extends AbstractProcessor { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 1dc689215d74c..ffd70849d8f1c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -8,10 +8,10 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index da2f97e283f2a..7a202df6fa744 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -29,6 +29,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.license.License; import org.elasticsearch.license.XPackLicenseState; @@ -38,7 +39,6 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java index d8f86eef2c4d0..d171795f86a54 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -9,9 +9,9 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java index 160a53c861a8b..a0c8ddcba5125 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -9,8 +9,8 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java index 040b9fbedb289..273120bba09ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java @@ -10,7 +10,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.Releasable; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java index 03f1084e39184..417f13964f6cb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java index e15f0402db3b1..a7f7f2bb9f538 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessor.java @@ -9,8 +9,8 @@ import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java index 6fe1e66b8f6de..150ac184d246f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java index c11ea2005a05d..453b689d59cc0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 1983dd973c9ef..6483b9d9b3da9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.ml.inference.nlp; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessor.java index 1f296438c796c..0b4f9d8f897e0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessor.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.ml.inference.nlp; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java index eff6916d61609..a9422e66b16bc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java @@ -9,8 +9,8 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java index ea7b064de8464..df3e0756ea39a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/InferenceRescorer.java @@ -16,9 +16,9 @@ import org.apache.lucene.search.TopDocs; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Strings; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.rescore.Rescorer; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java index 80bc6f208a501..48b570e927b15 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractor.java @@ -21,8 +21,8 @@ import java.util.Set; import java.util.function.Consumer; +import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD; import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.MODEL_ID_RESULTS_FIELD; import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.TYPE; /** diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregationTests.java index 5dc7f43118dd7..016a89fe4e4b7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregationTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.InvalidAggregationPathException; @@ -19,7 +20,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/ParsedInference.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/ParsedInference.java index fa275b3f9f400..9b11fc166ef58 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/ParsedInference.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/ParsedInference.java @@ -30,7 +30,7 @@ /** * There isn't enough information in toXContent representation of the - * {@link org.elasticsearch.xpack.core.ml.inference.results.InferenceResults} + * {@link org.elasticsearch.inference.InferenceResults} * objects to fully reconstruct them. In particular, depending on which * fields are written (result value, feature importance) it is not possible to * distinguish between a Regression result and a Classification result. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java index 9885e1b46045b..b6d53b949ce15 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; @@ -25,7 +26,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java index 181b6abd5b549..c937e9be24b01 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskAwareRequest; @@ -19,7 +20,6 @@ import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 0029276d20e6d..3854ad984467d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.test.ESTestCase; @@ -34,7 +35,6 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.MlConfigVersion; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 4709925dbe7d7..faa64969fb6b2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.TestIngestDocument; import org.elasticsearch.license.License; @@ -16,7 +17,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; @@ -46,7 +46,7 @@ import java.util.Map; import java.util.TreeMap; -import static org.elasticsearch.xpack.core.ml.inference.results.InferenceResults.writeResult; +import static org.elasticsearch.inference.InferenceResults.writeResult; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModelTests.serializeFromTrainedModel; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.closeTo; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 644b619537741..2f4640cfa38dc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.license.XPackLicenseState; @@ -40,7 +41,6 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java index 2fa2149312ca2..f7b8b8a0967f9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/InferenceProcessorInfoExtractorTests.java @@ -15,13 +15,13 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformStatsAction.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformStatsAction.java index 2649da6959205..98343d5593dfa 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformStatsAction.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportGetTransformStatsAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.support.tasks.BaseTasksResponse; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.client.internal.Client; @@ -138,7 +139,11 @@ protected void taskOperation(CancellableTask actionTask, Request request, Transf ), // at this point the transport already spend some time budget in `doExecute`, it is hard to tell what is left: // recording the time spend would be complex and crosses machine boundaries, that's why we use a heuristic here - TimeValue.timeValueMillis((long) (request.getTimeout().millis() * CHECKPOINT_INFO_TIMEOUT_SHARE)) + TimeValue.timeValueMillis( + (long) ((request.getTimeout() != null + ? request.getTimeout().millis() + : AcknowledgedRequest.DEFAULT_ACK_TIMEOUT.millis()) * CHECKPOINT_INFO_TIMEOUT_SHARE) + ) ); } else { listener.onResponse(new Response(Collections.emptyList()));