Skip to content

Commit

Permalink
feat(service): use slurm sinfo command to improve "cluster load" in…
Browse files Browse the repository at this point in the history
…dicator (#1664)

Merge pull request #1664 from AntaresSimulatorTeam/fix/launcher_cannot_export_zipped_outputs (ANT-725)
  • Loading branch information
laurent-laporte-pro authored Feb 2, 2024
2 parents c6a09b0 + f380a44 commit ab7bc4b
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 122 deletions.
2 changes: 2 additions & 0 deletions antarest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class SlurmConfig:
default_time_limit: int = 0
default_json_db_name: str = ""
slurm_script_path: str = ""
partition: str = ""
max_cores: int = 64
antares_versions_on_remote_server: List[str] = field(default_factory=list)
enable_nb_cores_detection: bool = False
Expand Down Expand Up @@ -290,6 +291,7 @@ def from_dict(cls, data: JSON) -> "SlurmConfig":
default_time_limit=data.get("default_time_limit", defaults.default_time_limit),
default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name),
slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path),
partition=data.get("partition", defaults.partition),
antares_versions_on_remote_server=data.get(
"antares_versions_on_remote_server",
defaults.antares_versions_on_remote_server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _init_launcher_parameters(self, local_workspace: Optional[Path] = None) -> M
json_dir=local_workspace or self.slurm_config.local_workspace,
default_json_db_name=self.slurm_config.default_json_db_name,
slurm_script_path=self.slurm_config.slurm_script_path,
partition=self.slurm_config.partition,
antares_versions_on_remote_server=self.slurm_config.antares_versions_on_remote_server,
default_ssh_dict={
"username": self.slurm_config.username,
Expand Down
53 changes: 52 additions & 1 deletion antarest/launcher/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import enum
import json
import typing as t
from datetime import datetime

from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.orm import relationship # type: ignore

from antarest.core.persistence import Base
from antarest.core.utils.string import to_camel_case
from antarest.login.model import Identity, UserInfo


Expand All @@ -32,6 +34,15 @@ class LauncherParametersDTO(BaseModel):
other_options: t.Optional[str] = None
# add extensions field here

@classmethod
def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO":
"""
Convert the launcher parameters from a string to a `LauncherParametersDTO` object.
"""
if params is None:
return cls()
return cls.parse_obj(json.loads(params))


class LogType(str, enum.Enum):
STDOUT = "STDOUT"
Expand Down Expand Up @@ -214,3 +225,43 @@ class JobCreationDTO(BaseModel):

class LauncherEnginesDTO(BaseModel):
engines: t.List[str]


class LauncherLoadDTO(
BaseModel,
extra="forbid",
validate_assignment=True,
allow_population_by_field_name=True,
alias_generator=to_camel_case,
):
"""
DTO representing the load of the SLURM cluster or local machine.
Attributes:
allocated_cpu_rate: The rate of allocated CPU, in range (0, 100).
cluster_load_rate: The rate of cluster load, in range (0, 100).
nb_queued_jobs: The number of queued jobs.
launcher_status: The status of the launcher: "SUCCESS" or "FAILED".
"""

allocated_cpu_rate: float = Field(
description="The rate of allocated CPU, in range (0, 100)",
ge=0,
le=100,
title="Allocated CPU Rate",
)
cluster_load_rate: float = Field(
description="The rate of cluster load, in range (0, 100)",
ge=0,
le=100,
title="Cluster Load Rate",
)
nb_queued_jobs: int = Field(
description="The number of queued jobs",
ge=0,
title="Number of Queued Jobs",
)
launcher_status: str = Field(
description="The status of the launcher: 'SUCCESS' or 'FAILED'",
title="Launcher Status",
)
78 changes: 42 additions & 36 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import json
import logging
import os
import shutil
Expand Down Expand Up @@ -33,11 +32,14 @@
JobLogType,
JobResult,
JobStatus,
LauncherLoadDTO,
LauncherParametersDTO,
LogType,
XpansionParametersDTO,
)
from antarest.launcher.repository import JobResultRepository
from antarest.launcher.ssh_client import calculates_slurm_load
from antarest.launcher.ssh_config import SSHConfigDTO
from antarest.study.repository import StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path
Expand Down Expand Up @@ -502,7 +504,7 @@ def _import_output(
launching_user = DEFAULT_ADMIN_USER

study_id = job_result.study_id
job_launch_params = LauncherParametersDTO.parse_raw(job_result.launcher_params or "{}")
job_launch_params = LauncherParametersDTO.from_launcher_params(job_result.launcher_params)

# this now can be a zip file instead of a directory !
output_true_path = find_single_output_path(output_path)
Expand Down Expand Up @@ -585,7 +587,7 @@ def _download_fallback_output(self, job_id: str, params: RequestParameters) -> F
export_path = Path(export_file_download.path)
export_id = export_file_download.id

def export_task(notifier: TaskUpdateNotifier) -> TaskResult:
def export_task(_: TaskUpdateNotifier) -> TaskResult:
try:
#
zip_dir(output_path, export_path)
Expand Down Expand Up @@ -622,43 +624,47 @@ def download_output(self, job_id: str, params: RequestParameters) -> FileDownloa
)
raise JobNotFound()

def get_load(self, from_cluster: bool = False) -> Dict[str, float]:
all_running_jobs = self.job_result_repository.get_running()
local_running_jobs = []
slurm_running_jobs = []
for job in all_running_jobs:
if job.launcher == "slurm":
slurm_running_jobs.append(job)
elif job.launcher == "local":
local_running_jobs.append(job)
def get_load(self) -> LauncherLoadDTO:
"""
Get the load of the SLURM cluster or the local machine.
"""
# SLURM load calculation
if self.config.launcher.default == "slurm":
if slurm_config := self.config.launcher.slurm:
ssh_config = SSHConfigDTO(
config_path=Path(),
username=slurm_config.username,
hostname=slurm_config.hostname,
port=slurm_config.port,
private_key_file=slurm_config.private_key_file,
key_password=slurm_config.key_password,
password=slurm_config.password,
)
partition = slurm_config.partition
allocated_cpus, cluster_load, queued_jobs = calculates_slurm_load(ssh_config, partition)
return LauncherLoadDTO(
allocated_cpu_rate=allocated_cpus,
cluster_load_rate=cluster_load,
nb_queued_jobs=queued_jobs,
launcher_status="SUCCESS",
)
else:
logger.warning(f"Unknown job launcher {job.launcher}")
raise KeyError("Default launcher is slurm but it is not registered in the config file")

load = {}
# local load calculation
local_used_cpus = sum(
LauncherParametersDTO.from_launcher_params(job.launcher_params).nb_cpu or 1
for job in self.job_result_repository.get_running()
)

slurm_config = self.config.launcher.slurm
if slurm_config is not None:
if from_cluster:
raise NotImplementedError("Cluster load not implemented yet")
default_cpu = slurm_config.nb_cores.default
slurm_used_cpus = 0
for job in slurm_running_jobs:
obj = json.loads(job.launcher_params) if job.launcher_params else {}
launch_params = LauncherParametersDTO(**obj)
slurm_used_cpus += launch_params.nb_cpu or default_cpu
load["slurm"] = slurm_used_cpus / slurm_config.max_cores
cluster_load_approx = min(1.0, local_used_cpus / (os.cpu_count() or 1))

local_config = self.config.launcher.local
if local_config is not None:
default_cpu = local_config.nb_cores.default
local_used_cpus = 0
for job in local_running_jobs:
obj = json.loads(job.launcher_params) if job.launcher_params else {}
launch_params = LauncherParametersDTO(**obj)
local_used_cpus += launch_params.nb_cpu or default_cpu
load["local"] = local_used_cpus / local_config.nb_cores.max

return load
return LauncherLoadDTO(
allocated_cpu_rate=cluster_load_approx,
cluster_load_rate=cluster_load_approx,
nb_queued_jobs=0,
launcher_status="SUCCESS",
)

def get_solver_versions(self, solver: str) -> List[str]:
"""
Expand Down
106 changes: 106 additions & 0 deletions antarest/launcher/ssh_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import contextlib
import socket
import shlex
from typing import Any, List, Tuple

import paramiko

from antarest.launcher.ssh_config import SSHConfigDTO


@contextlib.contextmanager # type: ignore
def ssh_client(ssh_config: SSHConfigDTO) -> paramiko.SSHClient: # type: ignore
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(
hostname=ssh_config.hostname,
port=ssh_config.port,
username=ssh_config.username,
pkey=paramiko.RSAKey.from_private_key_file(filename=str(ssh_config.private_key_file)),
timeout=600,
allow_agent=False,
)
with contextlib.closing(client):
yield client


class SlurmError(Exception):
pass


def execute_command(ssh_config: SSHConfigDTO, args: List[str]) -> Any:
command = " ".join(args)
try:
with ssh_client(ssh_config) as client: # type: ignore
stdin, stdout, stderr = client.exec_command(command, timeout=10)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
except (
paramiko.AuthenticationException,
paramiko.SSHException,
socket.timeout,
socket.error,
) as e:
raise SlurmError(f"Can't retrieve SLURM information: {e}") from e
if error:
raise SlurmError(f"Can't retrieve SLURM information: {error}")
return output


def parse_cpu_used(sinfo_output: str) -> float:
"""
Returns the percentage of used CPUs in the cluster, in range [0, 100].
"""
cpu_info_split = sinfo_output.split("/")
cpu_used_count = int(cpu_info_split[0])
cpu_inactive_count = int(cpu_info_split[1])
return 100 * cpu_used_count / (cpu_used_count + cpu_inactive_count)


def parse_cpu_load(sinfo_output: str) -> float:
"""
Returns the percentage of CPU load in the cluster, in range [0, 100].
"""
lines = sinfo_output.splitlines()
cpus_used = 0.0
cpus_available = 0.0
for line in lines:
values = line.split()
if "N/A" in values:
continue
cpus_used += float(values[0])
cpus_available += float(values[1])
ratio = cpus_used / max(cpus_available, 1)
return 100 * min(1.0, ratio)


def calculates_slurm_load(ssh_config: SSHConfigDTO, partition: str) -> Tuple[float, float, int]:
"""
Returns the used/oad of the SLURM cluster or local machine in percentage and the number of queued jobs.
"""
partition_arg = f"--partition={partition}" if partition else ""

# allocated cpus
arg_list = ["sinfo", partition_arg, "-O", "NodeAIOT", "--noheader"]
sinfo_cpus_used = execute_command(ssh_config, arg_list)
if not sinfo_cpus_used:
args = " ".join(map(shlex.quote, arg_list))
raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result")
allocated_cpus = parse_cpu_used(sinfo_cpus_used)

# cluster load
arg_list = ["sinfo", partition_arg, "-N", "-O", "CPUsLoad,CPUs", "--noheader"]
sinfo_cpus_load = execute_command(ssh_config, arg_list)
if not sinfo_cpus_load:
args = " ".join(map(shlex.quote, arg_list))
raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result")
cluster_load = parse_cpu_load(sinfo_cpus_load)

# queued jobs
arg_list = ["squeue", partition_arg, "--noheader", "-t", "pending", "|", "wc", "-l"]
queued_jobs = execute_command(ssh_config, arg_list)
if not queued_jobs:
args = " ".join(map(shlex.quote, arg_list))
raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result")

return allocated_cpus, cluster_load, int(queued_jobs)
21 changes: 21 additions & 0 deletions antarest/launcher/ssh_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pathlib
from typing import Any, Dict, Optional

import paramiko
from pydantic import BaseModel, root_validator


class SSHConfigDTO(BaseModel):
config_path: pathlib.Path
username: str
hostname: str
port: int = 22
private_key_file: Optional[pathlib.Path] = None
key_password: Optional[str] = ""
password: Optional[str] = ""

@root_validator()
def validate_connection_information(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "private_key_file" not in values and "password" not in values:
raise paramiko.AuthenticationException("SSH config needs at least a private key or a password")
return values
Loading

0 comments on commit ab7bc4b

Please sign in to comment.