From 8ab54c443f705eb7b89789d0e86e61edc99923c0 Mon Sep 17 00:00:00 2001 From: Tanqiu Liu Date: Mon, 6 Nov 2023 22:55:11 -0800 Subject: [PATCH 1/2] Fix flaky integ tests Signed-off-by: Tanqiu Liu --- .../common/BaseNeuralSearchIT.java | 41 +++++++++++++++++-- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 33cdff9a0..ce021e25b 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -49,6 +49,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.test.ClusterServiceUtils; @@ -610,8 +611,8 @@ protected void deleteModel(String modelId) { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); - // after model undeploy returns, the max interval to update model status is 3s in ml-commons CronJob. - Thread.sleep(3000); + // wait for model undeploy to complete + pollForModelState(modelId, MLModelState.UNDEPLOYED, 3000, 5); makeRequest( client(), @@ -623,6 +624,16 @@ protected void deleteModel(String modelId) { ); } + @SneakyThrows + protected void pollForModelState(String modelId, MLModelState expectedModelState, int intervalMs, int maxAttempts) { + for (int i = 0; i < maxAttempts; i++) { + Thread.sleep(intervalMs); + if (expectedModelState.equals(getModelState(modelId))) { + return; + } + } + } + public boolean isUpdateClusterSettings() { return true; } @@ -698,7 +709,7 @@ protected void deleteSearchPipeline(final String pipelineId) { } /** - * Find all modesl that are currently deployed in the cluster + * Find all models that are currently deployed in the cluster * @return set of model ids */ @SneakyThrows @@ -733,11 +744,33 @@ protected Set findDeployedModels() { List> innerHitsMap = (List>) hits.get("hits"); return innerHitsMap.stream() .map(hit -> (Map) hit.get("_source")) - .filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id")) + .filter( + hitsMap -> !Objects.isNull(hitsMap) + && hitsMap.containsKey("model_id") + && MLModelState.DEPLOYED.equals(getModelState(hitsMap.get("model_id").toString())) + ) .map(hitsMap -> (String) hitsMap.get("model_id")) .collect(Collectors.toSet()); } + @SneakyThrows + protected MLModelState getModelState(String modelId) { + Response getModelResponse = makeRequest( + client(), + "GET", + String.format(LOCALE, "/_plugins/_ml/models/%s", modelId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map getModelResponseJson = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(getModelResponse.getEntity()), + false + ); + return MLModelState.valueOf((String) getModelResponseJson.get("model_state")); + } + /** * Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model * fail on assertion From 371a9ddfe87659a266527d9965d5da5194574dfd Mon Sep 17 00:00:00 2001 From: Tanqiu Liu Date: Wed, 8 Nov 2023 01:07:14 -0800 Subject: [PATCH 2/2] Address PR comments; Added CHANGELOG Signed-off-by: Tanqiu Liu --- CHANGELOG.md | 1 + .../common/BaseNeuralSearchIT.java | 24 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b92edd850..dc6eeb46c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fixed flaky integration tests caused by model_state transition latency. ### Infrastructure ### Documentation ### Maintenance diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index ce021e25b..16a71fcd1 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -193,6 +193,8 @@ protected void loadModel(String modelId) throws Exception { isComplete = checkComplete(taskQueryResult); Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); } + // wait for model state update to DEPLOYED + pollForModelState(modelId, Set.of(MLModelState.DEPLOYED), 3000, 5); } /** @@ -612,7 +614,7 @@ protected void deleteModel(String modelId) { ); // wait for model undeploy to complete - pollForModelState(modelId, MLModelState.UNDEPLOYED, 3000, 5); + pollForModelState(modelId, Set.of(MLModelState.UNDEPLOYED, MLModelState.DEPLOY_FAILED), 3000, 5); makeRequest( client(), @@ -624,14 +626,26 @@ protected void deleteModel(String modelId) { ); } - @SneakyThrows - protected void pollForModelState(String modelId, MLModelState expectedModelState, int intervalMs, int maxAttempts) { + protected void pollForModelState(String modelId, Set exitModelStates, int intervalMs, int maxAttempts) + throws InterruptedException { + MLModelState currentState = null; for (int i = 0; i < maxAttempts; i++) { Thread.sleep(intervalMs); - if (expectedModelState.equals(getModelState(modelId))) { + currentState = getModelState(modelId); + if (exitModelStates.contains(currentState)) { return; } } + fail( + String.format( + LOCALE, + "Model state does not reached exit states %s after %d attempts with interval of %d ms, latest model state: %s.", + StringUtils.join(exitModelStates, ","), + maxAttempts, + intervalMs, + currentState + ) + ); } public boolean isUpdateClusterSettings() { @@ -747,7 +761,7 @@ protected Set findDeployedModels() { .filter( hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id") - && MLModelState.DEPLOYED.equals(getModelState(hitsMap.get("model_id").toString())) + && getModelState(hitsMap.get("model_id").toString()) == MLModelState.DEPLOYED ) .map(hitsMap -> (String) hitsMap.get("model_id")) .collect(Collectors.toSet());