diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 276ce1774e..4e956e9c7e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -10,12 +10,16 @@ import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.time.Instant; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -131,10 +135,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener workNodesRemovedFromCluster = new HashSet<>(); + + if (workNodes != null && !workNodes.isEmpty()) { + Set allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService))); + + workNodesRemovedFromCluster = workNodes + .stream() + .filter(node -> !allNodesInCluster.contains(node)) + .collect(Collectors.toSet()); + + if (!workNodesRemovedFromCluster.isEmpty()) { + workNodes.removeAll(workNodesRemovedFromCluster); + } + } + + if (workNodes == null || workNodes.isEmpty()) { + if (!workNodesRemovedFromCluster.isEmpty()) { + mlTaskCache.updateWorkerNode(workNodesRemovedFromCluster); + mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0])); + } int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize(); MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED; - if (mlTaskCache.allNodeFailed()) { + if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) { taskState = MLTaskState.FAILED; currentWorkerNodeCount = 0; } else { @@ -150,11 +173,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener updateFields = new HashMap<>(); updateFields.put(MLModel.MODEL_STATE_FIELD, modelState); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java index a6b519a2aa..0b3d7d116f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java @@ -62,4 +62,9 @@ public int errorNodesCount() { public boolean allNodeFailed() { return workerNodeSize != null && errors.size() == workerNodeSize; } + + public void updateWorkerNode(Set nodesRemovedFromCluster) { + this.workerNodes.removeAll(nodesRemovedFromCluster); + this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size(); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java index b5c1a4e1f7..62a4fa779f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java @@ -29,6 +29,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.clusterSetting; +import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; import java.util.Arrays; import java.util.HashSet; @@ -43,6 +44,7 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -94,6 +96,8 @@ public class TransportForwardActionTests extends OpenSearchTestCase { private TransportForwardAction forwardAction; + private ClusterState testState; + Settings settings = Settings .builder() .put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true) @@ -137,6 +141,9 @@ public void setup() { ) ); + testState = setupTestClusterState("test_node_id2"); + when(clusterService.state()).thenReturn(testState); + node1 = new DiscoveryNode(nodeId1, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); node2 = new DiscoveryNode(nodeId2, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index c40b11e1ce..523b8fee36 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -118,7 +118,7 @@ public void setup() throws IOException { encryptor = spy(new EncryptorImpl(null)); syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java index 446dc74213..4c644300f2 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLProfileActionTests.java @@ -151,7 +151,7 @@ public void setup() throws IOException { .build(); clusterName = new ClusterName("test cluster"); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java index cdd9136255..a034af502a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLStatsActionTests.java @@ -147,7 +147,7 @@ public void setup() throws IOException { roleSet, Version.CURRENT ); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); clusterName = new ClusterName(clusterNameStr); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java index cf8e93765c..e98d1a5f31 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUndeployModelActionTests.java @@ -67,7 +67,7 @@ public class RestMLUndeployModelActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - testState = setupTestClusterState(); + testState = setupTestClusterState("node"); when(clusterService.state()).thenReturn(testState); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 0a1a78050c..49d4200b53 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -461,11 +461,11 @@ public static ClusterState state(int numDataNodes, String indexName, String mapp return state(new ClusterName("test"), indexName, mapping, clusterManagerNode, clusterManagerNode, allNodes); } - public static ClusterState setupTestClusterState() { + public static ClusterState setupTestClusterState(String nodeId) { Set roleSet = new HashSet<>(); roleSet.add(DiscoveryNodeRole.DATA_ROLE); DiscoveryNode node = new DiscoveryNode( - "node", + nodeId, new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), new HashMap<>(), roleSet,