diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 33cdff9a0..ce021e25b 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -49,6 +49,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.test.ClusterServiceUtils; @@ -610,8 +611,8 @@ protected void deleteModel(String modelId) { ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); - // after model undeploy returns, the max interval to update model status is 3s in ml-commons CronJob. - Thread.sleep(3000); + // wait for model undeploy to complete + pollForModelState(modelId, MLModelState.UNDEPLOYED, 3000, 5); makeRequest( client(), @@ -623,6 +624,16 @@ protected void deleteModel(String modelId) { ); } + @SneakyThrows + protected void pollForModelState(String modelId, MLModelState expectedModelState, int intervalMs, int maxAttempts) { + for (int i = 0; i < maxAttempts; i++) { + Thread.sleep(intervalMs); + if (expectedModelState.equals(getModelState(modelId))) { + return; + } + } + } + public boolean isUpdateClusterSettings() { return true; } @@ -698,7 +709,7 @@ protected void deleteSearchPipeline(final String pipelineId) { } /** - * Find all modesl that are currently deployed in the cluster + * Find all models that are currently deployed in the cluster * @return set of model ids */ @SneakyThrows @@ -733,11 +744,33 @@ protected Set findDeployedModels() { List> innerHitsMap = (List>) hits.get("hits"); return innerHitsMap.stream() .map(hit -> (Map) hit.get("_source")) - .filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id")) + .filter( + hitsMap -> !Objects.isNull(hitsMap) + && hitsMap.containsKey("model_id") + && MLModelState.DEPLOYED.equals(getModelState(hitsMap.get("model_id").toString())) + ) .map(hitsMap -> (String) hitsMap.get("model_id")) .collect(Collectors.toSet()); } + @SneakyThrows + protected MLModelState getModelState(String modelId) { + Response getModelResponse = makeRequest( + client(), + "GET", + String.format(LOCALE, "/_plugins/_ml/models/%s", modelId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map getModelResponseJson = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(getModelResponse.getEntity()), + false + ); + return MLModelState.valueOf((String) getModelResponseJson.get("model_state")); + } + /** * Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model * fail on assertion