diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 65767339d2..5fde0e8be8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -27,6 +27,7 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -861,6 +862,9 @@ public static Map parseResponseToMap(Response response) throws IOException { public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException { Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null); Map profile = parseResponseToMap(response); + if (profile == null || profile.get("nodes") == null) { + return new HashMap(); + } Map nodeProfiles = (Map) profile.get("nodes"); for (Map.Entry entry : nodeProfiles.entrySet()) { Map modelProfiles = (Map) entry.getValue(); @@ -918,6 +922,17 @@ public Consumer> verifyTextEmbeddingModelDeployed() { }; } + public Consumer> verifyRemoteModelDeployed() { + return (modelProfile) -> { + if (modelProfile.containsKey("model_state")) { + assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state")); + assertTrue(((String) modelProfile.get("predictor")).startsWith("org.opensearch.ml.engine.algorithms.remote.RemoteModel@")); + } + List workNodes = (List) modelProfile.get("worker_nodes"); + assertTrue(workNodes.size() > 0); + }; + } + public Map undeployModel(String modelId) throws IOException { Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 8a6bd85727..fdfd8b0f6a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -237,6 +238,39 @@ public void testPredictRemoteModel() throws IOException, InterruptedException { assertFalse(((String) responseMap.get("text")).isEmpty()); } + public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + System.out.println("OPENAI_KEY is null"); + return; + } + Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } + responseMap = (Map) responseList.get(0); + assertFalse(((String) responseMap.get("text")).isEmpty()); + + getModelProfile(modelId, verifyRemoteModelDeployed()); + TimeUnit.SECONDS.sleep(71); + assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty()); + } + public void testPredictRemoteModelWithInterface(String testCase, Consumer verifyResponse, Consumer verifyException) throws IOException, InterruptedException { @@ -841,6 +875,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } + public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException { + String registerModelGroupEntity = "{\n" + + " \"name\": \"remote_model_group\",\n" + + " \"description\": \"This is an example description\"\n" + + "}"; + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + TestHelper.toHttpEntity(registerModelGroupEntity), + null + ); + Map responseMap = parseResponseToMap(response); + assertEquals((String) responseMap.get("status"), "CREATED"); + String modelGroupId = (String) responseMap.get("model_group_id"); + + String registerModelEntity = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"version\": \"1.0.0\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\",\n" + + " \"deploy_setting\": " + + " { \"model_ttl_minutes\": " + + ttl + + "}\n" + + "}"; + return TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); + } + public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException { String registerModelGroupEntity = "{\n" + " \"name\": \"remote_model_group\",\n"