From 4f7dc901c80aaa9290859853bbfb04d6c6dc21cb Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Mon, 13 May 2024 13:35:27 -0700 Subject: [PATCH] add IT for remote model automatic deploy with TTL (#2431) * add IT for remote model automatic deploy with TTL Signed-off-by: Xun Zhang * remove duplicate and unuseful remote inference ITs Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../ml/rest/MLCommonsRestTestCase.java | 15 +++ .../ml/rest/RestMLRemoteInferenceIT.java | 99 +++++++++++-------- 2 files changed, 71 insertions(+), 43 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 65767339d2..5fde0e8be8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -27,6 +27,7 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -861,6 +862,9 @@ public static Map parseResponseToMap(Response response) throws IOException { public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException { Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null); Map profile = parseResponseToMap(response); + if (profile == null || profile.get("nodes") == null) { + return new HashMap(); + } Map nodeProfiles = (Map) profile.get("nodes"); for (Map.Entry entry : nodeProfiles.entrySet()) { Map modelProfiles = (Map) entry.getValue(); @@ -918,6 +922,17 @@ public Consumer> verifyTextEmbeddingModelDeployed() { }; } + public Consumer> verifyRemoteModelDeployed() { + return (modelProfile) -> { + if (modelProfile.containsKey("model_state")) { + assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state")); + assertTrue(((String) modelProfile.get("predictor")).startsWith("org.opensearch.ml.engine.algorithms.remote.RemoteModel@")); + } + List workNodes = (List) modelProfile.get("worker_nodes"); + assertTrue(workNodes.size() > 0); + }; + } + public Map undeployModel(String modelId) throws IOException { Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 8a6bd85727..d301e2b381 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -8,12 +8,14 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; @@ -170,19 +172,6 @@ public void testSearchMLTasks_afterCreation() throws IOException { assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } - public void testRegisterRemoteModel() throws IOException, InterruptedException { - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - assertNotNull(responseMap.get("model_id")); - } - public void testDeployRemoteModel() throws IOException, InterruptedException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); @@ -201,25 +190,18 @@ public void testDeployRemoteModel() throws IOException, InterruptedException { waitForTask(taskId, MLTaskState.COMPLETED); } - public void testPredictRemoteModel() throws IOException, InterruptedException { + public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { + System.out.println("OPENAI_KEY is null"); return; } Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); + response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1); responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); @@ -235,6 +217,10 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { } responseMap = (Map) responseList.get(0); assertFalse(((String) responseMap.get("text")).isEmpty()); + + getModelProfile(modelId, verifyRemoteModelDeployed()); + TimeUnit.SECONDS.sleep(71); + assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty()); } public void testPredictRemoteModelWithInterface(String testCase, Consumer verifyResponse, Consumer verifyException) @@ -301,26 +287,6 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException, }); } - public void testUndeployRemoteModel() throws IOException, InterruptedException { - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = undeployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - assertTrue(responseMap.toString().contains("undeployed")); - } - public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -384,8 +350,13 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep responseMap = parseResponseToMap(response); // TODO handle throttling error assertNotNull(responseMap); + + response = undeployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + assertTrue(responseMap.toString().contains("undeployed")); } + @Ignore public void testOpenAIEditsModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -457,6 +428,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { assertFalse(((String) responseMap.get("content")).isEmpty()); } + @Ignore public void testOpenAIModerationsModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -687,6 +659,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti assertFalse(((String) responseMap.get("text")).isEmpty()); } + @Ignore public void testCohereClassifyModel() throws IOException, InterruptedException { // Skip test if key is null if (COHERE_KEY == null) { @@ -841,6 +814,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } + public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException { + String registerModelGroupEntity = "{\n" + + " \"name\": \"remote_model_group\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + String registerModelEntity = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"version\": \"1.0.0\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"deploy_setting\": " + + " { \"model_ttl_minutes\": " + + ttl + + "}\n" + + "}"; + return TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); + } + public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException { String registerModelGroupEntity = "{\n" + " \"name\": \"remote_model_group\",\n"