diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index fdfd8b0f6a..eeefea2b49 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -15,6 +15,7 @@ import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Before; +import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; @@ -171,19 +172,6 @@ public void testSearchMLTasks_afterCreation() throws IOException { assertEquals((Double) 1.0, (Double) ((Map) ((Map) responseMap.get("hits")).get("total")).get("value")); } - public void testRegisterRemoteModel() throws IOException, InterruptedException { - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - assertNotNull(responseMap.get("model_id")); - } - public void testDeployRemoteModel() throws IOException, InterruptedException { Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); @@ -202,42 +190,6 @@ public void testDeployRemoteModel() throws IOException, InterruptedException { waitForTask(taskId, MLTaskState.COMPLETED); } - public void testPredictRemoteModel() throws IOException, InterruptedException { - // Skip test if key is null - if (OPENAI_KEY == null) { - return; - } - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("choices"); - if (responseList == null) { - assertTrue(checkThrottlingOpenAI(responseMap)); - return; - } - responseMap = (Map) responseList.get(0); - assertFalse(((String) responseMap.get("text")).isEmpty()); - } - public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -335,26 +287,6 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException, }); } - public void testUndeployRemoteModel() throws IOException, InterruptedException { - Response response = createConnector(completionModelConnectorEntity); - Map responseMap = parseResponseToMap(response); - String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 completions", connectorId); - responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); - String modelId = (String) responseMap.get("model_id"); - response = deployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = undeployRemoteModel(modelId); - responseMap = parseResponseToMap(response); - assertTrue(responseMap.toString().contains("undeployed")); - } - public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -418,8 +350,12 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep responseMap = parseResponseToMap(response); // TODO handle throttling error assertNotNull(responseMap); - } + response = undeployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + assertTrue(responseMap.toString().contains("undeployed")); + } + @Ignore public void testOpenAIEditsModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -490,7 +426,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { assertFalse(((String) responseMap.get("content")).isEmpty()); } - + @Ignore public void testOpenAIModerationsModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -720,7 +656,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti responseMap = (Map) responseList.get(0); assertFalse(((String) responseMap.get("text")).isEmpty()); } - + @Ignore public void testCohereClassifyModel() throws IOException, InterruptedException { // Skip test if key is null if (COHERE_KEY == null) {