diff --git a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java index bad7c6d610..8c1900904d 100644 --- a/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/bwc/MLCommonsBackwardsCompatibilityRestTestCase.java @@ -5,6 +5,7 @@ package org.opensearch.ml.bwc; +import static org.hamcrest.Matchers.equalTo; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; @@ -735,7 +736,7 @@ public String getTaskState(String taskId) throws IOException { return (String) task.get("state"); } - public void waitForTask(String taskId, MLTaskState targetState) throws InterruptedException { + public void waitForTask(String taskId, MLTaskState targetState) throws InterruptedException, IOException { AtomicBoolean taskDone = new AtomicBoolean(false); waitUntil(() -> { try { @@ -748,6 +749,7 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt } return taskDone.get(); }, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS); + assertThat(getTaskState(taskId), equalTo(targetState.name())); assertTrue(taskDone.get()); } }