From d20b45bb8bd0124ad26b891a9252d230bbd8f079 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com> Date: Thu, 12 Oct 2023 09:53:12 -0400 Subject: [PATCH] [ML] Fixing put inference service request hanging (#100725) (#100763) * Handling in cluster non cloud listener on response * Cleaning up code * spelling (cherry picked from commit 0848c2b85bb9ffeecfa9dff4f6d9a2e9b13b0b0d) Co-authored-by: Elastic Machine --- .../integration/MockInferenceServiceIT.java | 26 ++++++--- .../TestInferenceServicePlugin.java | 54 +++++++++++++------ .../TransportPutInferenceModelAction.java | 17 ++++-- 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java index e211cb0647774..0da0340084cba 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java @@ -75,7 +75,17 @@ protected Function getClientWrapper() { public void testMockService() { String modelId = "test-mock"; - ModelConfigurations putModel = putMockService(modelId, TaskType.SPARSE_EMBEDDING); + ModelConfigurations putModel = putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING); + ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING); + assertModelsAreEqual(putModel, readModel); + + // The response is randomly generated, the input can be anything + inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, randomAlphaOfLength(10)); + } + + public void testMockInClusterService() { + String modelId = "test-mock-in-cluster"; + ModelConfigurations putModel = putMockService(modelId, "test_service_in_cluster_service", TaskType.SPARSE_EMBEDDING); ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING); assertModelsAreEqual(putModel, readModel); @@ -85,7 +95,7 @@ public void testMockService() { public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException { String modelId = "test-mock"; - putMockService(modelId, TaskType.SPARSE_EMBEDDING); + putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING); ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING); assertThat(readModel.getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class)); @@ -103,7 +113,7 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti public void testGetUnparsedModelMap_ForTestServiceModel_ReturnsSecretsPopulated() { String modelId = "test-unparsed"; - putMockService(modelId, TaskType.SPARSE_EMBEDDING); + putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING); var listener = new PlainActionFuture(); modelRegistry.getUnparsedModelMap(modelId, listener); @@ -114,10 +124,10 @@ public void testGetUnparsedModelMap_ForTestServiceModel_ReturnsSecretsPopulated( assertThat(secrets.apiKey(), is("abc64")); } - private ModelConfigurations putMockService(String modelId, TaskType taskType) { - String body = """ + private ModelConfigurations putMockService(String modelId, String serviceName, TaskType taskType) { + String body = Strings.format(""" { - "service": "test_service", + "service": "%s", "service_settings": { "model": "my_model", "api_key": "abc64" @@ -126,7 +136,7 @@ private ModelConfigurations putMockService(String modelId, TaskType taskType) { "temperature": 3 } } - """; + """, serviceName); var request = new PutInferenceModelAction.Request( taskType.toString(), modelId, @@ -135,7 +145,7 @@ private ModelConfigurations putMockService(String modelId, TaskType taskType) { ); var response = client().execute(PutInferenceModelAction.INSTANCE, request).actionGet(); - assertEquals("test_service", response.getModel().getService()); + assertEquals(serviceName, response.getModel().getService()); assertThat(response.getModel().getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class)); var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) response.getModel().getServiceSettings(); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java index b72fa99efaf72..96625e6bec031 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java @@ -42,7 +42,7 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi @Override public List getInferenceServiceFactories() { - return List.of(TestInferenceService::new); + return List.of(TestInferenceService::new, TestInferenceServiceClusterService::new); } @Override @@ -54,10 +54,39 @@ public List getInferenceServiceNamedWriteables() { ); } - public static class TestInferenceService implements InferenceService { - + public static class TestInferenceService extends TestInferenceServiceBase { private static final String NAME = "test_service"; + public TestInferenceService(InferenceServiceFactoryContext context) { + super(context); + } + + @Override + public String name() { + return NAME; + } + } + + public static class TestInferenceServiceClusterService extends TestInferenceServiceBase { + private static final String NAME = "test_service_in_cluster_service"; + + public TestInferenceServiceClusterService(InferenceServiceFactoryContext context) { + super(context); + } + + @Override + public boolean isInClusterService() { + return true; + } + + @Override + public String name() { + return NAME; + } + } + + public abstract static class TestInferenceServiceBase implements InferenceService { + private static Map getTaskSettingsMap(Map settings) { Map taskSettingsMap; // task settings are optional @@ -70,13 +99,8 @@ private static Map getTaskSettingsMap(Map settin return taskSettingsMap; } - public TestInferenceService(InferenceServicePlugin.InferenceServiceFactoryContext context) { - - } + public TestInferenceServiceBase(InferenceServicePlugin.InferenceServiceFactoryContext context) { - @Override - public String name() { - return NAME; } @Override @@ -93,11 +117,11 @@ public TestServiceModel parseRequestConfig( var taskSettingsMap = getTaskSettingsMap(config); var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - throwIfNotEmptyMap(config, NAME); - throwIfNotEmptyMap(serviceSettingsMap, NAME); - throwIfNotEmptyMap(taskSettingsMap, NAME); + throwIfNotEmptyMap(config, name()); + throwIfNotEmptyMap(serviceSettingsMap, name()); + throwIfNotEmptyMap(taskSettingsMap, name()); - return new TestServiceModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); } @Override @@ -116,7 +140,7 @@ public TestServiceModel parsePersistedConfig( var taskSettingsMap = getTaskSettingsMap(config); var taskSettings = TestTaskSettings.fromMap(taskSettingsMap); - return new TestServiceModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings); + return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings); } @Override @@ -125,7 +149,7 @@ public void infer(Model model, String input, Map taskSettings, A case SPARSE_EMBEDDING -> listener.onResponse(TextExpansionResultsTests.createRandomResults(1, 10)); default -> listener.onFailure( new ElasticsearchStatusException( - TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), NAME), + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), RestStatus.BAD_REQUEST ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 58f781d99b26a..046eff3e6b830 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -102,8 +102,17 @@ protected void masterOperation( // information when creating the model MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(ActionListener.wrap(architectures -> { if (architectures.isEmpty() && clusterIsInElasticCloud(clusterService.getClusterSettings())) { - // In Elastic cloud ml nodes run on Linux x86 - architectures = Set.of("linux-x86_64"); + parseAndStoreModel( + service.get(), + request.getModelId(), + request.getTaskType(), + requestAsMap, + // In Elastic cloud ml nodes run on Linux x86 + Set.of("linux-x86_64"), + listener + ); + } else { + // The architecture field could be an empty set, the individual services will need to handle that parseAndStoreModel(service.get(), request.getModelId(), request.getTaskType(), requestAsMap, architectures, listener); } }, listener::onFailure), client, threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME)); @@ -118,10 +127,10 @@ private void parseAndStoreModel( String modelId, TaskType taskType, Map config, - Set platfromArchitectures, + Set platformArchitectures, ActionListener listener ) { - var model = service.parseRequestConfig(modelId, taskType, config, platfromArchitectures); + var model = service.parseRequestConfig(modelId, taskType, config, platformArchitectures); // model is valid good to persist then start this.modelRegistry.storeModel(model, ActionListener.wrap(r -> { startModel(service, model, listener); }, listener::onFailure)); }