Skip to content

Commit

Permalink
Reduce complexity of InstanceManager
Browse files Browse the repository at this point in the history
Reduce complexity of InstanceManager, by collapsing JobLevelScalingInstanceManager and NodeListScalingInstanceManager into InstanceManager.
InstanceManagerFactory is not needed anymore

Signed-off-by: Luca Carrogu <[email protected]>
  • Loading branch information
lukeseawalker committed Feb 1, 2024
1 parent 4d3849f commit e416bac
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 1,244 deletions.
7 changes: 4 additions & 3 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
from retrying import retry
from slurm_plugin.capacity_block_manager import CapacityBlockManager
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import TIMESTAMP_FORMAT, log_exception, print_with_count
from slurm_plugin.common import TIMESTAMP_FORMAT, log_exception, print_with_count, ScalingStrategy
from slurm_plugin.console_logger import ConsoleLogger
from slurm_plugin.instance_manager import InstanceManagerFactory
from slurm_plugin.instance_manager import InstanceManager
from slurm_plugin.slurm_resources import (
CONFIG_FILE_DIR,
ComputeResourceFailureEvent,
Expand Down Expand Up @@ -424,7 +424,7 @@ def shutdown(self):
@staticmethod
def _initialize_instance_manager(config):
"""Initialize instance manager class that will be used to launch/terminate/describe instances."""
return InstanceManagerFactory.get_manager(
return InstanceManager(
config.region,
config.cluster_name,
config.boto3_config,
Expand Down Expand Up @@ -877,6 +877,7 @@ def _handle_unhealthy_static_nodes(self, unhealthy_static_nodes):
node_list=node_list,
launch_batch_size=self._config.launch_max_batch_size,
update_node_address=self._config.update_node_address,
scaling_strategy=ScalingStrategy.BEST_EFFORT
)
# Add launched nodes to list of nodes being replaced, excluding any nodes that failed to launch
failed_nodes = set().union(*self._instance_manager.failed_nodes.values())
Expand Down
308 changes: 33 additions & 275 deletions src/slurm_plugin/instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,56 +63,7 @@ class NodeAddrUpdateError(Exception):
"""Raised when error occurs while updating NodeAddrs in Slurm node."""


class InstanceManagerFactory:
@staticmethod
def get_manager(
region: str,
cluster_name: str,
boto3_config: Config,
table_name: str = None,
hosted_zone: str = None,
dns_domain: str = None,
use_private_hostname: bool = False,
head_node_private_ip: str = None,
head_node_hostname: str = None,
fleet_config: Dict[str, any] = None,
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
job_level_scaling: bool = False,
):
if job_level_scaling:
return JobLevelScalingInstanceManager(
region=region,
cluster_name=cluster_name,
boto3_config=boto3_config,
table_name=table_name,
hosted_zone=hosted_zone,
dns_domain=dns_domain,
use_private_hostname=use_private_hostname,
head_node_private_ip=head_node_private_ip,
head_node_hostname=head_node_hostname,
fleet_config=fleet_config,
run_instances_overrides=run_instances_overrides,
create_fleet_overrides=create_fleet_overrides,
)
else:
return NodeListScalingInstanceManager(
region=region,
cluster_name=cluster_name,
boto3_config=boto3_config,
table_name=table_name,
hosted_zone=hosted_zone,
dns_domain=dns_domain,
use_private_hostname=use_private_hostname,
head_node_private_ip=head_node_private_ip,
head_node_hostname=head_node_hostname,
fleet_config=fleet_config,
run_instances_overrides=run_instances_overrides,
create_fleet_overrides=create_fleet_overrides,
)


class InstanceManager(ABC):
class InstanceManager:
"""
InstanceManager class.
Expand All @@ -134,6 +85,7 @@ def __init__(
fleet_config: Dict[str, any] = None,
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
job_level_scaling: bool = False,
):
"""Initialize InstanceLauncher with required attributes."""
self._region = region
Expand All @@ -154,23 +106,13 @@ def __init__(
resource_name, region_name=region, config=boto3_config
)
self.nodes_assigned_to_instances = {}
self.unused_launched_instances = {}
self.job_level_scaling = job_level_scaling

def _clear_failed_nodes(self):
"""Clear and reset failed nodes list."""
self.failed_nodes = {}

@abstractmethod
def add_instances(
self,
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
):
"""Add EC2 instances to Slurm nodes."""

@log_exception(
logger, "saving assigned hostnames in DynamoDB", raise_on_error=True, exception_to_raise=HostnameTableStoreError
Expand Down Expand Up @@ -472,39 +414,6 @@ def _create_request_for_nodes(table_name, node_names):
}
}


class JobLevelScalingInstanceManager(InstanceManager):
def __init__(
self,
region: str,
cluster_name: str,
boto3_config: Config,
table_name: str = None,
hosted_zone: str = None,
dns_domain: str = None,
use_private_hostname: bool = False,
head_node_private_ip: str = None,
head_node_hostname: str = None,
fleet_config: Dict[str, any] = None,
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
):
super().__init__(
region=region,
cluster_name=cluster_name,
boto3_config=boto3_config,
table_name=table_name,
hosted_zone=hosted_zone,
dns_domain=dns_domain,
use_private_hostname=use_private_hostname,
head_node_private_ip=head_node_private_ip,
head_node_hostname=head_node_hostname,
fleet_config=fleet_config,
run_instances_overrides=run_instances_overrides,
create_fleet_overrides=create_fleet_overrides,
)
self.unused_launched_instances = {}

def _clear_unused_launched_instances(self):
"""Clear and reset unused launched instances list."""
self.unused_launched_instances = {}
Expand Down Expand Up @@ -539,21 +448,36 @@ def add_instances(
# Reset unused instances pool
self._clear_unused_launched_instances()

if slurm_resume:
logger.debug("Performing job level scaling using Slurm resume fle")
self._add_instances_for_resume_file(
slurm_resume=slurm_resume,
node_list=node_list,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
if self.job_level_scaling:
if slurm_resume:
logger.debug("Performing job level scaling using Slurm resume fle")
self._add_instances_for_resume_file(
slurm_resume=slurm_resume,
node_list=node_list,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
else:
logger.error(
"Not possible to perform job level scaling because Slurm resume file content is empty. "
"No scaling actions will be taken."
)
else:
logger.error(
"Not possible to perform job level scaling because Slurm resume file content is empty. "
"No scaling actions will be taken."
)
if node_list:
logger.debug("Performing node list scaling using Slurm node resume list")
self._add_instances_for_nodes(
node_list=node_list,
launch_batch_size=launch_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
else:
logger.error(
"Not possible to perform scaling because Slurm node resume list is empty. "
"No scaling actions will be taken."
)

self._terminate_unassigned_launched_instances(terminate_batch_size)

Expand Down Expand Up @@ -1158,169 +1082,3 @@ def _update_slurm_node_addrs(self, slurm_nodes: List[str], launched_instances: L
print_with_count(launched_instances),
)
raise NodeAddrUpdateError


class NodeListScalingInstanceManager(InstanceManager):
def __init__(
self,
region: str,
cluster_name: str,
boto3_config: Config,
table_name: str = None,
hosted_zone: str = None,
dns_domain: str = None,
use_private_hostname: bool = False,
head_node_private_ip: str = None,
head_node_hostname: str = None,
fleet_config: Dict[str, any] = None,
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
):
super().__init__(
region=region,
cluster_name=cluster_name,
boto3_config=boto3_config,
table_name=table_name,
hosted_zone=hosted_zone,
dns_domain=dns_domain,
use_private_hostname=use_private_hostname,
head_node_private_ip=head_node_private_ip,
head_node_hostname=head_node_hostname,
fleet_config=fleet_config,
run_instances_overrides=run_instances_overrides,
create_fleet_overrides=create_fleet_overrides,
)

def add_instances(
self,
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
# Default to BEST_EFFORT since clustermgtd is not yet adapted for Job Level Scaling
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
):
"""Add EC2 instances to Slurm nodes."""
# Reset failed_nodes
self._clear_failed_nodes()

logger.debug("Node Scaling using Slurm Node Resume List")
self._add_instances_for_nodes(
node_list=node_list,
launch_batch_size=launch_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)

def _add_instances_for_nodes(
self,
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
):
"""Launch requested EC2 instances for nodes."""
# At fleet management level, the scaling strategies can be grouped based on the actual
# launch behaviour i.e. all-or-nothing or best-effort
all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]

nodes_to_launch = self._parse_nodes_resume_list(node_list)
for queue, compute_resources in nodes_to_launch.items():
for compute_resource, slurm_node_list in compute_resources.items():
logger.info("Launching instances for Slurm nodes %s", print_with_count(slurm_node_list))

# each compute resource can be configured to use create_fleet or run_instances
fleet_manager = FleetManagerFactory.get_manager(
self._cluster_name,
self._region,
self._boto3_config,
self._fleet_config,
queue,
compute_resource,
all_or_nothing_batch,
self._run_instances_overrides,
self._create_fleet_overrides,
)
for batch_nodes in grouper(slurm_node_list, launch_batch_size):
try:
launched_instances = fleet_manager.launch_ec2_instances(len(batch_nodes))

if update_node_address:
assigned_nodes = self._update_slurm_node_addrs_and_failed_nodes(
list(batch_nodes), launched_instances
)
try:
self._store_assigned_hostnames(assigned_nodes)
self._update_dns_hostnames(assigned_nodes)
except (HostnameTableStoreError, HostnameDnsStoreError):
self._update_failed_nodes(set(assigned_nodes.keys()))
except ClientError as e:
logger.error(
"Encountered exception when launching instances for nodes %s: %s",
print_with_count(batch_nodes),
e,
)
error_code = e.response.get("Error", {}).get("Code")
self._update_failed_nodes(set(batch_nodes), error_code)
except Exception as e:
logger.error(
"Encountered exception when launching instances for nodes %s: %s",
print_with_count(batch_nodes),
e,
)
self._update_failed_nodes(set(batch_nodes))

def _update_slurm_node_addrs_and_failed_nodes(self, slurm_nodes: List[str], launched_instances: List[EC2Instance]):
"""Update node information in slurm with info from launched EC2 instance."""
try:
# There could be fewer launched instances than nodes requested to be launched if best-effort scaling
# Group nodes into successfully launched and failed to launch based on number of launched instances
# fmt: off
launched_nodes = slurm_nodes[:len(launched_instances)]
fail_launch_nodes = slurm_nodes[len(launched_instances):]
# fmt: on
if launched_nodes:
# When using a cluster DNS domain we don't need to pass nodehostnames
# because they are equal to node names.
# It is possible to force the use of private hostnames by setting
# use_private_hostname = "true" as extra json parameter
node_hostnames = (
None if not self._use_private_hostname else [instance.hostname for instance in launched_instances]
)
update_nodes(
launched_nodes,
nodeaddrs=[instance.private_ip for instance in launched_instances],
nodehostnames=node_hostnames,
)
logger.info(
"Nodes are now configured with instances %s",
print_with_count(zip(launched_nodes, launched_instances)),
)
if fail_launch_nodes:
if launched_nodes:
logger.warning(
"Failed to launch instances due to limited EC2 capacity for following nodes: %s",
print_with_count(fail_launch_nodes),
)
self._update_failed_nodes(set(fail_launch_nodes), "LimitedInstanceCapacity")
else:
# EC2 Fleet doens't trigger any exception in case of ICEs and may return more than one error
# for each request. So when no instances were launched we force the reason to ICE
logger.error(
"Failed to launch instances due to limited EC2 capacity for following nodes: %s",
print_with_count(fail_launch_nodes),
)
self._update_failed_nodes(set(fail_launch_nodes), "InsufficientInstanceCapacity")

return dict(zip(launched_nodes, launched_instances))

except subprocess.CalledProcessError:
logger.error(
"Encountered error when updating nodes %s with instances %s",
print_with_count(slurm_nodes),
print_with_count(launched_instances),
)
self._update_failed_nodes(set(slurm_nodes))
return {}
4 changes: 2 additions & 2 deletions src/slurm_plugin/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from common.utils import read_json
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import ScalingStrategy, is_clustermgtd_heartbeat_valid, print_with_count
from slurm_plugin.instance_manager import InstanceManagerFactory
from slurm_plugin.instance_manager import InstanceManager
from slurm_plugin.slurm_resources import CONFIG_FILE_DIR

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -192,7 +192,7 @@ def _resume(arg_nodes, resume_config, slurm_resume):
node_list_with_status.append((node.name, node.state_string))
log.info("Current state of Slurm nodes to resume: %s", node_list_with_status)

instance_manager = InstanceManagerFactory.get_manager(
instance_manager = InstanceManager(
region=resume_config.region,
cluster_name=resume_config.cluster_name,
boto3_config=resume_config.boto3_config,
Expand Down
Loading

0 comments on commit e416bac

Please sign in to comment.