diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index f4bfb1f5..ce22febd 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -1200,17 +1200,31 @@ def _find_active_nodes(partitions_name_map): active_nodes += partition.slurm_nodes return list(dict.fromkeys(active_nodes)) - def _is_node_in_replacement_valid(self, node, check_node_is_valid): + def _is_node_in_replacement_valid(self, node: SlurmNode, check_node_is_valid): """ Check node is replacement timeout or in replacement. If check_node_is_valid=True, check whether a node is in replacement, If check_node_is_valid=False, check whether a node is replacement timeout. """ - if node.instance and node.name in self._static_nodes_in_replacement: - time_is_expired = time_is_up( - node.instance.launch_time, self._current_time, grace_time=self._config.node_replacement_timeout + log.debug(f"Checking if node is in replacement {node}") + if ( + node.is_backing_instance_valid( + self._config.ec2_instance_missing_max_count, + self._nodes_without_backing_instance_count_map, + log_warn_if_unhealthy=True, + ) + and node.name in self._static_nodes_in_replacement + ): + # Set `time_is_expired` to `False` if `node.instance` is `None` since we don't have a launch time yet + time_is_expired = ( + False + if not node.instance + else time_is_up( + node.instance.launch_time, self._current_time, grace_time=self._config.node_replacement_timeout + ) ) + log.debug(f"Node {node} is in replacement and timer expired? {time_is_expired}, instance? {node.instance}") return not time_is_expired if check_node_is_valid else time_is_expired return False diff --git a/src/slurm_plugin/slurm_resources.py b/src/slurm_plugin/slurm_resources.py index 73a9b1e6..a2378a0f 100644 --- a/src/slurm_plugin/slurm_resources.py +++ b/src/slurm_plugin/slurm_resources.py @@ -180,6 +180,17 @@ class SlurmReservation: users: str +class MissingInstance: + name: str + ip: str + count: int + + def __init__(self, name, ip, count): + self.name = name + self.ip = ip + self.count = count + + class SlurmNode(metaclass=ABCMeta): SLURM_SCONTROL_COMPLETING_STATE = "COMPLETING" SLURM_SCONTROL_BUSY_STATES = {"MIXED", "ALLOCATED", SLURM_SCONTROL_COMPLETING_STATE} @@ -427,7 +438,7 @@ def is_powering_down_with_nodeaddr(self): def is_backing_instance_valid( self, ec2_instance_missing_max_count, - nodes_without_backing_instance_count_map: dict, + nodes_without_backing_instance_count_map: dict[str, MissingInstance], log_warn_if_unhealthy=True, ): """Check if a slurm node's addr is set, it points to a valid instance in EC2.""" @@ -445,7 +456,11 @@ def is_backing_instance_valid( ) # Allow a few iterations for the eventual consistency of EC2 data logger.debug(f"Map of slurm nodes without backing instances {nodes_without_backing_instance_count_map}") - missing_instance_loop_count = nodes_without_backing_instance_count_map.get(self.name, 0) + missing_instance = nodes_without_backing_instance_count_map.get(self.name, None) + missing_instance_loop_count = missing_instance.count if missing_instance else 0 + if missing_instance and self.nodeaddr != missing_instance.ip: + # Reset the loop count since the nodeaddr has changed + missing_instance_loop_count = 0 # If the loop count has been reached, the instance is unhealthy and will be terminated if missing_instance_loop_count >= ec2_instance_missing_max_count: if log_warn_if_unhealthy: @@ -454,11 +469,12 @@ def is_backing_instance_valid( nodes_without_backing_instance_count_map.pop(self.name, None) self.ec2_backing_instance_valid = False else: - nodes_without_backing_instance_count_map[self.name] = missing_instance_loop_count + 1 + instance_to_add = MissingInstance(self.name, self.nodeaddr, missing_instance_loop_count + 1) + nodes_without_backing_instance_count_map[self.name] = instance_to_add if log_warn_if_unhealthy: logger.warning( f"Incrementing missing EC2 instance count for node {self.name} to " - f"{nodes_without_backing_instance_count_map[self.name]}." + f"{nodes_without_backing_instance_count_map[self.name].count}." ) else: # Remove the slurm node from the map since the instance is healthy diff --git a/tests/slurm_plugin/slurm_resources/test_slurm_resources.py b/tests/slurm_plugin/slurm_resources/test_slurm_resources.py index c25570ae..b3138b27 100644 --- a/tests/slurm_plugin/slurm_resources/test_slurm_resources.py +++ b/tests/slurm_plugin/slurm_resources/test_slurm_resources.py @@ -18,6 +18,7 @@ DynamicNode, EC2InstanceHealthState, InvalidNodenameError, + MissingInstance, SlurmPartition, SlurmResumeJob, StaticNode, @@ -1185,26 +1186,34 @@ def test_slurm_node_is_powering_down_with_nodeaddr(node, expected_result): StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), None, 2, - {"queue1-st-c5xlarge-1": 1}, - {"queue1-st-c5xlarge-1": 2}, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 1)}, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 2)}, True, ), ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), None, 2, - {"queue1-st-c5xlarge-1": 2}, - {"queue1-st-c5xlarge-1": 2}, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 2)}, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 2)}, False, ), ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), "Instance", 2, - {"queue1-st-c5xlarge-1": 3}, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 3)}, {}, True, ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"), + "Instance", + 3, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-2", 2)}, + {"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 1)}, + True, + ), ], ids=[ "static_no_backing_zero_max_count", @@ -1214,6 +1223,7 @@ def test_slurm_node_is_powering_down_with_nodeaddr(node, expected_result): "static_no_backing_count_not_exceeded", "static_no_backing_with_count_exceeded", "static_backed_with_count_exceeded", + "static_no_backing_count_not_exceeded_with_wrong_ip", ], ) def test_slurm_node_is_backing_instance_valid(node, instance, max_count, count_map, final_map, expected_result): @@ -1225,7 +1235,9 @@ def test_slurm_node_is_backing_instance_valid(node, instance, max_count, count_m ).is_equal_to(expected_result) assert_that(node.ec2_backing_instance_valid).is_equal_to(expected_result) if count_map: - assert_that(count_map[node.name]).is_equal_to(final_map.get(node.name, None)) + assert_that(count_map[node.name].count).is_equal_to(final_map.get(node.name, None).count) + assert_that(count_map[node.name].ip).is_equal_to(final_map.get(node.name, None).ip) + assert_that(count_map[node.name].ip).is_equal_to(node.nodeaddr) @pytest.mark.parametrize( diff --git a/tests/slurm_plugin/test_clustermgtd.py b/tests/slurm_plugin/test_clustermgtd.py index 188f03a2..422c516c 100644 --- a/tests/slurm_plugin/test_clustermgtd.py +++ b/tests/slurm_plugin/test_clustermgtd.py @@ -767,6 +767,7 @@ def test_handle_health_check( region="region", boto3_config=None, fleet_config={}, + ec2_instance_missing_max_count=0, ) cluster_manager = ClusterManager(mock_sync_config) @@ -831,6 +832,7 @@ def test_update_static_nodes_in_replacement(current_replacing_nodes, slurm_nodes region="region", boto3_config=None, fleet_config={}, + ec2_instance_missing_max_count=0, ) cluster_manager = ClusterManager(mock_sync_config) cluster_manager._static_nodes_in_replacement = current_replacing_nodes @@ -2646,13 +2648,14 @@ def initialize_console_logger_mock(mocker): @pytest.mark.parametrize( - "current_replacing_nodes, node, instance, current_time, expected_result", + "current_replacing_nodes, node, instance, current_time, max_count, expected_result", [ ( set(), StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"), EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)), datetime(2020, 1, 1, 0, 0, 29), + 0, False, ), ( @@ -2660,6 +2663,7 @@ def initialize_console_logger_mock(mocker): StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"), None, datetime(2020, 1, 1, 0, 0, 29), + 0, False, ), ( @@ -2667,6 +2671,7 @@ def initialize_console_logger_mock(mocker): StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD", "queue1"), EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)), datetime(2020, 1, 1, 0, 0, 29), + 0, True, ), ( @@ -2674,15 +2679,30 @@ def initialize_console_logger_mock(mocker): StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"), EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)), datetime(2020, 1, 1, 0, 0, 30), + 0, False, ), + ( + {"queue1-st-c5xlarge-1"}, + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"), + None, + datetime(2020, 1, 1, 0, 0, 30), + 1, + True, + ), + ], + ids=[ + "not_in_replacement", + "no-backing-instance", + "in_replacement", + "timeout", + "no-backing-instance-with-max-count", ], - ids=["not_in_replacement", "no-backing-instance", "in_replacement", "timeout"], ) @pytest.mark.usefixtures( "initialize_instance_manager_mock", "initialize_executor_mock", "initialize_console_logger_mock" ) -def test_is_node_being_replaced(current_replacing_nodes, node, instance, current_time, expected_result): +def test_is_node_being_replaced(current_replacing_nodes, node, instance, current_time, max_count, expected_result): mock_sync_config = SimpleNamespace( node_replacement_timeout=30, insufficient_capacity_timeout=3, @@ -2691,6 +2711,7 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current region="region", boto3_config=None, fleet_config={}, + ec2_instance_missing_max_count=max_count, ) cluster_manager = ClusterManager(mock_sync_config) cluster_manager._current_time = current_time @@ -2700,24 +2721,34 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current @pytest.mark.parametrize( - "node, instance, current_node_in_replacement, is_replacement_timeout", + "node, instance, current_node_in_replacement, max_count, is_replacement_timeout", [ ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"), None, {"queue1-st-c5xlarge-1"}, + 0, + False, + ), + ( + StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"), + None, + {"queue1-st-c5xlarge-1"}, + 1, False, ), ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"), EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)), {"queue1-st-c5xlarge-1"}, + 0, True, ), ( DynamicNode("queue1-dy-c5xlarge-1", "ip-1", "hostname", "MIXED+CLOUD+NOT_RESPONDING+POWERING_UP", "queue1"), None, {"some_node_in_replacement"}, + 0, False, ), ( @@ -2730,12 +2761,14 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current ), EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)), {"some_node_in_replacement"}, + 0, False, ), ( StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"), EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)), {"some_node_in_replacement"}, + 0, False, ), ], @@ -2743,7 +2776,7 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current @pytest.mark.usefixtures( "initialize_instance_manager_mock", "initialize_executor_mock", "initialize_console_logger_mock" ) -def test_is_node_replacement_timeout(node, current_node_in_replacement, is_replacement_timeout, instance): +def test_is_node_replacement_timeout(node, current_node_in_replacement, max_count, is_replacement_timeout, instance): node.instance = instance mock_sync_config = SimpleNamespace( node_replacement_timeout=30, @@ -2753,6 +2786,7 @@ def test_is_node_replacement_timeout(node, current_node_in_replacement, is_repla region="region", boto3_config=None, fleet_config={}, + ec2_instance_missing_max_count=0, ) cluster_manager = ClusterManager(mock_sync_config) cluster_manager._current_time = datetime(2020, 1, 2, 0, 0, 0)