Skip to content

Commit

Permalink
[Develop] Backing instance max count updates - `_is_node_in_replaceme…
Browse files Browse the repository at this point in the history
…nt_valid()`, `is_backing_instance_valid()`, and unit tests (#622)

* Add logic to `is_backing_instance_valid()` to check the IP to make sure the instance matches what is being tracked (#618)

The missing instance map did not track what the IP address was that was associated with the slurm node.
Because of this if a new instance is launched before an instance becomes healthy, the increment is not reset
for the instance count map.  This change uses a class object to track the data and links the node name to the ip.

Also use the `is_backing_instance_valid()` function in `is_state_healthy()` instead of the plain `node.instance`
object check to allow for the delay in EC2 consistency.

* Refactor logic in `_is_node_in_replacement_valid()` to account for `node.instance` being `None` (#620)

* Add unit tests to cover max_count > 0 in _is_node_in_replacement_valid
  • Loading branch information
dreambeyondorange authored Mar 1, 2024
1 parent 293efb7 commit 7bb99a9
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 19 deletions.
22 changes: 18 additions & 4 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,17 +1200,31 @@ def _find_active_nodes(partitions_name_map):
active_nodes += partition.slurm_nodes
return list(dict.fromkeys(active_nodes))

def _is_node_in_replacement_valid(self, node, check_node_is_valid):
def _is_node_in_replacement_valid(self, node: SlurmNode, check_node_is_valid):
"""
Check node is replacement timeout or in replacement.
If check_node_is_valid=True, check whether a node is in replacement,
If check_node_is_valid=False, check whether a node is replacement timeout.
"""
if node.instance and node.name in self._static_nodes_in_replacement:
time_is_expired = time_is_up(
node.instance.launch_time, self._current_time, grace_time=self._config.node_replacement_timeout
log.debug(f"Checking if node is in replacement {node}")
if (
node.is_backing_instance_valid(
self._config.ec2_instance_missing_max_count,
self._nodes_without_backing_instance_count_map,
log_warn_if_unhealthy=True,
)
and node.name in self._static_nodes_in_replacement
):
# Set `time_is_expired` to `False` if `node.instance` is `None` since we don't have a launch time yet
time_is_expired = (
False
if not node.instance
else time_is_up(
node.instance.launch_time, self._current_time, grace_time=self._config.node_replacement_timeout
)
)
log.debug(f"Node {node} is in replacement and timer expired? {time_is_expired}, instance? {node.instance}")
return not time_is_expired if check_node_is_valid else time_is_expired
return False

Expand Down
24 changes: 20 additions & 4 deletions src/slurm_plugin/slurm_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ class SlurmReservation:
users: str


class MissingInstance:
name: str
ip: str
count: int

def __init__(self, name, ip, count):
self.name = name
self.ip = ip
self.count = count


class SlurmNode(metaclass=ABCMeta):
SLURM_SCONTROL_COMPLETING_STATE = "COMPLETING"
SLURM_SCONTROL_BUSY_STATES = {"MIXED", "ALLOCATED", SLURM_SCONTROL_COMPLETING_STATE}
Expand Down Expand Up @@ -427,7 +438,7 @@ def is_powering_down_with_nodeaddr(self):
def is_backing_instance_valid(
self,
ec2_instance_missing_max_count,
nodes_without_backing_instance_count_map: dict,
nodes_without_backing_instance_count_map: dict[str, MissingInstance],
log_warn_if_unhealthy=True,
):
"""Check if a slurm node's addr is set, it points to a valid instance in EC2."""
Expand All @@ -445,7 +456,11 @@ def is_backing_instance_valid(
)
# Allow a few iterations for the eventual consistency of EC2 data
logger.debug(f"Map of slurm nodes without backing instances {nodes_without_backing_instance_count_map}")
missing_instance_loop_count = nodes_without_backing_instance_count_map.get(self.name, 0)
missing_instance = nodes_without_backing_instance_count_map.get(self.name, None)
missing_instance_loop_count = missing_instance.count if missing_instance else 0
if missing_instance and self.nodeaddr != missing_instance.ip:
# Reset the loop count since the nodeaddr has changed
missing_instance_loop_count = 0
# If the loop count has been reached, the instance is unhealthy and will be terminated
if missing_instance_loop_count >= ec2_instance_missing_max_count:
if log_warn_if_unhealthy:
Expand All @@ -454,11 +469,12 @@ def is_backing_instance_valid(
nodes_without_backing_instance_count_map.pop(self.name, None)
self.ec2_backing_instance_valid = False
else:
nodes_without_backing_instance_count_map[self.name] = missing_instance_loop_count + 1
instance_to_add = MissingInstance(self.name, self.nodeaddr, missing_instance_loop_count + 1)
nodes_without_backing_instance_count_map[self.name] = instance_to_add
if log_warn_if_unhealthy:
logger.warning(
f"Incrementing missing EC2 instance count for node {self.name} to "
f"{nodes_without_backing_instance_count_map[self.name]}."
f"{nodes_without_backing_instance_count_map[self.name].count}."
)
else:
# Remove the slurm node from the map since the instance is healthy
Expand Down
24 changes: 18 additions & 6 deletions tests/slurm_plugin/slurm_resources/test_slurm_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DynamicNode,
EC2InstanceHealthState,
InvalidNodenameError,
MissingInstance,
SlurmPartition,
SlurmResumeJob,
StaticNode,
Expand Down Expand Up @@ -1185,26 +1186,34 @@ def test_slurm_node_is_powering_down_with_nodeaddr(node, expected_result):
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
2,
{"queue1-st-c5xlarge-1": 1},
{"queue1-st-c5xlarge-1": 2},
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 1)},
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 2)},
True,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
None,
2,
{"queue1-st-c5xlarge-1": 2},
{"queue1-st-c5xlarge-1": 2},
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 2)},
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 2)},
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
"Instance",
2,
{"queue1-st-c5xlarge-1": 3},
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 3)},
{},
True,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD+POWER", "queue1"),
"Instance",
3,
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-2", 2)},
{"queue1-st-c5xlarge-1": MissingInstance("queue1-st-c5xlarge-1", "ip-1", 1)},
True,
),
],
ids=[
"static_no_backing_zero_max_count",
Expand All @@ -1214,6 +1223,7 @@ def test_slurm_node_is_powering_down_with_nodeaddr(node, expected_result):
"static_no_backing_count_not_exceeded",
"static_no_backing_with_count_exceeded",
"static_backed_with_count_exceeded",
"static_no_backing_count_not_exceeded_with_wrong_ip",
],
)
def test_slurm_node_is_backing_instance_valid(node, instance, max_count, count_map, final_map, expected_result):
Expand All @@ -1225,7 +1235,9 @@ def test_slurm_node_is_backing_instance_valid(node, instance, max_count, count_m
).is_equal_to(expected_result)
assert_that(node.ec2_backing_instance_valid).is_equal_to(expected_result)
if count_map:
assert_that(count_map[node.name]).is_equal_to(final_map.get(node.name, None))
assert_that(count_map[node.name].count).is_equal_to(final_map.get(node.name, None).count)
assert_that(count_map[node.name].ip).is_equal_to(final_map.get(node.name, None).ip)
assert_that(count_map[node.name].ip).is_equal_to(node.nodeaddr)


@pytest.mark.parametrize(
Expand Down
44 changes: 39 additions & 5 deletions tests/slurm_plugin/test_clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ def test_handle_health_check(
region="region",
boto3_config=None,
fleet_config={},
ec2_instance_missing_max_count=0,
)

cluster_manager = ClusterManager(mock_sync_config)
Expand Down Expand Up @@ -831,6 +832,7 @@ def test_update_static_nodes_in_replacement(current_replacing_nodes, slurm_nodes
region="region",
boto3_config=None,
fleet_config={},
ec2_instance_missing_max_count=0,
)
cluster_manager = ClusterManager(mock_sync_config)
cluster_manager._static_nodes_in_replacement = current_replacing_nodes
Expand Down Expand Up @@ -2646,43 +2648,61 @@ def initialize_console_logger_mock(mocker):


@pytest.mark.parametrize(
"current_replacing_nodes, node, instance, current_time, expected_result",
"current_replacing_nodes, node, instance, current_time, max_count, expected_result",
[
(
set(),
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"),
EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)),
datetime(2020, 1, 1, 0, 0, 29),
0,
False,
),
(
{"queue1-st-c5xlarge-1"},
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"),
None,
datetime(2020, 1, 1, 0, 0, 29),
0,
False,
),
(
{"queue1-st-c5xlarge-1"},
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD", "queue1"),
EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)),
datetime(2020, 1, 1, 0, 0, 29),
0,
True,
),
(
{"queue1-st-c5xlarge-1"},
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"),
EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)),
datetime(2020, 1, 1, 0, 0, 30),
0,
False,
),
(
{"queue1-st-c5xlarge-1"},
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "IDLE+CLOUD", "queue1"),
None,
datetime(2020, 1, 1, 0, 0, 30),
1,
True,
),
],
ids=[
"not_in_replacement",
"no-backing-instance",
"in_replacement",
"timeout",
"no-backing-instance-with-max-count",
],
ids=["not_in_replacement", "no-backing-instance", "in_replacement", "timeout"],
)
@pytest.mark.usefixtures(
"initialize_instance_manager_mock", "initialize_executor_mock", "initialize_console_logger_mock"
)
def test_is_node_being_replaced(current_replacing_nodes, node, instance, current_time, expected_result):
def test_is_node_being_replaced(current_replacing_nodes, node, instance, current_time, max_count, expected_result):
mock_sync_config = SimpleNamespace(
node_replacement_timeout=30,
insufficient_capacity_timeout=3,
Expand All @@ -2691,6 +2711,7 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current
region="region",
boto3_config=None,
fleet_config={},
ec2_instance_missing_max_count=max_count,
)
cluster_manager = ClusterManager(mock_sync_config)
cluster_manager._current_time = current_time
Expand All @@ -2700,24 +2721,34 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current


@pytest.mark.parametrize(
"node, instance, current_node_in_replacement, is_replacement_timeout",
"node, instance, current_node_in_replacement, max_count, is_replacement_timeout",
[
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"),
None,
{"queue1-st-c5xlarge-1"},
0,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"),
None,
{"queue1-st-c5xlarge-1"},
1,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"),
EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)),
{"queue1-st-c5xlarge-1"},
0,
True,
),
(
DynamicNode("queue1-dy-c5xlarge-1", "ip-1", "hostname", "MIXED+CLOUD+NOT_RESPONDING+POWERING_UP", "queue1"),
None,
{"some_node_in_replacement"},
0,
False,
),
(
Expand All @@ -2730,20 +2761,22 @@ def test_is_node_being_replaced(current_replacing_nodes, node, instance, current
),
EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)),
{"some_node_in_replacement"},
0,
False,
),
(
StaticNode("queue1-st-c5xlarge-1", "ip-1", "hostname", "DOWN+CLOUD+NOT_RESPONDING", "queue1"),
EC2Instance("id-1", "ip-1", "hostname", {"ip-1"}, datetime(2020, 1, 1, 0, 0, 0)),
{"some_node_in_replacement"},
0,
False,
),
],
)
@pytest.mark.usefixtures(
"initialize_instance_manager_mock", "initialize_executor_mock", "initialize_console_logger_mock"
)
def test_is_node_replacement_timeout(node, current_node_in_replacement, is_replacement_timeout, instance):
def test_is_node_replacement_timeout(node, current_node_in_replacement, max_count, is_replacement_timeout, instance):
node.instance = instance
mock_sync_config = SimpleNamespace(
node_replacement_timeout=30,
Expand All @@ -2753,6 +2786,7 @@ def test_is_node_replacement_timeout(node, current_node_in_replacement, is_repla
region="region",
boto3_config=None,
fleet_config={},
ec2_instance_missing_max_count=0,
)
cluster_manager = ClusterManager(mock_sync_config)
cluster_manager._current_time = datetime(2020, 1, 2, 0, 0, 0)
Expand Down

0 comments on commit 7bb99a9

Please sign in to comment.