Skip to content

Commit

Permalink
add IT for remote model automatic deploy with TTL
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed May 9, 2024
1 parent 89f23d2 commit f0b0bec
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -861,6 +862,9 @@ public static Map parseResponseToMap(Response response) throws IOException {
public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException {
Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null);
Map profile = parseResponseToMap(response);
if (profile == null || profile.get("nodes") == null) {
return new HashMap();
}
Map<String, Object> nodeProfiles = (Map) profile.get("nodes");
for (Map.Entry<String, Object> entry : nodeProfiles.entrySet()) {
Map<String, Object> modelProfiles = (Map) entry.getValue();
Expand Down Expand Up @@ -918,6 +922,17 @@ public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
};
}

public Consumer<Map<String, Object>> verifyRemoteModelDeployed() {
return (modelProfile) -> {
if (modelProfile.containsKey("model_state")) {
assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state"));
assertTrue(((String) modelProfile.get("predictor")).startsWith("org.opensearch.ml.engine.algorithms.remote.RemoteModel@"));
}
List<String> workNodes = (List) modelProfile.get("worker_nodes");
assertTrue(workNodes.size() > 0);
};
}

public Map undeployModel(String modelId) throws IOException {
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.apache.commons.lang3.exception.ExceptionUtils;
Expand Down Expand Up @@ -237,6 +238,39 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
assertFalse(((String) responseMap.get("text")).isEmpty());
}

public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
System.out.println("OPENAI_KEY is null");
return;
}
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");
if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());

getModelProfile(modelId, verifyRemoteModelDeployed());
TimeUnit.SECONDS.sleep(71);
assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty());
}

public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
throws IOException,
InterruptedException {
Expand Down Expand Up @@ -841,6 +875,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"description\": \"This is an example description\"\n"
+ "}";
Response response = TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/model_groups/_register",
null,
TestHelper.toHttpEntity(registerModelGroupEntity),
null
);
Map responseMap = parseResponseToMap(response);
assertEquals((String) responseMap.get("status"), "CREATED");
String modelGroupId = (String) responseMap.get("model_group_id");

String registerModelEntity = "{\n"
+ " \"name\": \""
+ name
+ "\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"version\": \"1.0.0\",\n"
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"deploy_setting\": "
+ " { \"model_ttl_minutes\": "
+ ttl
+ "}\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
Expand Down

0 comments on commit f0b0bec

Please sign in to comment.