From d7011abff035a710b828f1019e9acd0ba51a996b Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 25 Sep 2023 14:53:10 +0100 Subject: [PATCH] Make Inference Services plugins --- server/src/main/java/module-info.java | 1 + .../inference/InferenceResults.java | 14 ++ .../inference}/InferenceService.java | 30 ++- .../inference/InferenceServiceRegistry.java | 69 +++++ .../org/elasticsearch}/inference/Model.java | 7 +- .../inference/ServiceSettings.java | 7 +- .../inference/TaskSettings.java | 7 +- .../elasticsearch}/inference/TaskType.java | 11 +- .../java/org/elasticsearch/node/Node.java | 12 +- .../plugins/InferenceServicePlugin.java | 43 ++++ .../MlInferenceNamedXContentProvider.java | 2 +- .../inference/results/InferenceResults.java | 4 +- .../integration/MockInferenceServiceIT.java | 131 ++++++++++ .../integration/ModelRegistryIT.java | 8 +- .../TestInferenceServicePlugin.java | 238 ++++++++++++++++++ .../inference/src/main/java/module-info.java | 1 - .../InferenceNamedWriteablesProvider.java | 9 +- .../xpack/inference/InferencePlugin.java | 28 +-- .../xpack/inference/UnparsedModel.java | 2 + .../action/DeleteInferenceModelAction.java | 2 +- .../action/GetInferenceModelAction.java | 2 +- .../inference/action/InferenceAction.java | 14 +- .../action/PutInferenceModelAction.java | 4 +- .../TransportGetInferenceModelAction.java | 6 +- .../action/TransportInferenceAction.java | 16 +- .../TransportPutInferenceModelAction.java | 10 +- .../inference/registry/ModelRegistry.java | 2 +- .../inference/registry/ServiceRegistry.java | 31 --- .../inference/results/InferenceResult.java | 13 - .../results/SparseEmbeddingResult.java | 82 ------ .../inference/services/MapParsingUtils.java | 6 + .../services/elser/ElserMlNodeModel.java | 4 +- .../services/elser/ElserMlNodeService.java | 46 ++-- .../elser/ElserMlNodeServiceSettings.java | 4 +- .../elser/ElserMlNodeTaskSettings.java | 2 +- .../xpack/inference/ModelTests.java | 4 + .../action/GetInferenceModelRequestTests.java | 2 +- .../action/InferenceActionRequestTests.java | 2 +- .../action/InferenceActionResponseTests.java | 12 +- .../action/PutInferenceModelRequestTests.java | 2 +- .../registry/ServiceRegistryTests.java | 28 --- .../results/SparseEmbeddingResultTests.java | 47 ---- .../elser/ElserMlNodeServiceTests.java | 14 +- 43 files changed, 646 insertions(+), 333 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceResults.java rename {x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services => server/src/main/java/org/elasticsearch/inference}/InferenceService.java (74%) create mode 100644 server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java rename {x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack => server/src/main/java/org/elasticsearch}/inference/Model.java (94%) rename {x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack => server/src/main/java/org/elasticsearch}/inference/ServiceSettings.java (62%) rename {x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack => server/src/main/java/org/elasticsearch}/inference/TaskSettings.java (62%) rename {x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack => server/src/main/java/org/elasticsearch}/inference/TaskType.java (77%) create mode 100644 server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 1568f47461e6..99ce5910c977 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -278,6 +278,7 @@ exports org.elasticsearch.indices.recovery; exports org.elasticsearch.indices.recovery.plan; exports org.elasticsearch.indices.store; + exports org.elasticsearch.inference; exports org.elasticsearch.ingest; exports org.elasticsearch.internal to diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java new file mode 100644 index 000000000000..5cd1b548d8ec --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xcontent.ToXContentFragment; + +public interface InferenceResults extends NamedWriteable, ToXContentFragment {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java similarity index 74% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/InferenceService.java rename to server/src/main/java/org/elasticsearch/inference/InferenceService.java index 18704f4c3274..3bc9cf2a5a11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -1,16 +1,14 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference.services; +package org.elasticsearch.inference; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; -import org.elasticsearch.xpack.inference.results.InferenceResult; import java.util.Map; @@ -45,20 +43,20 @@ public interface InferenceService { */ Model parseConfigLenient(String modelId, TaskType taskType, Map config); - /** - * Start or prepare the model for use. - * @param model The model - * @param listener The listener - */ - void start(Model model, ActionListener listener); - /** * Perform inference on the model. * - * @param model Model configuration + * @param model The model * @param input Inference input - * @param requestTaskSettings Settings in the request to override the model's defaults + * @param taskSettings Settings in the request to override the model's defaults * @param listener Inference result listener */ - void infer(Model model, String input, Map requestTaskSettings, ActionListener listener); + void infer(Model model, String input, Map taskSettings, ActionListener listener); + + /** + * Start or prepare the model for use. + * @param model The model + * @param listener The listener + */ + void start(Model model, ActionListener listener); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java new file mode 100644 index 000000000000..ac1439150f8e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.component.AbstractLifecycleComponent; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.plugins.InferenceServicePlugin; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class InferenceServiceRegistry extends AbstractLifecycleComponent { + + private final Map services; + private final List namedWriteables = new ArrayList<>(); + + public InferenceServiceRegistry( + List inferenceServicePlugins, + InferenceServicePlugin.InferenceServiceFactoryContext factoryContext + ) { + // TODO check names are unique + services = inferenceServicePlugins.stream() + .flatMap(r -> r.getInferenceServiceFactories().stream()) + .map(factory -> factory.create(factoryContext)) + .collect(Collectors.toMap(InferenceService::name, Function.identity())); + + for (var plugin : inferenceServicePlugins) { + namedWriteables.addAll(plugin.getInferenceServiceNamedWriteables()); + } + } + + public Map getServices() { + return services; + } + + public Optional getService(String serviceName) { + return Optional.ofNullable(services.get(serviceName)); + } + + public List getNamedWriteables() { + return namedWriteables; + } + + @Override + protected void doStart() { + + } + + @Override + protected void doStop() { + + } + + @Override + protected void doClose() throws IOException { + + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java similarity index 94% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/Model.java rename to server/src/main/java/org/elasticsearch/inference/Model.java index c0032ebf25ca..67ee58bad733 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/Model.java +++ b/server/src/main/java/org/elasticsearch/inference/Model.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/ServiceSettings.java b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java similarity index 62% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/ServiceSettings.java rename to server/src/main/java/org/elasticsearch/inference/ServiceSettings.java index 16f39ad8e560..01a75158f5ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/ServiceSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ServiceSettings.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskSettings.java b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java similarity index 62% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskSettings.java rename to server/src/main/java/org/elasticsearch/inference/TaskSettings.java index 200f1b309822..e346fa89b866 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskSettings.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java similarity index 77% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskType.java rename to server/src/main/java/org/elasticsearch/inference/TaskType.java index 5e9cc9327003..9e96a7c4c52d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference; +package org.elasticsearch.inference; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.StreamInput; @@ -49,4 +50,8 @@ public static TaskType fromStream(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeEnum(this); } + + public static String unsupportedTaskTypeErrorMsg(TaskType taskType, String serviceName) { + return "The [" + serviceName + "] service does not support task type [" + taskType + "]"; + } } diff --git a/server/src/main/java/org/elasticsearch/node/Node.java b/server/src/main/java/org/elasticsearch/node/Node.java index ad8ad68bf765..a2da65bdaa8b 100644 --- a/server/src/main/java/org/elasticsearch/node/Node.java +++ b/server/src/main/java/org/elasticsearch/node/Node.java @@ -146,6 +146,7 @@ import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; import org.elasticsearch.indices.store.IndicesStore; +import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -165,6 +166,7 @@ import org.elasticsearch.plugins.EnginePlugin; import org.elasticsearch.plugins.HealthPlugin; import org.elasticsearch.plugins.IndexStorePlugin; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MetadataUpgrader; @@ -532,6 +534,12 @@ protected Node( Supplier documentParsingObserverSupplier = getDocumentParsingObserverSupplier(); + var factoryContext = new InferenceServicePlugin.InferenceServiceFactoryContext(client); + final InferenceServiceRegistry inferenceServiceRegistry = new InferenceServiceRegistry( + pluginsService.filterPlugins(InferenceServicePlugin.class), + factoryContext + ); + final IngestService ingestService = new IngestService( clusterService, threadPool, @@ -555,7 +563,8 @@ protected Node( searchModule.getNamedWriteables().stream(), pluginsService.flatMap(Plugin::getNamedWriteables), ClusterModule.getNamedWriteables().stream(), - SystemIndexMigrationExecutor.getNamedWriteables().stream() + SystemIndexMigrationExecutor.getNamedWriteables().stream(), + inferenceServiceRegistry.getNamedWriteables().stream() ).flatMap(Function.identity()).toList(); final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); NamedXContentRegistry xContentRegistry = new NamedXContentRegistry( @@ -1170,6 +1179,7 @@ protected Node( b.bind(WriteLoadForecaster.class).toInstance(writeLoadForecaster); b.bind(HealthPeriodicLogger.class).toInstance(healthPeriodicLogger); b.bind(CompatibilityVersions.class).toInstance(compatibilityVersions); + b.bind(InferenceServiceRegistry.class).toInstance(inferenceServiceRegistry); }); if (ReadinessService.enabled(environment)) { diff --git a/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java b/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java new file mode 100644 index 000000000000..2672a4b8fcbc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/plugins/InferenceServicePlugin.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.plugins; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.InferenceService; + +import java.util.List; + +/** + * InferenceServicePlugins implement an inference service + */ +public interface InferenceServicePlugin { + + List getInferenceServiceFactories(); + + record InferenceServiceFactoryContext(Client client) {} + + interface Factory { + /** + * InferenceServices are created from the factory context + */ + InferenceService create(InferenceServiceFactoryContext context); + } + + /** + * The named writables defined and used by each of the implemented + * InferenceServices. Each service should define named writables for + * - {@link org.elasticsearch.inference.TaskSettings} + * - {@link org.elasticsearch.inference.ServiceSettings} + * And optionally for {@link org.elasticsearch.inference.InferenceResults} + * if the service uses a new type of result. + * @return All named writables defined by the services + */ + List getInferenceServiceNamedWriteables(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 1ef60af3584e..7f0d12af5f46 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; @@ -21,7 +22,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncoding; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java index 83f08391c656..0f8935b7206e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -6,14 +6,12 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.ingest.IngestDocument; -import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.Map; -public interface InferenceResults extends NamedWriteable, ToXContentFragment { +public interface InferenceResults extends org.elasticsearch.inference.InferenceResults { String PREDICTION_PROBABILITY = "prediction_probability"; String MODEL_ID_RESULTS_FIELD = "model_id"; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java new file mode 100644 index 000000000000..411c29255fd7 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.PutInferenceModelAction; + +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; + +public class MockInferenceServiceIT extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return List.of(InferencePlugin.class, TestInferenceServicePlugin.class); + } + + @Override + protected Function getClientWrapper() { + final Map headers = Map.of( + "Authorization", + basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING) + ); + // we need to wrap node clients because we do not specify a user for nodes and all requests will use the system + // user. This is ok for internal n2n stuff but the test framework does other things like wiping indices, repositories, etc + // that the system user cannot do. so we wrap the node client with a user that can do these things since the client() calls + // return a node client + return client -> client.filterWithHeader(headers); + } + + public void testMockService() { + String modelId = "test-mock"; + Model putModel = putMockService(modelId, TaskType.SPARSE_EMBEDDING); + Model readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING); + assertModelsAreEqual(putModel, readModel); + + // The response is randomly generated, the input can be anything + inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, randomAlphaOfLength(10)); + } + + private Model putMockService(String modelId, TaskType taskType) { + String body = """ + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + "temperature": 3 + } + } + """; + var request = new PutInferenceModelAction.Request( + taskType.toString(), + modelId, + new BytesArray(body.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON + ); + + var response = client().execute(PutInferenceModelAction.INSTANCE, request).actionGet(); + assertEquals("test_service", response.getModel().getService()); + + assertThat(response.getModel().getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class)); + var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) response.getModel().getServiceSettings(); + assertEquals("my_model", serviceSettings.model()); + assertEquals("abc64", serviceSettings.apiKey()); + + assertThat(response.getModel().getTaskSettings(), instanceOf(TestInferenceServicePlugin.TestTaskSettings.class)); + var taskSettings = (TestInferenceServicePlugin.TestTaskSettings) response.getModel().getTaskSettings(); + assertEquals(3, (int) taskSettings.temperature()); + + return response.getModel(); + } + + public Model getModel(String modelId, TaskType taskType) { + var response = client().execute(GetInferenceModelAction.INSTANCE, new GetInferenceModelAction.Request(modelId, taskType.toString())) + .actionGet(); + return response.getModel(); + } + + private void inferOnMockService(String modelId, TaskType taskType, String input) { + var response = client().execute(InferenceAction.INSTANCE, new InferenceAction.Request(taskType, modelId, input, Map.of())) + .actionGet(); + if (taskType == TaskType.SPARSE_EMBEDDING) { + assertThat(response.getResult(), instanceOf(TextExpansionResults.class)); + var teResult = (TextExpansionResults) response.getResult(); + assertThat(teResult.getWeightedTokens(), not(empty())); + } else { + fail("test with task type [" + taskType + "] are not supported yet"); + } + } + + private void assertModelsAreEqual(Model model1, Model model2) { + // The test can't rely on Model::equals as the specific subclass + // may be different. Model loses information about it's implemented + // subtype when it is streamed across the wire. + assertEquals(model1.getModelId(), model2.getModelId()); + assertEquals(model1.getService(), model2.getService()); + assertEquals(model1.getTaskType(), model2.getTaskType()); + + // TaskSettings and Service settings are named writables so + // the actual implementing class type is not lost when streamed \ + assertEquals(model1.getServiceSettings(), model2.getServiceSettings()); + assertEquals(model1.getTaskSettings(), model2.getTaskSettings()); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index fbb6bb7e316f..a400f84e3c2e 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -9,15 +9,15 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.ServiceSettings; -import org.elasticsearch.xpack.inference.TaskSettings; -import org.elasticsearch.xpack.inference.TaskType; import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java new file mode 100644 index 000000000000..af81a8378a86 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.InferenceServicePlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; +import org.elasticsearch.xpack.inference.services.MapParsingUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.MapParsingUtils.throwIfNotEmptyMap; + +public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin { + + @Override + public List getInferenceServiceFactories() { + return List.of(TestInferenceService::new); + } + + @Override + public List getInferenceServiceNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new), + new NamedWriteableRegistry.Entry(TaskSettings.class, TestTaskSettings.NAME, TestTaskSettings::new) + ); + } + + public class TestInferenceService implements InferenceService { + + private static final String NAME = "test_service"; + + public static TestServiceModel parseConfig( + boolean throwOnUnknownFields, + String modelId, + TaskType taskType, + Map settings + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(settings, Model.SERVICE_SETTINGS); + var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap); + + Map taskSettingsMap; + // task settings are optional + if (settings.containsKey(Model.TASK_SETTINGS)) { + taskSettingsMap = removeFromMapOrThrowIfNull(settings, Model.TASK_SETTINGS); + } else { + taskSettingsMap = Map.of(); + } + + var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); + + if (throwOnUnknownFields) { + throwIfNotEmptyMap(settings, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + } + + return new TestServiceModel(modelId, taskType, NAME, serviceSettings, taskSettings); + } + + public TestInferenceService(InferenceServicePlugin.InferenceServiceFactoryContext context) { + + } + + @Override + public String name() { + return NAME; + } + + @Override + public TestServiceModel parseConfigStrict(String modelId, TaskType taskType, Map config) { + return parseConfig(true, modelId, taskType, config); + } + + @Override + public TestServiceModel parseConfigLenient(String modelId, TaskType taskType, Map config) { + return parseConfig(false, modelId, taskType, config); + } + + @Override + public void infer(Model model, String input, Map taskSettings, ActionListener listener) { + switch (model.getTaskType()) { + case SPARSE_EMBEDDING -> listener.onResponse(TextExpansionResultsTests.createRandomResults()); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME), + RestStatus.BAD_REQUEST + ) + ); + } + + } + + @Override + public void start(Model model, ActionListener listener) { + listener.onResponse(true); + } + } + + public static class TestServiceModel extends Model { + + public TestServiceModel( + String modelId, + TaskType taskType, + String service, + TestServiceSettings serviceSettings, + TestTaskSettings taskSettings + ) { + super(modelId, taskType, service, serviceSettings, taskSettings); + } + + @Override + public TestServiceSettings getServiceSettings() { + return (TestServiceSettings) super.getServiceSettings(); + } + + @Override + public TestTaskSettings getTaskSettings() { + return (TestTaskSettings) super.getTaskSettings(); + } + } + + public record TestServiceSettings(String model, String apiKey) implements ServiceSettings { + + private static final String NAME = "test_service_settings"; + + public static TestServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = MapParsingUtils.removeAsType(map, "model", String.class); + String apiKey = MapParsingUtils.removeAsType(map, "api_key", String.class); + + if (model == null) { + validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("model", Model.SERVICE_SETTINGS)); + } + if (apiKey == null) { + validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("api_key", Model.SERVICE_SETTINGS)); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new TestServiceSettings(model, apiKey); + } + + public TestServiceSettings(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("model", model); + builder.field("api_key", apiKey); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + out.writeString(apiKey); + } + } + + public record TestTaskSettings(Integer temperature) implements TaskSettings { + + private static final String NAME = "test_task_settings"; + + public static TestTaskSettings fromMap(Map map) { + Integer temperature = MapParsingUtils.removeAsType(map, "temperature", Integer.class); + return new TestTaskSettings(temperature); + } + + public TestTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalVInt()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(temperature); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (temperature != null) { + builder.field("temperature", temperature); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index e80d828e4e48..b21f919bbdc8 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -17,6 +17,5 @@ exports org.elasticsearch.xpack.inference.rest; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; - exports org.elasticsearch.xpack.inference.results; exports org.elasticsearch.xpack.inference; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 3bbc0a53a997..3ef93c6c275d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.xpack.inference.results.InferenceResult; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResult; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; @@ -31,11 +31,6 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(TaskSettings.class, ElserMlNodeTaskSettings.NAME, ElserMlNodeTaskSettings::new) ); - // Inference results - namedWriteables.add( - new NamedWriteableRegistry.Entry(InferenceResult.class, SparseEmbeddingResult.NAME, SparseEmbeddingResult::new) - ); - return namedWriteables; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index a5b0754a7f17..ba0f1b142a79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -27,6 +27,7 @@ import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.repositories.RepositoriesService; @@ -47,7 +48,6 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestInferenceAction; @@ -55,11 +55,10 @@ import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.function.Supplier; -public class InferencePlugin extends Plugin implements ActionPlugin, SystemIndexPlugin { +public class InferencePlugin extends Plugin implements ActionPlugin, InferenceServicePlugin, SystemIndexPlugin { public static final String NAME = "inference"; @@ -75,16 +74,6 @@ public class InferencePlugin extends Plugin implements ActionPlugin, SystemIndex ); } - @Override - public List getNamedWriteables() { - return InferenceNamedWriteablesProvider.getNamedWriteables(); - } - - @Override - public List getNamedXContent() { - return Collections.emptyList(); - } - @Override public List getRestHandlers( Settings settings, @@ -121,8 +110,7 @@ public Collection createComponents( IndicesService indicesService ) { ModelRegistry modelRegistry = new ModelRegistry(client); - ServiceRegistry serviceRegistry = new ServiceRegistry(new ElserMlNodeService(client)); - return List.of(modelRegistry, serviceRegistry); + return List.of(modelRegistry); } @Override @@ -155,4 +143,14 @@ public String getFeatureName() { public String getFeatureDescription() { return "Inference plugin for managing inference services and inference"; } + + @Override + public List getInferenceServiceFactories() { + return List.of(ElserMlNodeService::new); + } + + @Override + public List getInferenceServiceNamedWriteables() { + return InferenceNamedWriteablesProvider.getNamedWriteables(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java index 29b4accf1f4f..b6dd41df174e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnparsedModel.java @@ -8,6 +8,8 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java index 2ae1a33f8f5e..4062946935b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/DeleteInferenceModelAction.java @@ -13,7 +13,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.inference.TaskType; +import org.elasticsearch.inference.TaskType; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java index a80ed84d8a6e..7e47d3d93e3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/GetInferenceModelAction.java @@ -12,7 +12,7 @@ import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.inference.TaskType; +import org.elasticsearch.inference.TaskType; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java index 2b95e6153a36..7938c2abd8d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java @@ -14,14 +14,14 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xpack.inference.TaskType; -import org.elasticsearch.xpack.inference.results.InferenceResult; import java.io.IOException; import java.util.Map; @@ -168,15 +168,19 @@ public Request build() { public static class Response extends ActionResponse implements ToXContentObject { - private final InferenceResult result; + private final InferenceResults result; - public Response(InferenceResult result) { + public Response(InferenceResults result) { this.result = result; } public Response(StreamInput in) throws IOException { super(in); - result = in.readNamedWriteable(InferenceResult.class); + result = in.readNamedWriteable(InferenceResults.class); + } + + public InferenceResults getResult() { + return result; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java index 6ec020e6b479..6c59fc89fd15 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/PutInferenceModelAction.java @@ -15,11 +15,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index f11e6101e1e2..1e208e83985c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -12,26 +12,26 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; public class TransportGetInferenceModelAction extends HandledTransportAction< GetInferenceModelAction.Request, PutInferenceModelAction.Response> { private final ModelRegistry modelRegistry; - private final ServiceRegistry serviceRegistry; + private final InferenceServiceRegistry serviceRegistry; @Inject public TransportGetInferenceModelAction( TransportService transportService, ActionFilters actionFilters, ModelRegistry modelRegistry, - ServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry ) { super(GetInferenceModelAction.NAME, transportService, actionFilters, GetInferenceModelAction.Request::new); this.modelRegistry = modelRegistry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 42fa61b406e9..aab8ed98f424 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -11,33 +11,27 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.inference.Model; import org.elasticsearch.xpack.inference.UnparsedModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; -import org.elasticsearch.xpack.inference.services.InferenceService; public class TransportInferenceAction extends HandledTransportAction { private final ModelRegistry modelRegistry; - private final ServiceRegistry serviceRegistry; + private final InferenceServiceRegistry serviceRegistry; @Inject public TransportInferenceAction( - Settings settings, TransportService transportService, - ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, ModelRegistry modelRegistry, - ServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry ) { super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new); this.modelRegistry = modelRegistry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 0f3552372665..7ccaef2464f6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -18,16 +18,16 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.inference.Model; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.registry.ServiceRegistry; -import org.elasticsearch.xpack.inference.services.InferenceService; import java.io.IOException; import java.util.Map; @@ -37,7 +37,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction< PutInferenceModelAction.Response> { private final ModelRegistry modelRegistry; - private final ServiceRegistry serviceRegistry; + private final InferenceServiceRegistry serviceRegistry; @Inject public TransportPutInferenceModelAction( @@ -47,7 +47,7 @@ public TransportPutInferenceModelAction( ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, ModelRegistry modelRegistry, - ServiceRegistry serviceRegistry + InferenceServiceRegistry serviceRegistry ) { super( PutInferenceModelAction.NAME, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 5ad9554959a2..4403ec53e7a1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; +import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; @@ -32,7 +33,6 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.inference.InferenceIndex; -import org.elasticsearch.xpack.inference.Model; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java deleted file mode 100644 index 8767630f5625..000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ServiceRegistry.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.registry; - -import org.elasticsearch.xpack.inference.services.InferenceService; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; - -import java.util.Optional; - -public class ServiceRegistry { - - ElserMlNodeService elserService; - - public ServiceRegistry(ElserMlNodeService elserService) { - this.elserService = elserService; - } - - public Optional getService(String name) { - if (name.equals(ElserMlNodeService.NAME)) { - return Optional.of(elserService); - } - - return Optional.empty(); - } - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java deleted file mode 100644 index 8d8351dbe38d..000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/InferenceResult.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.io.stream.VersionedNamedWriteable; -import org.elasticsearch.xcontent.ToXContentFragment; - -public interface InferenceResult extends ToXContentFragment, VersionedNamedWriteable {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java deleted file mode 100644 index 3f84c91b055a..000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResult.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public class SparseEmbeddingResult implements InferenceResult { - - public static final String NAME = "sparse_embedding_result"; - - private final List weightedTokens; - - public SparseEmbeddingResult(List weightedTokens) { - this.weightedTokens = weightedTokens; - } - - public SparseEmbeddingResult(StreamInput in) throws IOException { - this.weightedTokens = in.readCollectionAsImmutableList(TextExpansionResults.WeightedToken::new); - } - - public List getWeightedTokens() { - return weightedTokens; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject("sparse_embedding"); - for (var weightedToken : weightedTokens) { - weightedToken.toXContent(builder, params); - } - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_500_074; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(weightedTokens); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SparseEmbeddingResult that = (SparseEmbeddingResult) o; - return Objects.equals(weightedTokens, that.weightedTokens); - } - - @Override - public int hashCode() { - return Objects.hash(weightedTokens); - } - - @Override - public String toString() { - return Strings.toString(this); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java index c2b4986b84dc..31228b645cff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/MapParsingUtils.java @@ -54,6 +54,12 @@ public static Map removeFromMapOrThrowIfNull(Map return value; } + public static void throwIfNotEmptyMap(Map settingsMap, String serviceName) { + if (settingsMap.isEmpty() == false) { + throw MapParsingUtils.unknownSettingsError(settingsMap, serviceName); + } + } + public static ElasticsearchStatusException unknownSettingsError(Map config, String serviceName) { // TOOD map as JSON return new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java index 499a336c5d1a..6c7e36e5d81e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeModel.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.inference.services.elser; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; public class ElserMlNodeModel extends Model { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index 602048f2e3e7..45acc467b047 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -9,27 +9,26 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; -import org.elasticsearch.xpack.inference.results.InferenceResult; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResult; -import org.elasticsearch.xpack.inference.services.InferenceService; -import org.elasticsearch.xpack.inference.services.MapParsingUtils; import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.MapParsingUtils.throwIfNotEmptyMap; public class ElserMlNodeService implements InferenceService { @@ -57,9 +56,9 @@ public static ElserMlNodeModel parseConfig( var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); if (throwOnUnknownFields) { - throwIfNotEmptyMap(settings); - throwIfNotEmptyMap(serviceSettingsMap); - throwIfNotEmptyMap(taskSettingsMap); + throwIfNotEmptyMap(settings, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); } return new ElserMlNodeModel(modelId, taskType, NAME, serviceSettings, taskSettings); @@ -67,8 +66,8 @@ public static ElserMlNodeModel parseConfig( private final OriginSettingClient client; - public ElserMlNodeService(Client client) { - this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); + public ElserMlNodeService(InferenceServicePlugin.InferenceServiceFactoryContext context) { + this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN); } @Override @@ -89,7 +88,7 @@ public void start(Model model, ActionListener listener) { } if (model.getTaskType() != TaskType.SPARSE_EMBEDDING) { - listener.onFailure(new IllegalStateException(unsupportedTaskTypeErrorMsg(model.getTaskType()))); + listener.onFailure(new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME))); return; } @@ -109,11 +108,13 @@ public void start(Model model, ActionListener listener) { } @Override - public void infer(Model model, String input, Map requestTaskSettings, ActionListener listener) { + public void infer(Model model, String input, Map taskSettings, ActionListener listener) { // No task settings to override with requestTaskSettings if (model.getTaskType() != TaskType.SPARSE_EMBEDDING) { - listener.onFailure(new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(model.getTaskType()), RestStatus.BAD_REQUEST)); + listener.onFailure( + new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(model.getTaskType(), NAME), RestStatus.BAD_REQUEST) + ); return; } @@ -125,8 +126,7 @@ public void infer(Model model, String input, Map requestTaskSett ); client.execute(InferTrainedModelDeploymentAction.INSTANCE, request, ActionListener.wrap(inferenceResult -> { var textExpansionResult = (TextExpansionResults) inferenceResult.getResults().get(0); - var sparseEmbeddingResult = new SparseEmbeddingResult(textExpansionResult.getWeightedTokens()); - listener.onResponse(sparseEmbeddingResult); + listener.onResponse(textExpansionResult); }, listener::onFailure)); } @@ -136,7 +136,7 @@ private static ElserMlNodeServiceSettings serviceSettingsFromMap(Map config) { if (taskType != TaskType.SPARSE_EMBEDDING) { - throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType), RestStatus.BAD_REQUEST); + throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); } // no config options yet @@ -147,14 +147,4 @@ private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Ma public String name() { return NAME; } - - private static void throwIfNotEmptyMap(Map settingsMap) { - if (settingsMap.isEmpty() == false) { - throw MapParsingUtils.unknownSettingsError(settingsMap, NAME); - } - } - - private static String unsupportedTaskTypeErrorMsg(TaskType taskType) { - return "The [" + NAME + "] service does not support task type [" + taskType + "]"; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java index 1d6a5106a195..1314e6eab4f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java @@ -12,9 +12,9 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.ServiceSettings; import org.elasticsearch.xpack.inference.services.MapParsingUtils; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java index f4c75683783f..c1d84af5b5fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java @@ -11,8 +11,8 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.TaskSettings; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java index 57f4c0650b93..778f4703767a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelTests.java @@ -9,6 +9,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java index 22a4981e092d..0b30dc902103 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/GetInferenceModelRequestTests.java @@ -8,8 +8,8 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.TaskType; public class GetInferenceModelRequestTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index f937ba03ae86..3e1bea005165 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -9,8 +9,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.TaskType; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 13896607fe9a..795923e56c6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -10,16 +10,22 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; -import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultTests; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; public class InferenceActionResponseTests extends AbstractWireSerializingTestCase { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables()); + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + entries.addAll(InferenceNamedWriteablesProvider.getNamedWriteables()); + return new NamedWriteableRegistry(entries); } @Override @@ -29,7 +35,7 @@ protected Writeable.Reader instanceReader() { @Override protected InferenceAction.Response createTestInstance() { - return new InferenceAction.Response(SparseEmbeddingResultTests.createRandomResult()); + return new InferenceAction.Response(TextExpansionResultsTests.createRandomResults()); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index 770faf19585c..9aefea9a942d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.inference.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.TaskType; public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java deleted file mode 100644 index 492fb29f910b..000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ServiceRegistryTests.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.registry; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeService; - -import static org.mockito.Mockito.mock; - -public class ServiceRegistryTests extends ESTestCase { - - public void testGetService() { - ServiceRegistry registry = new ServiceRegistry(mock(ElserMlNodeService.class)); - var service = registry.getService(ElserMlNodeService.NAME); - assertTrue(service.isPresent()); - } - - public void testGetUnknownService() { - ServiceRegistry registry = new ServiceRegistry(mock(ElserMlNodeService.class)); - var service = registry.getService("foo"); - assertFalse(service.isPresent()); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java deleted file mode 100644 index 360dc3e97d14..000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultTests.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.results; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; - -import java.util.ArrayList; -import java.util.List; - -public class SparseEmbeddingResultTests extends AbstractWireSerializingTestCase { - - public static SparseEmbeddingResult createRandomResult() { - int numTokens = randomIntBetween(1, 20); - List tokenList = new ArrayList<>(); - for (int i = 0; i < numTokens; i++) { - tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false))); - } - return new SparseEmbeddingResult(tokenList); - } - - @Override - protected Writeable.Reader instanceReader() { - return SparseEmbeddingResult::new; - } - - @Override - protected SparseEmbeddingResult createTestInstance() { - return createRandomResult(); - } - - @Override - protected SparseEmbeddingResult mutateInstance(SparseEmbeddingResult instance) { - if (instance.getWeightedTokens().size() > 0) { - var tokens = instance.getWeightedTokens(); - return new SparseEmbeddingResult(tokens.subList(0, tokens.size() - 1)); - } else { - return new SparseEmbeddingResult(List.of(new TextExpansionResults.WeightedToken("a", 1.0f))); - } - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java index bdbb4c545900..0449c1b4a7d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java @@ -9,9 +9,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.InferenceServicePlugin; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.Model; -import org.elasticsearch.xpack.inference.TaskType; import java.util.HashMap; import java.util.Map; @@ -35,7 +36,7 @@ public static Model randomModelConfig(String modelId, TaskType taskType) { } public void testParseConfigStrict() { - var service = new ElserMlNodeService(mock(Client.class)); + var service = createService(mock(Client.class)); var settings = new HashMap(); settings.put( @@ -59,7 +60,7 @@ public void testParseConfigStrict() { } public void testParseConfigStrictWithNoTaskSettings() { - var service = new ElserMlNodeService(mock(Client.class)); + var service = createService(mock(Client.class)); var settings = new HashMap(); settings.put( @@ -153,4 +154,9 @@ public void testParseConfigStrictWithUnknownSettings() { } } } + + private ElserMlNodeService createService(Client client) { + var context = new InferenceServicePlugin.InferenceServiceFactoryContext(client); + return new ElserMlNodeService(context); + } }