Skip to content

Commit

Permalink
add IT for remote model automatic deploy with TTL (opensearch-project…
Browse files Browse the repository at this point in the history
…#2431)

* add IT for remote model automatic deploy with TTL

Signed-off-by: Xun Zhang <[email protected]>

* remove duplicate and unuseful remote inference ITs

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt authored May 13, 2024
1 parent aa09014 commit 4f7dc90
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -861,6 +862,9 @@ public static Map parseResponseToMap(Response response) throws IOException {
public Map getModelProfile(String modelId, Consumer verifyFunction) throws IOException {
Response response = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/profile/models/" + modelId, null, (String) null, null);
Map profile = parseResponseToMap(response);
if (profile == null || profile.get("nodes") == null) {
return new HashMap();
}
Map<String, Object> nodeProfiles = (Map) profile.get("nodes");
for (Map.Entry<String, Object> entry : nodeProfiles.entrySet()) {
Map<String, Object> modelProfiles = (Map) entry.getValue();
Expand Down Expand Up @@ -918,6 +922,17 @@ public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
};
}

public Consumer<Map<String, Object>> verifyRemoteModelDeployed() {
return (modelProfile) -> {
if (modelProfile.containsKey("model_state")) {
assertEquals(MLModelState.DEPLOYED.name(), modelProfile.get("model_state"));
assertTrue(((String) modelProfile.get("predictor")).startsWith("org.opensearch.ml.engine.algorithms.remote.RemoteModel@"));
}
List<String> workNodes = (List) modelProfile.get("worker_nodes");
assertTrue(workNodes.size() > 0);
};
}

public Map undeployModel(String modelId) throws IOException {
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.client.Response;
Expand Down Expand Up @@ -170,19 +172,6 @@ public void testSearchMLTasks_afterCreation() throws IOException {
assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value"));
}

public void testRegisterRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
assertNotNull(responseMap.get("model_id"));
}

public void testDeployRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
Expand All @@ -201,25 +190,18 @@ public void testDeployRemoteModel() throws IOException, InterruptedException {
waitForTask(taskId, MLTaskState.COMPLETED);
}

public void testPredictRemoteModel() throws IOException, InterruptedException {
public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
System.out.println("OPENAI_KEY is null");
return;
}
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
Expand All @@ -235,6 +217,10 @@ public void testPredictRemoteModel() throws IOException, InterruptedException {
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());

getModelProfile(modelId, verifyRemoteModelDeployed());
TimeUnit.SECONDS.sleep(71);
assertTrue(getModelProfile(modelId, verifyRemoteModelDeployed()).isEmpty());
}

public void testPredictRemoteModelWithInterface(String testCase, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
Expand Down Expand Up @@ -301,26 +287,6 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
});
}

public void testUndeployRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = undeployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
assertTrue(responseMap.toString().contains("undeployed"));
}

public void testOpenAIChatCompletionModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -384,8 +350,13 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
responseMap = parseResponseToMap(response);
// TODO handle throttling error
assertNotNull(responseMap);

response = undeployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
assertTrue(responseMap.toString().contains("undeployed"));
}

@Ignore
public void testOpenAIEditsModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -457,6 +428,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
assertFalse(((String) responseMap.get("content")).isEmpty());
}

@Ignore
public void testOpenAIModerationsModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -687,6 +659,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
assertFalse(((String) responseMap.get("text")).isEmpty());
}

@Ignore
public void testCohereClassifyModel() throws IOException, InterruptedException {
// Skip test if key is null
if (COHERE_KEY == null) {
Expand Down Expand Up @@ -841,6 +814,46 @@ public static Response registerRemoteModel(String name, String connectorId) thro
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"description\": \"This is an example description\"\n"
+ "}";
Response response = TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/model_groups/_register",
null,
TestHelper.toHttpEntity(registerModelGroupEntity),
null
);
Map responseMap = parseResponseToMap(response);
assertEquals((String) responseMap.get("status"), "CREATED");
String modelGroupId = (String) responseMap.get("model_group_id");

String registerModelEntity = "{\n"
+ " \"name\": \""
+ name
+ "\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"version\": \"1.0.0\",\n"
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"deploy_setting\": "
+ " { \"model_ttl_minutes\": "
+ ttl
+ "}\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
Expand Down

0 comments on commit 4f7dc90

Please sign in to comment.