Skip to content

Commit

Permalink
Fixes Two Flaky IT classes RestMLGuardrailsIT & ToolIntegrationWithLL…
Browse files Browse the repository at this point in the history
…MTest (#3253)

* fix uneeded call to get model_id for task api within RestMLGuardrailsIT

Following #3244 this IT called the task api to check the model id again however this is redundant. Instead one can directly pull the model_id upon creating the model group. Manual testing was done to see that the behavior is intact, this should help reduce the calls within a IT to make it less flaky

Signed-off-by: Brian Flores <[email protected]>

* fix ToolIntegrationWithLLMTest model undeploy race condition

Previously the test class attempted to delete a model without fully knowing if the model was undeployed in time. This change adds a waiting for 5 retries each 1 second to check the status of the model and only when undeployed will it proceed to delete the model. When the number of retries are exceeded it throws a error indicating a deeper problem. Manual testing was done to check that the model is undeployed by searching for the specific model via the checkForModelUndeployedStatus method.

Signed-off-by: Brian Flores <[email protected]>

---------

Signed-off-by: Brian Flores <[email protected]>
(cherry picked from commit 1a659c8)
  • Loading branch information
brianf-aws committed Dec 9, 2024
1 parent 9abed5f commit ee4845e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
}
return taskDone.get();
}, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS);
assertTrue(taskDone.get());
assertTrue(String.format(Locale.ROOT, "Task Id %s could not get to %s state", taskId, targetState.name()), taskDone.get());
}

public String registerConnector(String createConnectorInput) throws IOException, InterruptedException {
Expand Down
63 changes: 30 additions & 33 deletions plugin/src/test/java/org/opensearch/ml/rest/RestMLGuardrailsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,16 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
Expand All @@ -144,6 +143,7 @@ public void testPredictRemoteModelSuccess() throws IOException, InterruptedExcep
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");

if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
Expand All @@ -160,18 +160,18 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept
exceptionRule.expect(ResponseException.class);
exceptionRule.expectMessage("guardrails triggered for user input");
Response response = createConnector(completionModelConnectorEntity);

Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithLocalRegexGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");

waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
predictRemoteModel(modelId, predictInput);
Expand All @@ -187,17 +187,16 @@ public void testPredictRemoteModelFailedNonType() throws IOException, Interrupte
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
predictRemoteModel(modelId, predictInput);
}
Expand All @@ -211,17 +210,16 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
Response response = createConnector(completionModelConnectorEntityWithGuardrail);
Map responseMap = parseResponseToMap(response);
String guardrailConnectorId = (String) responseMap.get("connector_id");

response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String guardrailModelId = (String) responseMap.get("model_id");

response = deployRemoteModel(guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

// Check the response from guardrails model that should be "accept".
String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}";
response = predictRemoteModel(guardrailModelId, predictInput);
Expand All @@ -233,21 +231,21 @@ public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException
responseMap = (Map) responseMap.get("dataAsMap");
String validationResult = (String) responseMap.get("response");
Assert.assertTrue(validateRegex(validationResult, acceptRegex));

// Create predict model.
response = createConnector(completionModelConnectorEntity);
responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

// Predict.
predictInput = "{\n"
+ " \"parameters\": {\n"
Expand Down Expand Up @@ -282,17 +280,17 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
Response response = createConnector(completionModelConnectorEntityWithGuardrail);
Map responseMap = parseResponseToMap(response);
String guardrailConnectorId = (String) responseMap.get("connector_id");

// Create the model ID
response = registerRemoteModel("guardrail model group", "openAI-GPT-3.5 completions", guardrailConnectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String guardrailModelId = (String) responseMap.get("model_id");

response = deployRemoteModel(guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);

// Check the response from guardrails model that should be "reject".
String predictInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"I will be executed or tortured.\"\n" + " }\n" + "}";
response = predictRemoteModel(guardrailModelId, predictInput);
Expand All @@ -304,17 +302,16 @@ public void testPredictRemoteModelFailedWithModelGuardrail() throws IOException,
responseMap = (Map) responseMap.get("dataAsMap");
String validationResult = (String) responseMap.get("response");
Assert.assertTrue(validateRegex(validationResult, rejectRegex));

// Create predict model.
response = createConnector(completionModelConnectorEntity);
responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");

response = registerRemoteModelWithModelGuardrails("openAI with guardrails", connectorId, guardrailModelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");

response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,31 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import org.junit.After;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.rest.RestBaseAgentToolsIT;
import org.opensearch.ml.utils.TestHelper;

import com.sun.net.httpserver.HttpServer;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class ToolIntegrationWithLLMTest extends RestBaseAgentToolsIT {

private static final int MAX_RETRIES = 5;
private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;

protected HttpServer server;
protected String modelId;
protected String agentId;
Expand Down Expand Up @@ -62,9 +70,48 @@ public void stopMockLLM() {
@After
public void deleteModel() throws IOException {
undeployModel(modelId);
checkForModelUndeployedStatus(modelId);
deleteModel(client(), modelId, null);
}

@SneakyThrows
private void checkForModelUndeployedStatus(String modelId) {
Predicate<Response> condition = response -> {
try {
Map<String, Object> responseInMap = parseResponseToMap(response);
MLModelState state = MLModelState.from(responseInMap.get(MLModel.MODEL_STATE_FIELD).toString());
return MLModelState.UNDEPLOYED.equals(state);
} catch (Exception e) {
return false;
}
};
waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, null, condition);
}

@SneakyThrows
protected Response waitResponseMeetingCondition(String method, String endpoint, String jsonEntity, Predicate<Response> condition) {
for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) {
Response response = TestHelper.makeRequest(client(), method, endpoint, null, jsonEntity, null);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
if (condition.test(response)) {
return response;
}
logger.info("The {}-th attempt on {}:{} . response: {}", attempt, method, endpoint, response.toString());
Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND);
}
fail(
String
.format(
Locale.ROOT,
"The response failed to meet condition after %d attempts. Attempted to perform %s : %s",
MAX_RETRIES,
method,
endpoint
)
);
return null;
}

private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException {
int retryTimes = 0;
String connectorId = null;
Expand Down

0 comments on commit ee4845e

Please sign in to comment.