Skip to content

Commit

Permalink
Fix flaky integ tests
Browse files Browse the repository at this point in the history
Signed-off-by: Tanqiu Liu <[email protected]>
  • Loading branch information
tanqiuliu committed Nov 7, 2023
1 parent cda2f82 commit 8ab54c4
Showing 1 changed file with 37 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.test.ClusterServiceUtils;
Expand Down Expand Up @@ -610,8 +611,8 @@ protected void deleteModel(String modelId) {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);

// after model undeploy returns, the max interval to update model status is 3s in ml-commons CronJob.
Thread.sleep(3000);
// wait for model undeploy to complete
pollForModelState(modelId, MLModelState.UNDEPLOYED, 3000, 5);

makeRequest(
client(),
Expand All @@ -623,6 +624,16 @@ protected void deleteModel(String modelId) {
);
}

@SneakyThrows
protected void pollForModelState(String modelId, MLModelState expectedModelState, int intervalMs, int maxAttempts) {
for (int i = 0; i < maxAttempts; i++) {
Thread.sleep(intervalMs);
if (expectedModelState.equals(getModelState(modelId))) {
return;
}
}
}

public boolean isUpdateClusterSettings() {
return true;
}
Expand Down Expand Up @@ -698,7 +709,7 @@ protected void deleteSearchPipeline(final String pipelineId) {
}

/**
* Find all modesl that are currently deployed in the cluster
* Find all models that are currently deployed in the cluster
* @return set of model ids
*/
@SneakyThrows
Expand Down Expand Up @@ -733,11 +744,33 @@ protected Set<String> findDeployedModels() {
List<Map<String, Object>> innerHitsMap = (List<Map<String, Object>>) hits.get("hits");
return innerHitsMap.stream()
.map(hit -> (Map<String, Object>) hit.get("_source"))
.filter(hitsMap -> !Objects.isNull(hitsMap) && hitsMap.containsKey("model_id"))
.filter(
hitsMap -> !Objects.isNull(hitsMap)
&& hitsMap.containsKey("model_id")
&& MLModelState.DEPLOYED.equals(getModelState(hitsMap.get("model_id").toString()))
)
.map(hitsMap -> (String) hitsMap.get("model_id"))
.collect(Collectors.toSet());
}

@SneakyThrows
protected MLModelState getModelState(String modelId) {
Response getModelResponse = makeRequest(
client(),
"GET",
String.format(LOCALE, "/_plugins/_ml/models/%s", modelId),
null,
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> getModelResponseJson = XContentHelper.convertToMap(
XContentType.JSON.xContent(),
EntityUtils.toString(getModelResponse.getEntity()),
false
);
return MLModelState.valueOf((String) getModelResponseJson.get("model_state"));
}

/**
* Get the id for model currently deployed in the cluster. If there are no models deployed or it's more than 1 model
* fail on assertion
Expand Down

0 comments on commit 8ab54c4

Please sign in to comment.