diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 613e7fb553..d766e4551d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -962,7 +962,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt } return taskDone.get(); }, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS); - assertTrue(taskDone.get()); + assertTrue(String.format(Locale.ROOT, "Task Id %s could not get to %s state", taskId, targetState.name()), taskDone.get()); } public String registerConnector(String createConnectorInput) throws IOException, InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index fbabf1dbb7..bd1c9536bf 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); @@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep responseMap = (Map) responseList.get(0); responseMap = (Map) responseMap.get("dataAsMap"); responseList = (List) responseMap.get("choices"); + if (responseList == null) { assertTrue(checkThrottlingOpenAI(responseMap)); return; @@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept exceptionRule.expect(ResponseException.class); exceptionRule.expectMessage("guardrails triggered for user input"); Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; predictRemoteModel(modelId, predictInput); @@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; predictRemoteModel(modelId, predictInput); } @@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); String guardrailConnectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "accept". String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; response = predictRemoteModel(guardrailModelId, predictInput); @@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException responseMap = (Map) responseMap.get("dataAsMap"); String validationResult = (String) responseMap.get("response"); Assert.assertTrue(validateRegex(validationResult, acceptRegex)); + // Create predict model. response = createConnector(completionModelConnectorEntity); responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Predict. predictInput = "{\n" + " \"parameters\": {\n" @@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); String guardrailConnectorId = (String) responseMap.get("connector_id"); + + // Create the model ID response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "reject". String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}"; response = predictRemoteModel(guardrailModelId, predictInput); @@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, responseMap = (Map) responseMap.get("dataAsMap"); String validationResult = (String) responseMap.get("response"); Assert.assertTrue(validateRegex(validationResult, rejectRegex)); + // Create predict model. response = createConnector(completionModelConnectorEntity); responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index cc111010cb..5689022edd 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -8,23 +8,31 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; import org.junit.After; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.rest.RestBaseAgentToolsIT; import org.opensearch.ml.utils.TestHelper; import com.sun.net.httpserver.HttpServer; +import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; @Log4j2 public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT { + + private static final int MAX_RETRIES = 5; + private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + protected HttpServer server; protected String modelId; protected String agentId; @@ -62,9 +70,48 @@ public void stopMockLLM() { @After public void deleteModel() throws IOException { undeployModel(modelId); + checkForModelUndeployedStatus(modelId); deleteModel(client(), modelId, null); } + @SneakyThrows + private void checkForModelUndeployedStatus(String modelId) { + Predicate condition = response -> { + try { + Map responseInMap = parseResponseToMap(response); + MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString()); + return MLModelState.UNDEPLOYED.equals(state); + } catch (Exception e) { + return false; + } + }; + waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition); + } + + @SneakyThrows + protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate condition) { + for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) { + Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + if (condition.test(response)) { + return response; + } + logger.info("The {}-th attempt on {}:{} . response: {}", attempt, method, endpoint, response.toString()); + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + } + fail( + String + .format( + Locale.ROOT, + "The response failed to meet condition after %d attempts. Attempted to perform %s : %s", + MAX_RETRIES, + method, + endpoint + ) + ); + return null; + } + private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException { int retryTimes = 0; String connectorId = null;