Skip to content

Commit

Permalink
remove duplicate and unuseful remote inference ITs
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed May 10, 2024
1 parent 44ab851 commit eaa5c3c
Showing 1 changed file with 8 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.client.Response;
Expand Down Expand Up @@ -171,19 +172,6 @@ public void testSearchMLTasks_afterCreation() throws IOException {
assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value"));
}

public void testRegisterRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
assertNotNull(responseMap.get("model_id"));
}

public void testDeployRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
Expand All @@ -202,42 +190,6 @@ public void testDeployRemoteModel() throws IOException, InterruptedException {
waitForTask(taskId, MLTaskState.COMPLETED);
}

public void testPredictRemoteModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
}
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");
if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
}

public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -335,26 +287,6 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
});
}

public void testUndeployRemoteModel() throws IOException, InterruptedException {
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = undeployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
assertTrue(responseMap.toString().contains("undeployed"));
}

public void testOpenAIChatCompletionModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -418,8 +350,12 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep
responseMap = parseResponseToMap(response);
// TODO handle throttling error
assertNotNull(responseMap);
}

response = undeployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
assertTrue(responseMap.toString().contains("undeployed"));
}
@Ignore
public void testOpenAIEditsModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -490,7 +426,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {

assertFalse(((String) responseMap.get("content")).isEmpty());
}

@Ignore
public void testOpenAIModerationsModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -720,7 +656,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
}

@Ignore
public void testCohereClassifyModel() throws IOException, InterruptedException {
// Skip test if key is null
if (COHERE_KEY == null) {
Expand Down

0 comments on commit eaa5c3c

Please sign in to comment.