diff --git a/antarest/core/config.py b/antarest/core/config.py index 849209d826..2ba7a72745 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -254,6 +254,7 @@ class SlurmConfig: default_time_limit: int = 0 default_json_db_name: str = "" slurm_script_path: str = "" + partition: str = "" max_cores: int = 64 antares_versions_on_remote_server: List[str] = field(default_factory=list) enable_nb_cores_detection: bool = False @@ -290,6 +291,7 @@ def from_dict(cls, data: JSON) -> "SlurmConfig": default_time_limit=data.get("default_time_limit", defaults.default_time_limit), default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name), slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path), + partition=data.get("partition", defaults.partition), antares_versions_on_remote_server=data.get( "antares_versions_on_remote_server", defaults.antares_versions_on_remote_server, diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py index e4412a344e..d2ce3ad1ec 100644 --- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py +++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py @@ -182,6 +182,7 @@ def _init_launcher_parameters(self, local_workspace: Optional[Path] = None) -> M json_dir=local_workspace or self.slurm_config.local_workspace, default_json_db_name=self.slurm_config.default_json_db_name, slurm_script_path=self.slurm_config.slurm_script_path, + partition=self.slurm_config.partition, antares_versions_on_remote_server=self.slurm_config.antares_versions_on_remote_server, default_ssh_dict={ "username": self.slurm_config.username, diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index 6fb40a98df..a8f856269d 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -1,12 +1,14 @@ import enum +import json import typing as t from datetime import datetime -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.orm import relationship # type: ignore from antarest.core.persistence import Base +from antarest.core.utils.string import to_camel_case from antarest.login.model import Identity, UserInfo @@ -32,6 +34,15 @@ class LauncherParametersDTO(BaseModel): other_options: t.Optional[str] = None # add extensions field here + @classmethod + def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO": + """ + Convert the launcher parameters from a string to a `LauncherParametersDTO` object. + """ + if params is None: + return cls() + return cls.parse_obj(json.loads(params)) + class LogType(str, enum.Enum): STDOUT = "STDOUT" @@ -214,3 +225,43 @@ class JobCreationDTO(BaseModel): class LauncherEnginesDTO(BaseModel): engines: t.List[str] + + +class LauncherLoadDTO( + BaseModel, + extra="forbid", + validate_assignment=True, + allow_population_by_field_name=True, + alias_generator=to_camel_case, +): + """ + DTO representing the load of the SLURM cluster or local machine. + + Attributes: + allocated_cpu_rate: The rate of allocated CPU, in range (0, 100). + cluster_load_rate: The rate of cluster load, in range (0, 100). + nb_queued_jobs: The number of queued jobs. + launcher_status: The status of the launcher: "SUCCESS" or "FAILED". + """ + + allocated_cpu_rate: float = Field( + description="The rate of allocated CPU, in range (0, 100)", + ge=0, + le=100, + title="Allocated CPU Rate", + ) + cluster_load_rate: float = Field( + description="The rate of cluster load, in range (0, 100)", + ge=0, + le=100, + title="Cluster Load Rate", + ) + nb_queued_jobs: int = Field( + description="The number of queued jobs", + ge=0, + title="Number of Queued Jobs", + ) + launcher_status: str = Field( + description="The status of the launcher: 'SUCCESS' or 'FAILED'", + title="Launcher Status", + ) diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 86b65ec9ce..4c4ea9aa15 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -1,5 +1,4 @@ import functools -import json import logging import os import shutil @@ -33,11 +32,14 @@ JobLogType, JobResult, JobStatus, + LauncherLoadDTO, LauncherParametersDTO, LogType, XpansionParametersDTO, ) from antarest.launcher.repository import JobResultRepository +from antarest.launcher.ssh_client import calculates_slurm_load +from antarest.launcher.ssh_config import SSHConfigDTO from antarest.study.repository import StudyFilter from antarest.study.service import StudyService from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path @@ -502,7 +504,7 @@ def _import_output( launching_user = DEFAULT_ADMIN_USER study_id = job_result.study_id - job_launch_params = LauncherParametersDTO.parse_raw(job_result.launcher_params or "{}") + job_launch_params = LauncherParametersDTO.from_launcher_params(job_result.launcher_params) # this now can be a zip file instead of a directory ! output_true_path = find_single_output_path(output_path) @@ -585,7 +587,7 @@ def _download_fallback_output(self, job_id: str, params: RequestParameters) -> F export_path = Path(export_file_download.path) export_id = export_file_download.id - def export_task(notifier: TaskUpdateNotifier) -> TaskResult: + def export_task(_: TaskUpdateNotifier) -> TaskResult: try: # zip_dir(output_path, export_path) @@ -622,43 +624,47 @@ def download_output(self, job_id: str, params: RequestParameters) -> FileDownloa ) raise JobNotFound() - def get_load(self, from_cluster: bool = False) -> Dict[str, float]: - all_running_jobs = self.job_result_repository.get_running() - local_running_jobs = [] - slurm_running_jobs = [] - for job in all_running_jobs: - if job.launcher == "slurm": - slurm_running_jobs.append(job) - elif job.launcher == "local": - local_running_jobs.append(job) + def get_load(self) -> LauncherLoadDTO: + """ + Get the load of the SLURM cluster or the local machine. + """ + # SLURM load calculation + if self.config.launcher.default == "slurm": + if slurm_config := self.config.launcher.slurm: + ssh_config = SSHConfigDTO( + config_path=Path(), + username=slurm_config.username, + hostname=slurm_config.hostname, + port=slurm_config.port, + private_key_file=slurm_config.private_key_file, + key_password=slurm_config.key_password, + password=slurm_config.password, + ) + partition = slurm_config.partition + allocated_cpus, cluster_load, queued_jobs = calculates_slurm_load(ssh_config, partition) + return LauncherLoadDTO( + allocated_cpu_rate=allocated_cpus, + cluster_load_rate=cluster_load, + nb_queued_jobs=queued_jobs, + launcher_status="SUCCESS", + ) else: - logger.warning(f"Unknown job launcher {job.launcher}") + raise KeyError("Default launcher is slurm but it is not registered in the config file") - load = {} + # local load calculation + local_used_cpus = sum( + LauncherParametersDTO.from_launcher_params(job.launcher_params).nb_cpu or 1 + for job in self.job_result_repository.get_running() + ) - slurm_config = self.config.launcher.slurm - if slurm_config is not None: - if from_cluster: - raise NotImplementedError("Cluster load not implemented yet") - default_cpu = slurm_config.nb_cores.default - slurm_used_cpus = 0 - for job in slurm_running_jobs: - obj = json.loads(job.launcher_params) if job.launcher_params else {} - launch_params = LauncherParametersDTO(**obj) - slurm_used_cpus += launch_params.nb_cpu or default_cpu - load["slurm"] = slurm_used_cpus / slurm_config.max_cores + cluster_load_approx = min(1.0, local_used_cpus / (os.cpu_count() or 1)) - local_config = self.config.launcher.local - if local_config is not None: - default_cpu = local_config.nb_cores.default - local_used_cpus = 0 - for job in local_running_jobs: - obj = json.loads(job.launcher_params) if job.launcher_params else {} - launch_params = LauncherParametersDTO(**obj) - local_used_cpus += launch_params.nb_cpu or default_cpu - load["local"] = local_used_cpus / local_config.nb_cores.max - - return load + return LauncherLoadDTO( + allocated_cpu_rate=cluster_load_approx, + cluster_load_rate=cluster_load_approx, + nb_queued_jobs=0, + launcher_status="SUCCESS", + ) def get_solver_versions(self, solver: str) -> List[str]: """ diff --git a/antarest/launcher/ssh_client.py b/antarest/launcher/ssh_client.py new file mode 100644 index 0000000000..acace09c9b --- /dev/null +++ b/antarest/launcher/ssh_client.py @@ -0,0 +1,106 @@ +import contextlib +import socket +import shlex +from typing import Any, List, Tuple + +import paramiko + +from antarest.launcher.ssh_config import SSHConfigDTO + + +@contextlib.contextmanager # type: ignore +def ssh_client(ssh_config: SSHConfigDTO) -> paramiko.SSHClient: # type: ignore + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect( + hostname=ssh_config.hostname, + port=ssh_config.port, + username=ssh_config.username, + pkey=paramiko.RSAKey.from_private_key_file(filename=str(ssh_config.private_key_file)), + timeout=600, + allow_agent=False, + ) + with contextlib.closing(client): + yield client + + +class SlurmError(Exception): + pass + + +def execute_command(ssh_config: SSHConfigDTO, args: List[str]) -> Any: + command = " ".join(args) + try: + with ssh_client(ssh_config) as client: # type: ignore + stdin, stdout, stderr = client.exec_command(command, timeout=10) + output = stdout.read().decode("utf-8").strip() + error = stderr.read().decode("utf-8").strip() + except ( + paramiko.AuthenticationException, + paramiko.SSHException, + socket.timeout, + socket.error, + ) as e: + raise SlurmError(f"Can't retrieve SLURM information: {e}") from e + if error: + raise SlurmError(f"Can't retrieve SLURM information: {error}") + return output + + +def parse_cpu_used(sinfo_output: str) -> float: + """ + Returns the percentage of used CPUs in the cluster, in range [0, 100]. + """ + cpu_info_split = sinfo_output.split("/") + cpu_used_count = int(cpu_info_split[0]) + cpu_inactive_count = int(cpu_info_split[1]) + return 100 * cpu_used_count / (cpu_used_count + cpu_inactive_count) + + +def parse_cpu_load(sinfo_output: str) -> float: + """ + Returns the percentage of CPU load in the cluster, in range [0, 100]. + """ + lines = sinfo_output.splitlines() + cpus_used = 0.0 + cpus_available = 0.0 + for line in lines: + values = line.split() + if "N/A" in values: + continue + cpus_used += float(values[0]) + cpus_available += float(values[1]) + ratio = cpus_used / max(cpus_available, 1) + return 100 * min(1.0, ratio) + + +def calculates_slurm_load(ssh_config: SSHConfigDTO, partition: str) -> Tuple[float, float, int]: + """ + Returns the used/oad of the SLURM cluster or local machine in percentage and the number of queued jobs. + """ + partition_arg = f"--partition={partition}" if partition else "" + + # allocated cpus + arg_list = ["sinfo", partition_arg, "-O", "NodeAIOT", "--noheader"] + sinfo_cpus_used = execute_command(ssh_config, arg_list) + if not sinfo_cpus_used: + args = " ".join(map(shlex.quote, arg_list)) + raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result") + allocated_cpus = parse_cpu_used(sinfo_cpus_used) + + # cluster load + arg_list = ["sinfo", partition_arg, "-N", "-O", "CPUsLoad,CPUs", "--noheader"] + sinfo_cpus_load = execute_command(ssh_config, arg_list) + if not sinfo_cpus_load: + args = " ".join(map(shlex.quote, arg_list)) + raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result") + cluster_load = parse_cpu_load(sinfo_cpus_load) + + # queued jobs + arg_list = ["squeue", partition_arg, "--noheader", "-t", "pending", "|", "wc", "-l"] + queued_jobs = execute_command(ssh_config, arg_list) + if not queued_jobs: + args = " ".join(map(shlex.quote, arg_list)) + raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result") + + return allocated_cpus, cluster_load, int(queued_jobs) diff --git a/antarest/launcher/ssh_config.py b/antarest/launcher/ssh_config.py new file mode 100644 index 0000000000..1fa4a4393c --- /dev/null +++ b/antarest/launcher/ssh_config.py @@ -0,0 +1,21 @@ +import pathlib +from typing import Any, Dict, Optional + +import paramiko +from pydantic import BaseModel, root_validator + + +class SSHConfigDTO(BaseModel): + config_path: pathlib.Path + username: str + hostname: str + port: int = 22 + private_key_file: Optional[pathlib.Path] = None + key_password: Optional[str] = "" + password: Optional[str] = "" + + @root_validator() + def validate_connection_information(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "private_key_file" not in values and "password" not in values: + raise paramiko.AuthenticationException("SSH config needs at least a private key or a password") + return values diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index a4a1c45ba6..051eba2cc4 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -11,8 +11,16 @@ from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters from antarest.core.utils.web import APITag -from antarest.launcher.model import JobCreationDTO, JobResultDTO, LauncherEnginesDTO, LauncherParametersDTO, LogType +from antarest.launcher.model import ( + JobCreationDTO, + JobResultDTO, + LauncherEnginesDTO, + LauncherLoadDTO, + LauncherParametersDTO, + LogType, +) from antarest.launcher.service import LauncherService +from antarest.launcher.ssh_client import SlurmError from antarest.login.auth import Auth logger = logging.getLogger(__name__) @@ -34,12 +42,12 @@ def __init__(self, solver: str) -> None: def create_launcher_api(service: LauncherService, config: Config) -> APIRouter: - bp = APIRouter(prefix="/v1") + bp = APIRouter(prefix="/v1/launcher") auth = Auth(config) @bp.post( - "/launcher/run/{study_id}", + "/run/{study_id}", tags=[APITag.launcher], summary="Run study", response_model=JobCreationDTO, @@ -69,7 +77,7 @@ def run( ) @bp.get( - "/launcher/jobs", + "/jobs", tags=[APITag.launcher], summary="Retrieve jobs", response_model=List[JobResultDTO], @@ -88,7 +96,7 @@ def get_job( return [job.to_dto() for job in service.get_jobs(study, params, filter_orphans, latest)] @bp.get( - "/launcher/jobs/{job_id}/logs", + "/jobs/{job_id}/logs", tags=[APITag.launcher], summary="Retrieve job logs from job id", ) @@ -102,7 +110,7 @@ def get_job_log( return service.get_log(job_id, log_type, params) @bp.get( - "/launcher/jobs/{job_id}/output", + "/jobs/{job_id}/output", tags=[APITag.launcher], summary="Export job output", response_model=FileDownloadTaskDTO, @@ -119,7 +127,7 @@ def export_job_output( return service.download_output(job_id, params) @bp.post( - "/launcher/jobs/{job_id}/kill", + "/jobs/{job_id}/kill", tags=[APITag.launcher], summary="Kill job", ) @@ -136,7 +144,7 @@ def kill_job( ).to_dto() @bp.get( - "/launcher/jobs/{job_id}", + "/jobs/{job_id}", tags=[APITag.launcher], summary="Retrieve job info from job id", response_model=JobResultDTO, @@ -147,7 +155,7 @@ def get_result(job_id: UUID, current_user: JWTUser = Depends(auth.get_current_us return service.get_result(job_id, params).to_dto() @bp.get( - "/launcher/jobs/{job_id}/progress", + "/jobs/{job_id}/progress", tags=[APITag.launcher], summary="Retrieve job progress from job id", response_model=int, @@ -161,7 +169,7 @@ def get_progress(job_id: str, current_user: JWTUser = Depends(auth.get_current_u return int(service.get_launch_progress(job_id, params)) @bp.delete( - "/launcher/jobs/{job_id}", + "/jobs/{job_id}", tags=[APITag.launcher], summary="Remove job", responses={204: {"description": "Job removed"}}, @@ -172,7 +180,7 @@ def remove_result(job_id: str, current_user: JWTUser = Depends(auth.get_current_ service.remove_job(job_id, params) @bp.get( - "/launcher/engines", + "/engines", tags=[APITag.launcher], summary="Retrieve available engines", response_model=LauncherEnginesDTO, @@ -182,20 +190,26 @@ def get_engines() -> Any: return LauncherEnginesDTO(engines=service.get_launchers()) @bp.get( - "/launcher/load", + "/load", tags=[APITag.launcher], - summary="Get the cluster load in usage percent", + summary="Get the SLURM cluster or local machine load", + response_model=LauncherLoadDTO, ) - def get_load( - from_cluster: bool = False, - current_user: JWTUser = Depends(auth.get_current_user), - ) -> Dict[str, float]: - params = RequestParameters(user=current_user) + def get_load() -> LauncherLoadDTO: logger.info("Fetching launcher load") - return service.get_load(from_cluster) + try: + return service.get_load() + except SlurmError as e: + logger.warning(e, exc_info=e) + return LauncherLoadDTO( + allocated_cpu_rate=0.0, + cluster_load_rate=0.0, + nb_queued_jobs=0, + launcher_status=f"FAILED: {e}", + ) @bp.get( - "/launcher/versions", + "/versions", tags=[APITag.launcher], summary="Get list of supported solver versions", response_model=List[str], @@ -232,7 +246,7 @@ def get_solver_versions( # noinspection SpellCheckingInspection @bp.get( - "/launcher/nbcores", # We avoid "nb_cores" and "nb-cores" in endpoints + "/nbcores", # We avoid "nb_cores" and "nb-cores" in endpoints tags=[APITag.launcher], summary="Retrieving Min, Default, and Max Core Count", response_model=Dict[str, int], diff --git a/docs/install/1-CONFIG.md b/docs/install/1-CONFIG.md index 8af3e22615..9ad48dad03 100644 --- a/docs/install/1-CONFIG.md +++ b/docs/install/1-CONFIG.md @@ -505,6 +505,15 @@ port: 22 - If SLURM is connected to `dev-server-name` (*recette* and *integration*), use this path: `/applis/antares/launchAntaresRec.sh` +### **partition** + +- **Type:** String +- **Default value:** "" +- **Description:** SLURM partition name. The partition refers to a logical division of the computing resources + available on a cluster managed by SLURM. + - If not specified, the default behavior is to allow the SLURM controller + to select the default partition as designated by the system administrator. + ### **antares_versions_on_remote_server** - **Type:** List of String @@ -530,6 +539,7 @@ launcher: default_n_cpu: 20 default_json_db_name: launcher_db.json slurm_script_path: /applis/antares/launchAntares.sh + partition: calin1 db_primary_key: name antares_versions_on_remote_server: - '610' diff --git a/requirements.txt b/requirements.txt index 70ded33e8c..4e12840d32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ jsonref~=0.2 MarkupSafe~=2.0.1 numpy~=1.22.1 pandas~=1.4.0 +paramiko~=2.12.0 plyer~=2.0.0 psycopg2-binary==2.9.4 py7zr~=0.20.6 diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 1ef69cc2dc..03e3cbafe6 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1,4 +1,5 @@ import io +import os from http import HTTPStatus from pathlib import Path from unittest.mock import ANY @@ -6,6 +7,7 @@ from starlette.testclient import TestClient from antarest.core.model import PublicMode +from antarest.launcher.model import LauncherLoadDTO from antarest.study.business.adequacy_patch_management import PriceTakingOrder from antarest.study.business.area_management import AreaType, LayerInfoDTO from antarest.study.business.areas.properties_management import AdequacyPatchMode @@ -296,6 +298,15 @@ def test_main(client: TestClient, admin_access_token: str, study_id: str) -> Non headers={"Authorization": f'Bearer {fred_credentials["access_token"]}'}, ) job_id = res.json()["job_id"] + + res = client.get("/v1/launcher/load", headers=admin_headers) + assert res.status_code == 200, res.json() + launcher_load = LauncherLoadDTO.parse_obj(res.json()) + assert launcher_load.allocated_cpu_rate == 1 / (os.cpu_count() or 1) + assert launcher_load.cluster_load_rate == 1 / (os.cpu_count() or 1) + assert launcher_load.nb_queued_jobs == 0 + assert launcher_load.launcher_status == "SUCCESS" + res = client.get( f"/v1/launcher/jobs?study_id={study_id}", headers={"Authorization": f'Bearer {fred_credentials["access_token"]}'}, diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py index a8e20283c8..72f95a782f 100644 --- a/tests/launcher/test_service.py +++ b/tests/launcher/test_service.py @@ -1,9 +1,10 @@ import json +import math import os import time from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, List, Union +from typing import Any, Dict, List, Union from unittest.mock import Mock, call, patch from uuid import uuid4 from zipfile import ZIP_DEFLATED, ZipFile @@ -30,7 +31,15 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base -from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LauncherParametersDTO, LogType +from antarest.launcher.model import ( + JobLog, + JobLogType, + JobResult, + JobStatus, + LauncherLoadDTO, + LauncherParametersDTO, + LogType, +) from antarest.launcher.service import ( EXECUTION_INFO_FILE, LAUNCHER_PARAM_NAME_SUFFIX, @@ -900,16 +909,73 @@ def test_save_solver_stats(self, tmp_path: Path) -> None: ) assert actual_obj.to_dto().dict() == expected_obj.to_dto().dict() - def test_get_load(self, tmp_path: Path) -> None: + @pytest.mark.parametrize( + ["running_jobs", "expected_result", "default_launcher"], + [ + pytest.param( + [], + { + "allocated_cpu_rate": 0.0, + "cluster_load_rate": 0.0, + "nb_queued_jobs": 0, + "launcher_status": "SUCCESS", + }, + "local", + id="local_no_running_job", + ), + pytest.param( + [ + Mock( + spec=JobResult, + launcher="local", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params='{"nb_cpu": 7}', + ), + ], + { + "allocated_cpu_rate": min(1.0, 8.0 / (os.cpu_count() or 1)), + "cluster_load_rate": min(1.0, 8.0 / (os.cpu_count() or 1)), + "nb_queued_jobs": 0, + "launcher_status": "SUCCESS", + }, + "local", + id="local_with_running_jobs", + ), + pytest.param( + [], + { + "allocated_cpu_rate": 0.0, + "cluster_load_rate": 0.0, + "nb_queued_jobs": 0, + "launcher_status": "SUCCESS", + }, + "slurm", + id="slurm launcher with no config", + marks=pytest.mark.xfail( + reason="Default launcher is slurm but it is not registered in the config file", + raises=KeyError, + strict=True, + ), + ), + ], + ) + def test_get_load( + self, + tmp_path: Path, + running_jobs: List[JobResult], + expected_result: Dict[str, Any], + default_launcher: str, + ) -> None: study_service = Mock() job_repository = Mock() config = Config( storage=StorageConfig(tmp_dir=tmp_path), - launcher=LauncherConfig( - local=LocalConfig(), - slurm=SlurmConfig(nb_cores=NbCoresConfig(min=1, default=12, max=24)), - ), + launcher=LauncherConfig(default=default_launcher, local=LocalConfig(), slurm=None), ) launcher_service = LauncherService( config=config, @@ -922,61 +988,18 @@ def test_get_load(self, tmp_path: Path) -> None: cache=Mock(), ) - job_repository.get_running.side_effect = [ - # call #1 - [], - # call #2 - [], - # call #3 - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, - ), - ], - # call #4 - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params='{"nb_cpu": 18}', - ), - Mock( - spec=JobResult, - launcher="local", - launcher_params=None, - ), - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, - ), - Mock( - spec=JobResult, - launcher="local", - launcher_params='{"nb_cpu": 7}', - ), - ], - ] + job_repository.get_running.return_value = running_jobs + + launcher_expected_result = LauncherLoadDTO.parse_obj(expected_result) + actual_result = launcher_service.get_load() - # call #1 - with pytest.raises(NotImplementedError): - launcher_service.get_load(from_cluster=True) - - # call #2 - load = launcher_service.get_load() - assert load["slurm"] == 0 - assert load["local"] == 0 - - # call #3 - load = launcher_service.get_load() - slurm_config = config.launcher.slurm - assert load["slurm"] == slurm_config.nb_cores.default / slurm_config.max_cores - assert load["local"] == 0 - - # call #4 - load = launcher_service.get_load() - local_config = config.launcher.local - assert load["slurm"] == (18 + slurm_config.nb_cores.default) / slurm_config.max_cores - assert load["local"] == (7 + local_config.nb_cores.default) / local_config.nb_cores.max + assert launcher_expected_result.launcher_status == actual_result.launcher_status + assert launcher_expected_result.nb_queued_jobs == actual_result.nb_queued_jobs + assert math.isclose( + launcher_expected_result.cluster_load_rate, + actual_result.cluster_load_rate, + ) + assert math.isclose( + launcher_expected_result.allocated_cpu_rate, + actual_result.allocated_cpu_rate, + ) diff --git a/tests/launcher/test_slurm_launcher.py b/tests/launcher/test_slurm_launcher.py index dfb6846e89..a3a8cfe90b 100644 --- a/tests/launcher/test_slurm_launcher.py +++ b/tests/launcher/test_slurm_launcher.py @@ -40,6 +40,7 @@ def launcher_config(tmp_path: Path) -> Config: "default_time_limit": 20, "default_json_db_name": "antares.db", "slurm_script_path": "/path/to/slurm/launcher.sh", + "partition": "fake_partition", "max_cores": 32, "antares_versions_on_remote_server": ["840", "850", "860"], "enable_nb_cores_detection": False, @@ -103,6 +104,7 @@ def test_init_slurm_launcher_parameters(tmp_path: Path) -> None: local_workspace=tmp_path, default_json_db_name="default_json_db_name", slurm_script_path="slurm_script_path", + partition="fake_partition", antares_versions_on_remote_server=["42"], username="username", hostname="hostname", @@ -122,6 +124,7 @@ def test_init_slurm_launcher_parameters(tmp_path: Path) -> None: assert main_parameters.json_dir == slurm_config.local_workspace assert main_parameters.default_json_db_name == slurm_config.default_json_db_name assert main_parameters.slurm_script_path == slurm_config.slurm_script_path + assert main_parameters.partition == config.launcher.slurm.partition assert main_parameters.antares_versions_on_remote_server == slurm_config.antares_versions_on_remote_server assert main_parameters.default_ssh_dict == { "username": slurm_config.username, @@ -474,6 +477,7 @@ def test_kill_job( json_dir=Path(tmp_path), default_json_db_name=slurm_config.default_json_db_name, slurm_script_path=slurm_config.slurm_script_path, + partition="fake_partition", antares_versions_on_remote_server=slurm_config.antares_versions_on_remote_server, default_ssh_dict={ "username": slurm_config.username, diff --git a/tests/launcher/test_ssh_client.py b/tests/launcher/test_ssh_client.py new file mode 100644 index 0000000000..0ced719cde --- /dev/null +++ b/tests/launcher/test_ssh_client.py @@ -0,0 +1,27 @@ +import math +from unittest.mock import Mock + +import pytest + +from antarest.launcher.ssh_client import SlurmError, calculates_slurm_load, parse_cpu_load, parse_cpu_used + + +@pytest.mark.unit_test +def test_parse_cpu_used() -> None: + assert parse_cpu_used("3/28/1/32") == 100 * 3 / (3 + 28) + + +@pytest.mark.unit_test +def test_parse_cpu_load() -> None: + sinfo_output = "0.01 24 \n0.01 24 \nN/A 24 \n9.94 24 " + assert math.isclose( + parse_cpu_load(sinfo_output), + 100 * (0.01 + 0.01 + 9.94) / (24 + 24 + 24), + ) + + +@pytest.mark.unit_test +def test_calculates_slurm_load_without_private_key_fails() -> None: + ssh_config = Mock() + with pytest.raises(SlurmError): + calculates_slurm_load(ssh_config, "fake_partition")