diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 24b305e382160..fa8e4672d5cf2 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -138,10 +138,10 @@ void chunkedInfer( /** * Stop the model deployment. * The default action does nothing except acknowledge the request (true). - * @param unparsedModel The unparsed model configuration + * @param model The model configuration * @param listener The listener */ - default void stop(UnparsedModel unparsedModel, ActionListener listener) { + default void stop(Model model, ActionListener listener) { listener.onResponse(true); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index c1dbd8cfec9d5..4e80d01f85543 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -115,7 +115,9 @@ private void doExecuteForked( var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { - service.get().stop(unparsedModel, listener); + var model = service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + service.get().stop(model, listener); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 5f97f3bad3dc8..06401bd766e40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -22,7 +22,6 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -119,9 +118,7 @@ public void start(Model model, ActionListener finalListener) { } @Override - public void stop(UnparsedModel unparsedModel, ActionListener listener) { - - var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + public void stop(Model model, ActionListener listener) { if (model instanceof ElasticsearchInternalModel esModel) { var serviceSettings = esModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ee8d4b0fbbc6b..0d3dfeeb76195 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -500,7 +500,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M @Override public void checkModelConfig(Model model, ActionListener listener) { - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java new file mode 100644 index 0000000000000..8fefb7e8f3acc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; + +public class ElasticsearchInternalServiceModelValidator implements ModelValidator { + + ModelValidator modelValidator; + + public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) { + this.modelValidator = modelValidator; + } + + @Override + public void validate(InferenceService service, Model model, ActionListener listener) { + modelValidator.validate(service, model, listener.delegateResponse((l, exception) -> { + // TODO: Cleanup the below code + service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception))); + })); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index 0464e790ba79a..75be41d0fe278 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -11,6 +11,15 @@ import org.elasticsearch.inference.TaskType; public class ModelValidatorBuilder { + public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) { + var modelValidator = buildModelValidator(taskType); + if (isElasticsearchInternalService) { + return new ElasticsearchInternalServiceModelValidator(modelValidator); + } else { + return modelValidator; + } + } + public static ModelValidator buildModelValidator(TaskType taskType) { if (taskType == null) { throw new IllegalArgumentException("Task type can't be null"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 510a093e9c162..c32c5223e5aee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -1463,7 +1463,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { ); var request = (InferModelAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getId(), is("custom-model")); + assertThat(request.getId(), is(randomInferenceEntityId)); return Void.TYPE; }).when(client).execute(eq(InferModelAction.INSTANCE), any(), any()); when(client.threadPool()).thenReturn(threadPool);