From ee4845ee2eefeee94ee5e8599b33c3d868954c47 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Fri, 6 Dec 2024 15:44:11 -0800 Subject: [PATCH] Fixes Two Flaky IT classes RestMLGuardrailsIT & ToolIntegrationWithLLMTest (#3253) * fix uneeded call to get model_id for task api within RestMLGuardrailsIT Following #3244 this IT called the task api to check the model id again however this is redundant. Instead one can directly pull the model_id upon creating the model group. Manual testing was done to see that the behavior is intact, this should help reduce the calls within a IT to make it less flaky Signed-off-by: Brian Flores * fix ToolIntegrationWithLLMTest model undeploy race condition Previously the test class attempted to delete a model without fully knowing if the model was undeployed in time. This change adds a waiting for 5 retries each 1 second to check the status of the model and only when undeployed will it proceed to delete the model. When the number of retries are exceeded it throws a error indicating a deeper problem. Manual testing was done to check that the model is undeployed by searching for the specific model via the checkForModelUndeployedStatus method. Signed-off-by: Brian Flores --------- Signed-off-by: Brian Flores (cherry picked from commit 1a659c8048dd23cdab3ab881622bf8cd355e1631) --- .../ml/rest/MLCommonsRestTestCase.java | 2 +- .../ml/rest/RestMLGuardrailsIT.java | 63 +++++++++---------- .../ml/tools/ToolIntegrationWithLLMTest.java | 47 ++++++++++++++ 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index c7fa379f39..5c64332c1d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -962,7 +962,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt } return taskDone.get(); }, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS); - assertTrue(taskDone.get()); + assertTrue(String.format(Locale.ROOT, "Task Id %s could not get to %s state", taskId, targetState.name()), taskDone.get()); } public String registerConnector(String createConnectorInput) throws IOException, InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java index fbabf1dbb7..bd1c9536bf 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java @@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); @@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep responseMap = (Map) responseList.get(0); responseMap = (Map) responseMap.get("dataAsMap"); responseList = (List) responseMap.get("choices"); + if (responseList == null) { assertTrue(checkThrottlingOpenAI(responseMap)); return; @@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept exceptionRule.expect(ResponseException.class); exceptionRule.expectMessage("guardrails triggered for user input"); Response response = createConnector(completionModelConnectorEntity); + Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; predictRemoteModel(modelId, predictInput); @@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}"; predictRemoteModel(modelId, predictInput); } @@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); String guardrailConnectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "accept". String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}"; response = predictRemoteModel(guardrailModelId, predictInput); @@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException responseMap = (Map) responseMap.get("dataAsMap"); String validationResult = (String) responseMap.get("response"); Assert.assertTrue(validateRegex(validationResult, acceptRegex)); + // Create predict model. response = createConnector(completionModelConnectorEntity); responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Predict. predictInput = "{\n" + " \"parameters\": {\n" @@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, Response response = createConnector(completionModelConnectorEntityWithGuardrail); Map responseMap = parseResponseToMap(response); String guardrailConnectorId = (String) responseMap.get("connector_id"); + + // Create the model ID response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId); responseMap = parseResponseToMap(response); - String taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String guardrailModelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); + String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); + // Check the response from guardrails model that should be "reject". String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}"; response = predictRemoteModel(guardrailModelId, predictInput); @@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException, responseMap = (Map) responseMap.get("dataAsMap"); String validationResult = (String) responseMap.get("response"); Assert.assertTrue(validateRegex(validationResult, rejectRegex)); + // Create predict model. response = createConnector(completionModelConnectorEntity); responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId); responseMap = parseResponseToMap(response); - taskId = (String) responseMap.get("task_id"); - waitForTask(taskId, MLTaskState.COMPLETED); - response = getTask(taskId); - responseMap = parseResponseToMap(response); String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java index cc111010cb..5689022edd 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ToolIntegrationWithLLMTest.java @@ -8,23 +8,31 @@ import java.io.IOException; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; import org.junit.After; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.rest.RestBaseAgentToolsIT; import org.opensearch.ml.utils.TestHelper; import com.sun.net.httpserver.HttpServer; +import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; @Log4j2 public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT { + + private static final int MAX_RETRIES = 5; + private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + protected HttpServer server; protected String modelId; protected String agentId; @@ -62,9 +70,48 @@ public void stopMockLLM() { @After public void deleteModel() throws IOException { undeployModel(modelId); + checkForModelUndeployedStatus(modelId); deleteModel(client(), modelId, null); } + @SneakyThrows + private void checkForModelUndeployedStatus(String modelId) { + Predicate condition = response -> { + try { + Map responseInMap = parseResponseToMap(response); + MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString()); + return MLModelState.UNDEPLOYED.equals(state); + } catch (Exception e) { + return false; + } + }; + waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition); + } + + @SneakyThrows + protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate condition) { + for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) { + Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + if (condition.test(response)) { + return response; + } + logger.info("The {}-th attempt on {}:{} . response: {}", attempt, method, endpoint, response.toString()); + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + } + fail( + String + .format( + Locale.ROOT, + "The response failed to meet condition after %d attempts. Attempted to perform %s : %s", + MAX_RETRIES, + method, + endpoint + ) + ); + return null; + } + private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException { int retryTimes = 0; String connectorId = null;