diff --git a/antarest/core/config.py b/antarest/core/config.py index 3a3e882df8..f258db429a 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -1,19 +1,14 @@ -import logging import multiprocessing import tempfile -from dataclasses import dataclass, field -from http import HTTPStatus +from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional import yaml -from fastapi import HTTPException from antarest.core.model import JSON from antarest.core.roles import RoleType -logger = logging.getLogger(__name__) - @dataclass(frozen=True) class ExternalAuthConfig: @@ -26,13 +21,16 @@ class ExternalAuthConfig: add_ext_groups: bool = False group_mapping: Dict[str, str] = field(default_factory=dict) - @staticmethod - def from_dict(data: JSON) -> "ExternalAuthConfig": - return ExternalAuthConfig( - url=data.get("url", None), - default_group_role=RoleType(data.get("default_group_role", RoleType.READER.value)), - add_ext_groups=data.get("add_ext_groups", False), - group_mapping=data.get("group_mapping", {}), + @classmethod + def from_dict(cls, data: JSON) -> "ExternalAuthConfig": + defaults = cls() + return cls( + url=data.get("url", defaults.url), + default_group_role=( + RoleType(data["default_group_role"]) if "default_group_role" in data else defaults.default_group_role + ), + add_ext_groups=data.get("add_ext_groups", defaults.add_ext_groups), + group_mapping=data.get("group_mapping", defaults.group_mapping), ) @@ -47,13 +45,18 @@ class SecurityConfig: disabled: bool = False external_auth: ExternalAuthConfig = ExternalAuthConfig() - @staticmethod - def from_dict(data: JSON) -> "SecurityConfig": - return SecurityConfig( - jwt_key=data.get("jwt", {}).get("key", ""), - admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", ""), - disabled=data.get("disabled", False), - external_auth=ExternalAuthConfig.from_dict(data.get("external_auth", {})), + @classmethod + def from_dict(cls, data: JSON) -> "SecurityConfig": + defaults = cls() + return cls( + jwt_key=data.get("jwt", {}).get("key", defaults.jwt_key), + admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", defaults.admin_pwd), + disabled=data.get("disabled", defaults.disabled), + external_auth=( + ExternalAuthConfig.from_dict(data["external_auth"]) + if "external_auth" in data + else defaults.external_auth + ), ) @@ -68,13 +71,14 @@ class WorkspaceConfig: groups: List[str] = field(default_factory=lambda: []) path: Path = Path() - @staticmethod - def from_dict(data: JSON) -> "WorkspaceConfig": - return WorkspaceConfig( - path=Path(data["path"]), - groups=data.get("groups", []), - filter_in=data.get("filter_in", [".*"]), - filter_out=data.get("filter_out", []), + @classmethod + def from_dict(cls, data: JSON) -> "WorkspaceConfig": + defaults = cls() + return cls( + filter_in=data.get("filter_in", defaults.filter_in), + filter_out=data.get("filter_out", defaults.filter_out), + groups=data.get("groups", defaults.groups), + path=Path(data["path"]) if "path" in data else defaults.path, ) @@ -94,18 +98,19 @@ class DbConfig: pool_size: int = 5 pool_use_lifo: bool = False - @staticmethod - def from_dict(data: JSON) -> "DbConfig": - return DbConfig( - db_admin_url=data.get("admin_url", None), - db_url=data.get("url", ""), - db_connect_timeout=data.get("db_connect_timeout", 10), - pool_recycle=data.get("pool_recycle", None), - pool_pre_ping=data.get("pool_pre_ping", False), - pool_use_null=data.get("pool_use_null", False), - pool_max_overflow=data.get("pool_max_overflow", 10), - pool_size=data.get("pool_size", 5), - pool_use_lifo=data.get("pool_use_lifo", False), + @classmethod + def from_dict(cls, data: JSON) -> "DbConfig": + defaults = cls() + return cls( + db_admin_url=data.get("admin_url", defaults.db_admin_url), + db_url=data.get("url", defaults.db_url), + db_connect_timeout=data.get("db_connect_timeout", defaults.db_connect_timeout), + pool_recycle=data.get("pool_recycle", defaults.pool_recycle), + pool_pre_ping=data.get("pool_pre_ping", defaults.pool_pre_ping), + pool_use_null=data.get("pool_use_null", defaults.pool_use_null), + pool_max_overflow=data.get("pool_max_overflow", defaults.pool_max_overflow), + pool_size=data.get("pool_size", defaults.pool_size), + pool_use_lifo=data.get("pool_use_lifo", defaults.pool_use_lifo), ) @@ -118,7 +123,7 @@ class StorageConfig: matrixstore: Path = Path("./matrixstore") archive_dir: Path = Path("./archives") tmp_dir: Path = Path(tempfile.gettempdir()) - workspaces: Dict[str, WorkspaceConfig] = field(default_factory=lambda: {}) + workspaces: Dict[str, WorkspaceConfig] = field(default_factory=dict) allow_deletion: bool = False watcher_lock: bool = True watcher_lock_delay: int = 10 @@ -130,36 +135,35 @@ class StorageConfig: auto_archive_sleeping_time: int = 3600 auto_archive_max_parallel: int = 5 - @staticmethod - def from_dict(data: JSON) -> "StorageConfig": - return StorageConfig( - tmp_dir=Path(data.get("tmp_dir", tempfile.gettempdir())), - matrixstore=Path(data["matrixstore"]), - workspaces={n: WorkspaceConfig.from_dict(w) for n, w in data["workspaces"].items()}, - allow_deletion=data.get("allow_deletion", False), - archive_dir=Path(data["archive_dir"]), - watcher_lock=data.get("watcher_lock", True), - watcher_lock_delay=data.get("watcher_lock_delay", 10), - download_default_expiration_timeout_minutes=data.get("download_default_expiration_timeout_minutes", 1440), - matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", 3600), - matrix_gc_dry_run=data.get("matrix_gc_dry_run", False), - auto_archive_threshold_days=data.get("auto_archive_threshold_days", 60), - auto_archive_dry_run=data.get("auto_archive_dry_run", False), - auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", 3600), - auto_archive_max_parallel=data.get("auto_archive_max_parallel", 5), + @classmethod + def from_dict(cls, data: JSON) -> "StorageConfig": + defaults = cls() + workspaces = ( + {key: WorkspaceConfig.from_dict(value) for key, value in data["workspaces"].items()} + if "workspaces" in data + else defaults.workspaces + ) + return cls( + matrixstore=Path(data["matrixstore"]) if "matrixstore" in data else defaults.matrixstore, + archive_dir=Path(data["archive_dir"]) if "archive_dir" in data else defaults.archive_dir, + tmp_dir=Path(data["tmp_dir"]) if "tmp_dir" in data else defaults.tmp_dir, + workspaces=workspaces, + allow_deletion=data.get("allow_deletion", defaults.allow_deletion), + watcher_lock=data.get("watcher_lock", defaults.watcher_lock), + watcher_lock_delay=data.get("watcher_lock_delay", defaults.watcher_lock_delay), + download_default_expiration_timeout_minutes=( + data.get( + "download_default_expiration_timeout_minutes", + defaults.download_default_expiration_timeout_minutes, + ) + ), + matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", defaults.matrix_gc_sleeping_time), + matrix_gc_dry_run=data.get("matrix_gc_dry_run", defaults.matrix_gc_dry_run), + auto_archive_threshold_days=data.get("auto_archive_threshold_days", defaults.auto_archive_threshold_days), + auto_archive_dry_run=data.get("auto_archive_dry_run", defaults.auto_archive_dry_run), + auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", defaults.auto_archive_sleeping_time), + auto_archive_max_parallel=data.get("auto_archive_max_parallel", defaults.auto_archive_max_parallel), ) - - -class InvalidConfigurationError(Exception): - """ - Check if configuration launcher is available - """ - - def __init__(self, launcher) -> None: - msg = f""" - The configuration: {launcher} is not available - """ - super().__init__(msg) @dataclass(frozen=True) @@ -172,43 +176,23 @@ class NbCoresConfig: default: int = 22 max: int = 24 - @classmethod - def from_dict(cls, data: JSON) -> "NbCoresConfig": - """ - Creates an instance of NBCoresConfig from a data dictionary - Args: - data: Parse config from dict. - Returns: object NbCoresConfig - """ - return cls(min=data["min"], max=data["max"], default=data["defaultValue"]) - def to_json(self) -> Dict[str, int]: """ Retrieves the number of cores parameters, returning a dictionary containing the values "min" (minimum allowed value), "defaultValue" (default value), and "max" (maximum allowed value) - Returns: Dict of core config + + Returns: + A dictionary: `{"min": min, "defaultValue": default, "max": max}`. + Because ReactJs Material UI expects "min", "defaultValue" and "max" keys. """ return {"min": self.min, "defaultValue": self.default, "max": self.max} def __post_init__(self) -> None: - """validation of cpu configuration""" - self.__validate_nb_cores(self.min, self.default, self.max) - - def __validate_nb_cores(self, min_cpu: int, default: int, max_cpu: int) -> None: - """ - Validates the number of cores parameters, raising an exception if they are - invalid (i.e., if 1 ≤ min ≤ default ≤ max is false) - Args: - min_cpu: min cpu - default: default cpu - max_cpu: max cpu - """ - msg = "" - if not (1 <= min_cpu <= default <= max_cpu): - msg = f"value min_cpu:{min_cpu} must be equal to 1" - msg = f"{msg} {default} must be less than max_cpu:{max_cpu} or greater than 1" - if msg: - raise ValueError(msg) + """validation of CPU configuration""" + if 1 <= self.min <= self.default <= self.max: + return + msg = f"Invalid configuration: 1 <= {self.min=} <= {self.default=} <= {self.max=}" + raise ValueError(msg) @dataclass(frozen=True) @@ -216,26 +200,27 @@ class LocalConfig: """Sub config object dedicated to launcher module (local)""" binaries: Dict[str, Path] = field(default_factory=dict) - enable_nb_core_detection: bool = False + enable_nb_cores_detection: bool = True nb_cores: NbCoresConfig = NbCoresConfig() @classmethod - def from_dict(cls, data: JSON) -> Optional["LocalConfig"]: + def from_dict(cls, data: JSON) -> "LocalConfig": """ - Creates an instance of NBCoresConfig from a data dictionary + Creates an instance of LocalConfig from a data dictionary Args: data: Parse config from dict. Returns: object NbCoresConfig """ - if data.get("enable_nb_cores_detection", False): - cpu = cls._autodetect_nb_cores() - nb_cores = NbCoresConfig(min=cpu["min"], default=cpu["default"], max=cpu["max"]) - else: - nb_cores = NbCoresConfig() + defaults = cls() + binaries = data.get("binaries", defaults.binaries) + enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection) + nb_cores = data.get("nb_cores", asdict(defaults.nb_cores)) + if enable_nb_cores_detection: + nb_cores.update(cls._autodetect_nb_cores()) return cls( - binaries={str(v): Path(p) for v, p in data["binaries"].items()}, - enable_nb_core_detection=data["enable_nb_cores_detection"], - nb_cores=nb_cores, + binaries={str(v): Path(p) for v, p in binaries.items()}, + enable_nb_cores_detection=enable_nb_cores_detection, + nb_cores=NbCoresConfig(**nb_cores), ) @classmethod @@ -265,45 +250,67 @@ class SlurmConfig: password: str = "" default_wait_time: int = 0 default_time_limit: int = 0 - default_n_cpu: int = 1 default_json_db_name: str = "" slurm_script_path: str = "" max_cores: int = 64 antares_versions_on_remote_server: List[str] = field(default_factory=list) - enable_nb_core_detection: bool = False + enable_nb_cores_detection: bool = False nb_cores: NbCoresConfig = NbCoresConfig() @classmethod def from_dict(cls, data: JSON) -> "SlurmConfig": """ Creates an instance of SlurmConfig from a data dictionary + Args: - data: Parse config from dict. + data: Parsed config from dict. Returns: object SlurmConfig """ - nb_cores = NbCoresConfig() + defaults = cls() + enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection) + nb_cores = data.get("nb_cores", asdict(defaults.nb_cores)) + if "default_n_cpu" in data: + # Use the old way to configure the NB cores for backward compatibility + nb_cores["default"] = int(data["default_n_cpu"]) + nb_cores["min"] = min(nb_cores["min"], nb_cores["default"]) + nb_cores["max"] = min(nb_cores["max"], nb_cores["default"]) + if enable_nb_cores_detection: + nb_cores.update(cls._autodetect_nb_cores()) return cls( - local_workspace=Path(data["local_workspace"]), - username=data["username"], - hostname=data["hostname"], - port=data["port"], - private_key_file=data["private_key_file"], - key_password=data["key_password"], - password=data["password"], - default_wait_time=data["default_wait_time"], - default_time_limit=data["default_time_limit"], - default_n_cpu=data["default_n_cpu"], - default_json_db_name=data["default_json_db_name"], - slurm_script_path=data["slurm_script_path"], - antares_versions_on_remote_server=data["antares_versions_on_remote_server"], - max_cores=data.get("max_cores", 64), - nb_cores=nb_cores, - enable_nb_core_detection=data.get("enable_nb_cores_detection", False), + local_workspace=Path(data.get("local_workspace", defaults.local_workspace)), + username=data.get("username", defaults.username), + hostname=data.get("hostname", defaults.hostname), + port=data.get("port", defaults.port), + private_key_file=data.get("private_key_file", defaults.private_key_file), + key_password=data.get("key_password", defaults.key_password), + password=data.get("password", defaults.password), + default_wait_time=data.get("default_wait_time", defaults.default_wait_time), + default_time_limit=data.get("default_time_limit", defaults.default_time_limit), + default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name), + slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path), + antares_versions_on_remote_server=data.get( + "antares_versions_on_remote_server", + defaults.antares_versions_on_remote_server, + ), + max_cores=data.get("max_cores", defaults.max_cores), + enable_nb_cores_detection=enable_nb_cores_detection, + nb_cores=NbCoresConfig(**nb_cores), ) - @staticmethod - def _autodetect_nb_cores() -> Dict[str, int]: - raise NotImplementedError() + @classmethod + def _autodetect_nb_cores(cls) -> Dict[str, int]: + raise NotImplementedError("NB Cores auto-detection is not implemented for SLURM server") + + +class InvalidConfigurationError(Exception): + """ + Exception raised when an attempt is made to retrieve the number of cores + of a launcher that doesn't exist in the configuration. + """ + + def __init__(self, launcher: str): + msg = f"Configuration is not available for the '{launcher}' launcher" + super().__init__(msg) @dataclass(frozen=True) @@ -313,49 +320,52 @@ class LauncherConfig: """ default: str = "local" - local: Optional[LocalConfig] = LocalConfig() - slurm: Optional[SlurmConfig] = SlurmConfig() + local: Optional[LocalConfig] = None + slurm: Optional[SlurmConfig] = None batch_size: int = 9999 @classmethod def from_dict(cls, data: JSON) -> "LauncherConfig": - local: Optional[LocalConfig] = None - if "local" in data: - local = LocalConfig.from_dict(data["local"]) - - slurm: Optional[SlurmConfig] = None - if "slurm" in data: - slurm = SlurmConfig.from_dict(data["slurm"]) - + defaults = cls() + default = data.get("default", cls.default) + local = LocalConfig.from_dict(data["local"]) if "local" in data else defaults.local + slurm = SlurmConfig.from_dict(data["slurm"]) if "slurm" in data else defaults.slurm + batch_size = data.get("batch_size", defaults.batch_size) return cls( - default=data.get("default", "local"), + default=default, local=local, slurm=slurm, - batch_size=data.get("batch_size", 9999), + batch_size=batch_size, ) + def __post_init__(self) -> None: + possible = {"local", "slurm"} + if self.default in possible: + return + msg = f"Invalid configuration: {self.default=} must be one of {possible!r}" + raise ValueError(msg) + def get_nb_cores(self, launcher: str) -> "NbCoresConfig": """ - This method retrieves the number of cores configuration for a given - launcher: "local," "slurm," or "default." + Retrieve the number of cores configuration for a given launcher: "local" or "slurm". + If "default" is specified, retrieve the configuration of the default launcher. + Args: - launcher: type of launcher local or slurm or default - Returns: min, max, default of cpu configuration + launcher: type of launcher "local", "slurm" or "default". + + Returns: + Number of cores of the given launcher. + + Raises: + InvalidConfigurationError: Exception raised when an attempt is made to retrieve + the number of cores of a launcher that doesn't exist in the configuration. """ - here = Path(__file__).parent.resolve() - project_path = next(iter(p for p in here.parents if p.joinpath("antarest").exists())) - file = project_path / "resources/application.yaml" - info_data = Config.from_yaml_file(file).launcher - - if launcher == "default": - launcher = info_data.default - if launcher == "slurm": - cpu = LauncherConfig.slurm.nb_cores - if launcher == "local": - cpu = info_data.local.nb_cores - elif launcher not in ("slurm", "local"): - raise InvalidConfigurationError("launcher") - return cpu + config_map = {"local": self.local, "slurm": self.slurm} + config_map["default"] = config_map[self.default] + launcher_config = config_map.get(launcher) + if launcher_config is None: + raise InvalidConfigurationError(launcher) + return launcher_config.nb_cores @dataclass(frozen=True) @@ -368,14 +378,13 @@ class LoggingConfig: json: bool = False level: str = "INFO" - @staticmethod - def from_dict(data: JSON) -> "LoggingConfig": - logging_config: Dict[str, Any] = data or {} - logfile: Optional[str] = logging_config.get("logfile") - return LoggingConfig( - logfile=Path(logfile) if logfile is not None else None, - json=logging_config.get("json", False), - level=logging_config.get("level", "INFO"), + @classmethod + def from_dict(cls, data: JSON) -> "LoggingConfig": + defaults = cls() + return cls( + logfile=Path(data["logfile"]) if "logfile" in data else defaults.logfile, + json=data.get("json", defaults.json), + level=data.get("level", defaults.level), ) @@ -389,12 +398,13 @@ class RedisConfig: port: int = 6379 password: Optional[str] = None - @staticmethod - def from_dict(data: JSON) -> "RedisConfig": - return RedisConfig( - host=data["host"], - port=data["port"], - password=data.get("password", None), + @classmethod + def from_dict(cls, data: JSON) -> "RedisConfig": + defaults = cls() + return cls( + host=data.get("host", defaults.host), + port=data.get("port", defaults.port), + password=data.get("password", defaults.password), ) @@ -405,9 +415,9 @@ class EventBusConfig: """ # noinspection PyUnusedLocal - @staticmethod - def from_dict(data: JSON) -> "EventBusConfig": - return EventBusConfig() + @classmethod + def from_dict(cls, data: JSON) -> "EventBusConfig": + return cls() @dataclass(frozen=True) @@ -418,10 +428,11 @@ class CacheConfig: checker_delay: float = 0.2 # in seconds - @staticmethod - def from_dict(data: JSON) -> "CacheConfig": - return CacheConfig( - checker_delay=float(data["checker_delay"]) if "checker_delay" in data else 0.2, + @classmethod + def from_dict(cls, data: JSON) -> "CacheConfig": + defaults = cls() + return cls( + checker_delay=data.get("checker_delay", defaults.checker_delay), ) @@ -430,9 +441,13 @@ class RemoteWorkerConfig: name: str queues: List[str] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "RemoteWorkerConfig": - return RemoteWorkerConfig(name=data["name"], queues=data.get("queues", [])) + @classmethod + def from_dict(cls, data: JSON) -> "RemoteWorkerConfig": + defaults = cls(name="") # `name` is mandatory + return cls( + name=data["name"], + queues=data.get("queues", defaults.queues), + ) @dataclass(frozen=True) @@ -444,16 +459,17 @@ class TaskConfig: max_workers: int = 5 remote_workers: List[RemoteWorkerConfig] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "TaskConfig": - return TaskConfig( - max_workers=int(data["max_workers"]) if "max_workers" in data else 5, - remote_workers=list( - map( - lambda x: RemoteWorkerConfig.from_dict(x), - data.get("remote_workers", []), - ) - ), + @classmethod + def from_dict(cls, data: JSON) -> "TaskConfig": + defaults = cls() + remote_workers = ( + [RemoteWorkerConfig.from_dict(d) for d in data["remote_workers"]] + if "remote_workers" in data + else defaults.remote_workers + ) + return cls( + max_workers=data.get("max_workers", defaults.max_workers), + remote_workers=remote_workers, ) @@ -466,11 +482,12 @@ class ServerConfig: worker_threadpool_size: int = 5 services: List[str] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "ServerConfig": - return ServerConfig( - worker_threadpool_size=int(data["worker_threadpool_size"]) if "worker_threadpool_size" in data else 5, - services=data.get("services", []), + @classmethod + def from_dict(cls, data: JSON) -> "ServerConfig": + defaults = cls() + return cls( + worker_threadpool_size=data.get("worker_threadpool_size", defaults.worker_threadpool_size), + services=data.get("services", defaults.services), ) @@ -494,36 +511,27 @@ class Config: tasks: TaskConfig = TaskConfig() root_path: str = "" - @staticmethod - def from_dict(data: JSON, res: Optional[Path] = None) -> "Config": - """ - Parse config from dict. - - Args: - data: dict struct to parse - res: resources path is not present in yaml file. - - Returns: - - """ - return Config( - security=SecurityConfig.from_dict(data.get("security", {})), - storage=StorageConfig.from_dict(data["storage"]), - launcher=LauncherConfig.from_dict(data.get("launcher", {})), - db=DbConfig.from_dict(data["db"]) if "db" in data else DbConfig(), - logging=LoggingConfig.from_dict(data.get("logging", {})), - debug=data.get("debug", False), - resources_path=res or Path(), - root_path=data.get("root_path", ""), - redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else None, - eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else EventBusConfig(), - cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else CacheConfig(), - tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else TaskConfig(), - server=ServerConfig.from_dict(data["server"]) if "server" in data else ServerConfig(), + @classmethod + def from_dict(cls, data: JSON) -> "Config": + defaults = cls() + return cls( + server=ServerConfig.from_dict(data["server"]) if "server" in data else defaults.server, + security=SecurityConfig.from_dict(data["security"]) if "security" in data else defaults.security, + storage=StorageConfig.from_dict(data["storage"]) if "storage" in data else defaults.storage, + launcher=LauncherConfig.from_dict(data["launcher"]) if "launcher" in data else defaults.launcher, + db=DbConfig.from_dict(data["db"]) if "db" in data else defaults.db, + logging=LoggingConfig.from_dict(data["logging"]) if "logging" in data else defaults.logging, + debug=data.get("debug", defaults.debug), + resources_path=data["resources_path"] if "resources_path" in data else defaults.resources_path, + redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else defaults.redis, + eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else defaults.eventbus, + cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else defaults.cache, + tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else defaults.tasks, + root_path=data.get("root_path", defaults.root_path), ) - @staticmethod - def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config": + @classmethod + def from_yaml_file(cls, file: Path, res: Optional[Path] = None) -> "Config": """ Parse config from yaml file. @@ -534,5 +542,8 @@ def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config": Returns: """ - data = yaml.safe_load(open(file)) - return Config.from_dict(data, res) + with open(file) as f: + data = yaml.safe_load(f) + if res is not None: + data["resources_path"] = res + return cls.from_dict(data) diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py index 926f2b50a3..00283b9ce8 100644 --- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py +++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py @@ -32,7 +32,6 @@ logger = logging.getLogger(__name__) logging.getLogger("paramiko").setLevel("WARN") -MAX_NB_CPU = 24 MAX_TIME_LIMIT = 864000 MIN_TIME_LIMIT = 3600 WORKSPACE_LOCK_FILE_NAME = ".lock" @@ -153,7 +152,7 @@ def _init_launcher_arguments(self, local_workspace: Optional[Path] = None) -> ar main_options_parameters = ParserParameters( default_wait_time=self.slurm_config.default_wait_time, default_time_limit=self.slurm_config.default_time_limit, - default_n_cpu=self.slurm_config.default_n_cpu, + default_n_cpu=self.slurm_config.nb_cores.default, studies_in_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_INPUT_DIR_NAME)), log_dir=str((Path(self.slurm_config.local_workspace) / LOG_DIR_NAME)), finished_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_OUTPUT_DIR_NAME)), @@ -440,7 +439,7 @@ def _run_study( _override_solver_version(study_path, version) append_log(launch_uuid, "Submitting study to slurm launcher") - launcher_args = self._check_and_apply_launcher_params(launcher_params) + launcher_args = self._apply_params(launcher_params) self._call_launcher(launcher_args, self.launcher_params) launch_success = self._check_if_study_is_in_launcher_db(launch_uuid) @@ -481,23 +480,40 @@ def _check_if_study_is_in_launcher_db(self, job_id: str) -> bool: studies = self.data_repo_tinydb.get_list_of_studies() return any(s.name == job_id for s in studies) - def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace: + def _apply_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace: + """ + Populate a `argparse.Namespace` object with the user parameters. + + Args: + launcher_params: + Contains the launcher parameters selected by the user. + If a parameter is not provided (`None`), the default value should be retrieved + from the configuration. + + Returns: + The `argparse.Namespace` object which is then passed to `antarestlauncher.main.run_with`, + to launch a simulation using Antares Launcher. + """ if launcher_params: launcher_args = deepcopy(self.launcher_args) - other_options = [] + if launcher_params.other_options: - options = re.split("\\s+", launcher_params.other_options) - for opt in options: - other_options.append(re.sub("[^a-zA-Z0-9_,-]", "", opt)) - if launcher_params.xpansion is not None: - launcher_args.xpansion_mode = "r" if launcher_params.xpansion_r_version else "cpp" + options = launcher_params.other_options.split() + other_options = [re.sub("[^a-zA-Z0-9_,-]", "", opt) for opt in options] + else: + other_options = [] + + # launcher_params.xpansion can be an `XpansionParametersDTO`, a bool or `None` + if launcher_params.xpansion: # not None and not False + launcher_args.xpansion_mode = {True: "r", False: "cpp"}[launcher_params.xpansion_r_version] if ( isinstance(launcher_params.xpansion, XpansionParametersDTO) and launcher_params.xpansion.sensitivity_mode ): other_options.append("xpansion_sensitivity") + time_limit = launcher_params.time_limit - if time_limit and isinstance(time_limit, int): + if time_limit is not None: if MIN_TIME_LIMIT > time_limit: logger.warning( f"Invalid slurm launcher time limit ({time_limit})," @@ -512,15 +528,23 @@ def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDT launcher_args.time_limit = MAX_TIME_LIMIT - 3600 else: launcher_args.time_limit = time_limit + post_processing = launcher_params.post_processing - if isinstance(post_processing, bool): + if post_processing is not None: launcher_args.post_processing = post_processing + nb_cpu = launcher_params.nb_cpu - if nb_cpu and isinstance(nb_cpu, int): - if 0 < nb_cpu <= MAX_NB_CPU: + if nb_cpu is not None: + nb_cores = self.slurm_config.nb_cores + if nb_cores.min <= nb_cpu <= nb_cores.max: launcher_args.n_cpu = nb_cpu else: - logger.warning(f"Invalid slurm launcher nb_cpu ({nb_cpu}), should be between 1 and 24") + logger.warning( + f"Invalid slurm launcher nb_cpu ({nb_cpu})," + f" should be between {nb_cores.min} and {nb_cores.max}" + ) + launcher_args.n_cpu = nb_cores.default + if launcher_params.adequacy_patch is not None: # the adequacy patch can be an empty object launcher_args.post_processing = True diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index 7a3c615811..a9bf0f6fde 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -17,14 +17,14 @@ class XpansionParametersDTO(BaseModel): class LauncherParametersDTO(BaseModel): - # Warning ! This class must be retrocompatible (that's the reason for the weird bool/XpansionParametersDTO union) + # Warning ! This class must be retro-compatible (that's the reason for the weird bool/XpansionParametersDTO union) # The reason is that it's stored in json format in database and deserialized using the latest class version # If compatibility is to be broken, an (alembic) data migration script should be added adequacy_patch: Optional[Dict[str, Any]] = None nb_cpu: Optional[int] = None post_processing: bool = False - time_limit: Optional[int] = None - xpansion: Union[bool, Optional[XpansionParametersDTO]] = None + time_limit: Optional[int] = None # 3600 <= time_limit < 864000 (10 days) + xpansion: Union[XpansionParametersDTO, bool, None] = None xpansion_r_version: bool = False archive_output: bool = True auto_unzip: bool = True diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 9c23505d86..aa081e5462 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -1,4 +1,5 @@ import functools +import json import logging import os import shutil @@ -10,7 +11,7 @@ from fastapi import HTTPException -from antarest.core.config import Config, LauncherConfig +from antarest.core.config import Config, NbCoresConfig from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager @@ -99,14 +100,20 @@ def _init_extensions(self) -> Dict[str, ILauncherExtension]: def get_launchers(self) -> List[str]: return list(self.launchers.keys()) - @staticmethod - def get_nb_cores(launcher: str) -> Dict[str, int]: + def get_nb_cores(self, launcher: str) -> NbCoresConfig: """ - Retrieving Min, Default, and Max Core Count. + Retrieve the configuration of the launcher's nb of cores. + Args: - launcher: name of the configuration : "default", "slurm" or "local". + launcher: name of the launcher: "default", "slurm" or "local". + + Returns: + Number of cores of the launcher + + Raises: + InvalidConfigurationError: if the launcher configuration is not available """ - return LauncherConfig().get_nb_cores(launcher).to_json() + return self.config.launcher.get_nb_cores(launcher) def _after_export_flat_hooks( self, @@ -595,27 +602,31 @@ def get_load(self, from_cluster: bool = False) -> Dict[str, float]: local_running_jobs.append(job) else: logger.warning(f"Unknown job launcher {job.launcher}") + load = {} - if self.config.launcher.slurm: + + slurm_config = self.config.launcher.slurm + if slurm_config is not None: if from_cluster: - raise NotImplementedError - slurm_used_cpus = functools.reduce( - lambda count, j: count - + ( - LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu - or self.config.launcher.slurm.default_n_cpu # type: ignore - ), - slurm_running_jobs, - 0, - ) - load["slurm"] = float(slurm_used_cpus) / self.config.launcher.slurm.max_cores - if self.config.launcher.local: - local_used_cpus = functools.reduce( - lambda count, j: count + (LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu or 1), - local_running_jobs, - 0, - ) - load["local"] = float(local_used_cpus) / (os.cpu_count() or 1) + raise NotImplementedError("Cluster load not implemented yet") + default_cpu = slurm_config.nb_cores.default + slurm_used_cpus = 0 + for job in slurm_running_jobs: + obj = json.loads(job.launcher_params) if job.launcher_params else {} + launch_params = LauncherParametersDTO(**obj) + slurm_used_cpus += launch_params.nb_cpu or default_cpu + load["slurm"] = slurm_used_cpus / slurm_config.max_cores + + local_config = self.config.launcher.local + if local_config is not None: + default_cpu = local_config.nb_cores.default + local_used_cpus = 0 + for job in local_running_jobs: + obj = json.loads(job.launcher_params) if job.launcher_params else {} + launch_params = LauncherParametersDTO(**obj) + local_used_cpus += launch_params.nb_cpu or default_cpu + load["local"] = local_used_cpus / local_config.nb_cores.max + return load def get_solver_versions(self, solver: str) -> List[str]: diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index e31e816013..51b3582997 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -230,8 +230,9 @@ def get_solver_versions( raise UnknownSolverConfig(solver) return service.get_solver_versions(solver) + # noinspection SpellCheckingInspection @bp.get( - "/launcher/nbcores", + "/launcher/nbcores", # We avoid "nb_cores" and "nb-cores" in endpoints tags=[APITag.launcher], summary="Retrieving Min, Default, and Max Core Count", response_model=Dict[str, int], @@ -256,13 +257,18 @@ def get_nb_cores( ) ) -> Dict[str, int]: """ - Retrieving Min, Default, and Max Core Count. + Retrieve the numer of cores of the launcher. + Args: - - `launcher`: name of the configuration to read: "default", "slurm" or "local". + - `launcher`: name of the configuration to read: "slurm" or "local". + If "default" is specified, retrieve the configuration of the default launcher. + + Returns: + - "min": min number of cores + - "defaultValue": default number of cores + - "max": max number of cores """ - logger.info(f"Fetching the number of cpu for the '{launcher}' configuration") - if launcher not in {"default", "slurm", "local"}: - raise UnknownSolverConfig(launcher) + logger.info(f"Fetching the number of cores for the '{launcher}' configuration") try: return service.config.launcher.get_nb_cores(launcher).to_json() except InvalidConfigurationError: diff --git a/tests/conftest_db.py b/tests/conftest_db.py index 877ca119d1..bcb4177766 100644 --- a/tests/conftest_db.py +++ b/tests/conftest_db.py @@ -3,7 +3,8 @@ import pytest from sqlalchemy import create_engine # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, sessionmaker # type: ignore from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base @@ -12,7 +13,7 @@ @pytest.fixture(name="db_engine") -def db_engine_fixture() -> Generator[Any, None, None]: +def db_engine_fixture() -> Generator[Engine, None, None]: """ Fixture that creates an in-memory SQLite database engine for testing. @@ -26,7 +27,7 @@ def db_engine_fixture() -> Generator[Any, None, None]: @pytest.fixture(name="db_session") -def db_session_fixture(db_engine) -> Generator: +def db_session_fixture(db_engine: Engine) -> Generator[Session, None, None]: """ Fixture that creates a database session for testing purposes. @@ -46,7 +47,7 @@ def db_session_fixture(db_engine) -> Generator: @pytest.fixture(name="db_middleware", autouse=True) def db_middleware_fixture( - db_engine: Any, + db_engine: Engine, ) -> Generator[DBSessionMiddleware, None, None]: """ Fixture that sets up a database session middleware with custom engine settings. diff --git a/tests/core/assets/__init__.py b/tests/core/assets/__init__.py new file mode 100644 index 0000000000..773f16ec60 --- /dev/null +++ b/tests/core/assets/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +ASSETS_DIR = Path(__file__).parent.resolve() diff --git a/tests/core/assets/config/application-2.14.yaml b/tests/core/assets/config/application-2.14.yaml new file mode 100644 index 0000000000..650093286d --- /dev/null +++ b/tests/core/assets/config/application-2.14.yaml @@ -0,0 +1,61 @@ +security: + disabled: false + jwt: + key: super-secret + login: + admin: + pwd: admin + +db: + url: "sqlite:////home/john/antares_data/database.db" + +storage: + tmp_dir: /tmp + matrixstore: /home/john/antares_data/matrices + archive_dir: /home/john/antares_data/archives + allow_deletion: false + workspaces: + default: + path: /home/john/antares_data/internal_studies/ + studies: + path: /home/john/antares_data/studies/ + +launcher: + default: slurm + local: + binaries: + 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver + 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver + + slurm: + local_workspace: /home/john/antares_data/slurm_workspace + + username: antares + hostname: slurm-prod-01 + + port: 22 + private_key_file: /home/john/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + default_n_cpu: 20 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '850' # 8.5.1/antares-8.5-solver + - '860' # 8.6.2/antares-8.6-solver + - '870' # 8.7.0/antares-8.7-solver + +debug: false + +root_path: "" + +server: + worker_threadpool_size: 12 + services: + - watcher + - matrix_gc + +logging: + level: INFO diff --git a/tests/core/assets/config/application-2.15.yaml b/tests/core/assets/config/application-2.15.yaml new file mode 100644 index 0000000000..c51d32aaae --- /dev/null +++ b/tests/core/assets/config/application-2.15.yaml @@ -0,0 +1,66 @@ +security: + disabled: false + jwt: + key: super-secret + login: + admin: + pwd: admin + +db: + url: "sqlite:////home/john/antares_data/database.db" + +storage: + tmp_dir: /tmp + matrixstore: /home/john/antares_data/matrices + archive_dir: /home/john/antares_data/archives + allow_deletion: false + workspaces: + default: + path: /home/john/antares_data/internal_studies/ + studies: + path: /home/john/antares_data/studies/ + +launcher: + default: slurm + local: + binaries: + 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver + 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver + enable_nb_cores_detection: True + + slurm: + local_workspace: /home/john/antares_data/slurm_workspace + + username: antares + hostname: slurm-prod-01 + + port: 22 + private_key_file: /home/john/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + enable_nb_cores_detection: False + nb_cores: + min: 1 + default: 22 + max: 24 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '850' # 8.5.1/antares-8.5-solver + - '860' # 8.6.2/antares-8.6-solver + - '870' # 8.7.0/antares-8.7-solver + +debug: false + +root_path: "" + +server: + worker_threadpool_size: 12 + services: + - watcher + - matrix_gc + +logging: + level: INFO diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 1c1c96a180..48c08a0f1f 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,15 +1,253 @@ from pathlib import Path +from unittest import mock import pytest -from antarest.core.config import Config +from antarest.core.config import ( + Config, + InvalidConfigurationError, + LauncherConfig, + LocalConfig, + NbCoresConfig, + SlurmConfig, +) +from tests.core.assets import ASSETS_DIR +LAUNCHER_CONFIG = { + "default": "slurm", + "local": { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + }, + "slurm": { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 1, "default": 34, "max": 36}, + }, + "batch_size": 100, +} -@pytest.mark.unit_test -def test_get_yaml(project_path: Path): - config = Config.from_yaml_file(file=project_path / "resources/application.yaml") - assert config.security.admin_pwd == "admin" - assert config.storage.workspaces["default"].path == Path("examples/internal_studies/") - assert not config.logging.json - assert config.logging.level == "INFO" +class TestNbCoresConfig: + def test_init__default_values(self): + config = NbCoresConfig() + assert config.min == 1 + assert config.default == 22 + assert config.max == 24 + + def test_init__invalid_values(self): + with pytest.raises(ValueError): + # default < min + NbCoresConfig(min=2, default=1, max=24) + with pytest.raises(ValueError): + # default > max + NbCoresConfig(min=1, default=25, max=24) + with pytest.raises(ValueError): + # min < 0 + NbCoresConfig(min=0, default=22, max=23) + with pytest.raises(ValueError): + # min > max + NbCoresConfig(min=22, default=22, max=21) + + def test_to_json(self): + config = NbCoresConfig() + # ReactJs Material UI expects "min", "defaultValue" and "max" keys + assert config.to_json() == {"min": 1, "defaultValue": 22, "max": 24} + + +class TestLocalConfig: + def test_init__default_values(self): + config = LocalConfig() + assert config.binaries == {}, "binaries should be empty by default" + assert config.enable_nb_cores_detection, "nb cores auto-detection should be enabled by default" + assert config.nb_cores == NbCoresConfig() + + def test_from_dict(self): + config = LocalConfig.from_dict( + { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + } + ) + assert config.binaries == {"860": Path("/bin/solver-860.exe")} + assert not config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) + + def test_from_dict__auto_detect(self): + with mock.patch("multiprocessing.cpu_count", return_value=8): + config = LocalConfig.from_dict( + { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": True, + } + ) + assert config.binaries == {"860": Path("/bin/solver-860.exe")} + assert config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=1, default=6, max=8) + + +class TestSlurmConfig: + def test_init__default_values(self): + config = SlurmConfig() + assert config.local_workspace == Path() + assert config.username == "" + assert config.hostname == "" + assert config.port == 0 + assert config.private_key_file == Path() + assert config.key_password == "" + assert config.password == "" + assert config.default_wait_time == 0 + assert config.default_time_limit == 0 + assert config.default_json_db_name == "" + assert config.slurm_script_path == "" + assert config.max_cores == 64 + assert config.antares_versions_on_remote_server == [], "solver versions should be empty by default" + assert not config.enable_nb_cores_detection, "nb cores auto-detection shouldn't be enabled by default" + assert config.nb_cores == NbCoresConfig() + + def test_from_dict(self): + config = SlurmConfig.from_dict( + { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + } + ) + assert config.local_workspace == Path("/home/john/antares/workspace") + assert config.username == "john" + assert config.hostname == "slurm-001" + assert config.port == 22 + assert config.private_key_file == Path("/home/john/.ssh/id_rsa") + assert config.key_password == "password" + assert config.password == "password" + assert config.default_wait_time == 10 + assert config.default_time_limit == 20 + assert config.default_json_db_name == "antares.db" + assert config.slurm_script_path == "/path/to/slurm/launcher.sh" + assert config.max_cores == 32 + assert config.antares_versions_on_remote_server == ["860"] + assert not config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) + + def test_from_dict__default_n_cpu__backport(self): + config = SlurmConfig.from_dict( + { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "default_n_cpu": 15, + } + ) + assert config.nb_cores == NbCoresConfig(min=1, default=15, max=15) + + def test_from_dict__auto_detect(self): + with pytest.raises(NotImplementedError): + SlurmConfig.from_dict({"enable_nb_cores_detection": True}) + + +class TestLauncherConfig: + def test_init__default_values(self): + config = LauncherConfig() + assert config.default == "local", "default launcher should be local" + assert config.local is None + assert config.slurm is None + assert config.batch_size == 9999 + + def test_from_dict(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.default == "slurm" + assert config.local == LocalConfig( + binaries={"860": Path("/bin/solver-860.exe")}, + enable_nb_cores_detection=False, + nb_cores=NbCoresConfig(min=2, default=10, max=20), + ) + assert config.slurm == SlurmConfig( + local_workspace=Path("/home/john/antares/workspace"), + username="john", + hostname="slurm-001", + port=22, + private_key_file=Path("/home/john/.ssh/id_rsa"), + key_password="password", + password="password", + default_wait_time=10, + default_time_limit=20, + default_json_db_name="antares.db", + slurm_script_path="/path/to/slurm/launcher.sh", + max_cores=32, + antares_versions_on_remote_server=["860"], + enable_nb_cores_detection=False, + nb_cores=NbCoresConfig(min=1, default=34, max=36), + ) + assert config.batch_size == 100 + + def test_init__invalid_launcher(self): + with pytest.raises(ValueError): + LauncherConfig(default="invalid_launcher") + + def test_get_nb_cores__default(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + # default == "slurm" + assert config.get_nb_cores(launcher="default") == NbCoresConfig(min=1, default=34, max=36) + + def test_get_nb_cores__local(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.get_nb_cores(launcher="local") == NbCoresConfig(min=2, default=10, max=20) + + def test_get_nb_cores__slurm(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.get_nb_cores(launcher="slurm") == NbCoresConfig(min=1, default=34, max=36) + + def test_get_nb_cores__invalid_configuration(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + with pytest.raises(InvalidConfigurationError): + config.get_nb_cores("invalid_launcher") + config = LauncherConfig.from_dict({}) + with pytest.raises(InvalidConfigurationError): + config.get_nb_cores("slurm") + + +class TestConfig: + @pytest.mark.parametrize("config_name", ["application-2.14.yaml", "application-2.15.yaml"]) + def test_from_yaml_file(self, config_name: str) -> None: + yaml_path = ASSETS_DIR.joinpath("config", config_name) + config = Config.from_yaml_file(yaml_path) + assert config.security.admin_pwd == "admin" + assert config.storage.workspaces["default"].path == Path("/home/john/antares_data/internal_studies") + assert not config.logging.json + assert config.logging.level == "INFO" diff --git a/tests/integration/launcher_blueprint/test_launcher_local.py b/tests/integration/launcher_blueprint/test_launcher_local.py index ea5f429e04..7244fba8ee 100644 --- a/tests/integration/launcher_blueprint/test_launcher_local.py +++ b/tests/integration/launcher_blueprint/test_launcher_local.py @@ -1,26 +1,26 @@ -import pytest +import http -import multiprocessing +import pytest from starlette.testclient import TestClient +from antarest.core.config import LocalConfig + +# noinspection SpellCheckingInspection @pytest.mark.integration_test -class TestlauncherNbcores: +class TestLauncherNbCores: """ The purpose of this unit test is to check the `/v1/launcher/nbcores` endpoint. """ - def test_get_launcher_nbcore( + def test_get_launcher_nb_cores( self, client: TestClient, user_access_token: str, ) -> None: - # Test The endpoint /v1/launcher/nbcores - # Fetch the default server version from the configuration file. - # NOTE: the value is defined in `tests/integration/assets/config.template.yml`. - max_cpu = multiprocessing.cpu_count() - default = max(1, max_cpu - 2) - nb_cores_expected = {"defaultValue": default, "max": max_cpu, "min": 1} + # NOTE: we have `enable_nb_cores_detection: True` in `tests/integration/assets/config.template.yml`. + local_nb_cores = LocalConfig.from_dict({"enable_nb_cores_detection": True}).nb_cores + nb_cores_expected = local_nb_cores.to_json() res = client.get( "/v1/launcher/nbcores", headers={"Authorization": f"Bearer {user_access_token}"}, @@ -44,3 +44,27 @@ def test_get_launcher_nbcore( res.raise_for_status() actual = res.json() assert actual == nb_cores_expected + + # Check that the endpoint raise an exception when the "slurm" launcher is requested. + res = client.get( + "/v1/launcher/nbcores?launcher=slurm", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() + actual = res.json() + assert actual == { + "description": "Unknown solver configuration: 'slurm'", + "exception": "UnknownSolverConfig", + } + + # Check that the endpoint raise an exception when an unknown launcher is requested. + res = client.get( + "/v1/launcher/nbcores?launcher=unknown", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() + actual = res.json() + assert actual == { + "description": "Unknown solver configuration: 'unknown'", + "exception": "UnknownSolverConfig", + } diff --git a/tests/launcher/test_local_launcher.py b/tests/launcher/test_local_launcher.py index 04741d319a..53adc03bf0 100644 --- a/tests/launcher/test_local_launcher.py +++ b/tests/launcher/test_local_launcher.py @@ -1,19 +1,28 @@ import os import textwrap +import uuid from pathlib import Path from unittest.mock import Mock, call -from uuid import uuid4 import pytest -from sqlalchemy import create_engine from antarest.core.config import Config, LauncherConfig, LocalConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.launcher.adapters.abstractlauncher import LauncherInitException from antarest.launcher.adapters.local_launcher.local_launcher import LocalLauncher from antarest.launcher.model import JobStatus, LauncherParametersDTO +SOLVER_NAME = "solver.bat" if os.name == "nt" else "solver.sh" + + +@pytest.fixture +def launcher_config(tmp_path: Path) -> Config: + """ + Fixture to create a launcher config with a local launcher. + """ + solver_path = tmp_path.joinpath(SOLVER_NAME) + data = {"binaries": {"700": solver_path}, "enable_nb_cores_detection": True} + return Config(launcher=LauncherConfig(local=LocalConfig.from_dict(data))) + @pytest.mark.unit_test def test_local_launcher__launcher_init_exception(): @@ -30,21 +39,12 @@ def test_local_launcher__launcher_init_exception(): @pytest.mark.unit_test -def test_compute(tmp_path: Path): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - local_launcher = LocalLauncher(Config(), callbacks=Mock(), event_bus=Mock(), cache=Mock()) +def test_compute(tmp_path: Path, launcher_config: Config): + local_launcher = LocalLauncher(launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock()) # prepare a dummy executable to simulate Antares Solver if os.name == "nt": - solver_name = "solver.bat" - solver_path = tmp_path.joinpath(solver_name) + solver_path = tmp_path.joinpath(SOLVER_NAME) solver_path.write_text( textwrap.dedent( """\ @@ -55,8 +55,7 @@ def test_compute(tmp_path: Path): ) ) else: - solver_name = "solver.sh" - solver_path = tmp_path.joinpath(solver_name) + solver_path = tmp_path.joinpath(SOLVER_NAME) solver_path.write_text( textwrap.dedent( """\ @@ -68,8 +67,8 @@ def test_compute(tmp_path: Path): ) solver_path.chmod(0o775) - uuid = uuid4() - local_launcher.job_id_to_study_id = {str(uuid): ("study-id", tmp_path / "run", Mock())} + study_id = uuid.uuid4() + local_launcher.job_id_to_study_id = {str(study_id): ("study-id", tmp_path / "run", Mock())} local_launcher.callbacks.import_output.return_value = "some output" launcher_parameters = LauncherParametersDTO( adequacy_patch=None, @@ -86,15 +85,15 @@ def test_compute(tmp_path: Path): local_launcher._compute( antares_solver_path=solver_path, study_uuid="study-id", - uuid=uuid, + uuid=study_id, launcher_parameters=launcher_parameters, ) # noinspection PyUnresolvedReferences local_launcher.callbacks.update_status.assert_has_calls( [ - call(str(uuid), JobStatus.RUNNING, None, None), - call(str(uuid), JobStatus.SUCCESS, None, "some output"), + call(str(study_id), JobStatus.RUNNING, None, None), + call(str(study_id), JobStatus.SUCCESS, None, "some output"), ] ) diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py index 665acc4bdc..a6177c5e61 100644 --- a/tests/launcher/test_service.py +++ b/tests/launcher/test_service.py @@ -1,18 +1,26 @@ import json -import multiprocessing import os import time from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, List, Literal, Union +from typing import Dict, List, Union from unittest.mock import Mock, call, patch from uuid import uuid4 from zipfile import ZIP_DEFLATED, ZipFile import pytest from sqlalchemy import create_engine - -from antarest.core.config import Config, LauncherConfig, LocalConfig, SlurmConfig, StorageConfig +from typing_extensions import Literal + +from antarest.core.config import ( + Config, + InvalidConfigurationError, + LauncherConfig, + LocalConfig, + NbCoresConfig, + SlurmConfig, + StorageConfig, +) from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownload, FileDownloadDTO, FileDownloadTaskDTO from antarest.core.interfaces.eventbus import Event, EventType @@ -21,7 +29,7 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base -from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LogType +from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LauncherParametersDTO, LogType from antarest.launcher.service import ( EXECUTION_INFO_FILE, LAUNCHER_PARAM_NAME_SUFFIX, @@ -34,911 +42,908 @@ from antarest.study.model import OwnerInfo, PublicMode, Study, StudyMetadataDTO -@pytest.mark.unit_test -@patch.object(Auth, "get_current_user") -def test_service_run_study(get_current_user_mock): - get_current_user_mock.return_value = None - storage_service_mock = Mock() - storage_service_mock.get_study_information.return_value = StudyMetadataDTO( - id="id", - name="name", - created=1, - updated=1, - type="rawstudy", - owner=OwnerInfo(id=0, name="author"), - groups=[], - public_mode=PublicMode.NONE, - version=42, - workspace="default", - managed=True, - archived=False, - ) - storage_service_mock.get_study_path.return_value = Path("path/to/study") - - uuid = uuid4() - launcher_mock = Mock() - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - event_bus = Mock() - - pending = JobResult( - id=str(uuid), - study_id="study_uuid", - job_status=JobStatus.PENDING, - launcher="local", - ) - repository = Mock() - repository.save.return_value = pending - - launcher_service = LauncherService( - config=Config(), - study_service=storage_service_mock, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=event_bus, - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher_service._generate_new_id = lambda: str(uuid) - - job_id = launcher_service.run_study( - "study_uuid", - "local", - None, - RequestParameters( - user=JWTUser( - id=0, - impersonator=0, - type="users", - ) - ), - ) - - assert job_id == str(uuid) - repository.save.assert_called_once_with(pending) - event_bus.push.assert_called_once_with( - Event( - type=EventType.STUDY_JOB_STARTED, - payload=pending.to_dto().dict(), - permissions=PermissionInfo(owner=0), +class TestLauncherService: + @pytest.mark.unit_test + @patch.object(Auth, "get_current_user") + def test_service_run_study(self, get_current_user_mock) -> None: + get_current_user_mock.return_value = None + storage_service_mock = Mock() + # noinspection SpellCheckingInspection + storage_service_mock.get_study_information.return_value = StudyMetadataDTO( + id="id", + name="name", + created="1", + updated="1", + type="rawstudy", + owner=OwnerInfo(id=0, name="author"), + groups=[], + public_mode=PublicMode.NONE, + version=42, + workspace="default", + managed=True, + archived=False, ) - ) - + storage_service_mock.get_study_path.return_value = Path("path/to/study") -@pytest.mark.unit_test -def test_service_get_result_from_launcher(): - launcher_mock = Mock() - fake_execution_result = JobResult( - id=str(uuid4()), - study_id="sid", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - launcher="local", - ) - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.get.return_value = fake_execution_result - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - job_id = uuid4() - assert ( - launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) + uuid = uuid4() + launcher_mock = Mock() + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + event_bus = Mock() -@pytest.mark.unit_test -def test_service_get_result_from_database(): - launcher_mock = Mock() - fake_execution_result = JobResult( - id=str(uuid4()), - study_id="sid", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - ) - launcher_mock.get_result.return_value = None - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.get.return_value = fake_execution_result - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + pending = JobResult( + id=str(uuid), + study_id="study_uuid", + job_status=JobStatus.PENDING, + launcher="local", + launcher_params=LauncherParametersDTO().json(), + ) + repository = Mock() + repository.save.return_value = pending + + launcher_service = LauncherService( + config=Config(), + study_service=storage_service_mock, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=event_bus, + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher_service._generate_new_id = lambda: str(uuid) - assert ( - launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) + job_id = launcher_service.run_study( + "study_uuid", + "local", + LauncherParametersDTO(), + RequestParameters( + user=JWTUser( + id=0, + impersonator=0, + type="users", + ) + ), + ) + assert job_id == str(uuid) + repository.save.assert_called_once_with(pending) + event_bus.push.assert_called_once_with( + Event( + type=EventType.STUDY_JOB_STARTED, + payload=pending.to_dto().dict(), + permissions=PermissionInfo(owner=0), + ) + ) -@pytest.mark.unit_test -def test_service_get_jobs_from_database(): - launcher_mock = Mock() - now = datetime.utcnow() - fake_execution_result = [ - JobResult( + @pytest.mark.unit_test + def test_service_get_result_from_launcher(self) -> None: + launcher_mock = Mock() + fake_execution_result = JobResult( id=str(uuid4()), - study_id="a", + study_id="sid", job_status=JobStatus.SUCCESS, msg="Hello, World!", exit_code=0, + launcher="local", ) - ] - returned_faked_execution_results = [ - JobResult( - id="1", - study_id="a", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - creation_date=now, - ), - JobResult( - id="2", - study_id="b", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - creation_date=now, - ), - ] - all_faked_execution_results = returned_faked_execution_results + [ - JobResult( - id="3", - study_id="c", + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.get.return_value = fake_execution_result + + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + job_id = uuid4() + assert ( + launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result + ) + + @pytest.mark.unit_test + def test_service_get_result_from_database(self) -> None: + launcher_mock = Mock() + fake_execution_result = JobResult( + id=str(uuid4()), + study_id="sid", job_status=JobStatus.SUCCESS, msg="Hello, World!", exit_code=0, - creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1), - ) - ] - launcher_mock.get_result.return_value = None - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.find_by_study.return_value = fake_execution_result - repository.get_all.return_value = all_faked_execution_results - - study_service = Mock() - study_service.repository = Mock() - study_service.repository.get_list.return_value = [ - Mock( - spec=Study, - id="b", - groups=[], - owner=User(id=2), - public_mode=PublicMode.NONE, ) - ] - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + launcher_mock.get_result.return_value = None + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.get.return_value = fake_execution_result + + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - study_id = uuid4() - assert ( - launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) - repository.find_by_study.assert_called_once_with(str(study_id)) - assert ( - launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == all_faked_execution_results - ) - assert ( - launcher_service.get_jobs( - None, - params=RequestParameters( - user=JWTUser( - id=2, - impersonator=2, - type="users", - groups=[], - ) - ), + assert ( + launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result ) - == returned_faked_execution_results - ) - with pytest.raises(UserHasNotPermissionError): - launcher_service.remove_job( - "some job", - RequestParameters( - user=JWTUser( - id=2, - impersonator=2, - type="users", - groups=[], - ) + @pytest.mark.unit_test + def test_service_get_jobs_from_database(self) -> None: + launcher_mock = Mock() + now = datetime.utcnow() + fake_execution_result = [ + JobResult( + id=str(uuid4()), + study_id="a", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + ) + ] + returned_faked_execution_results = [ + JobResult( + id="1", + study_id="a", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now, ), + JobResult( + id="2", + study_id="b", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now, + ), + ] + all_faked_execution_results = returned_faked_execution_results + [ + JobResult( + id="3", + study_id="c", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1), + ) + ] + launcher_mock.get_result.return_value = None + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.find_by_study.return_value = fake_execution_result + repository.get_all.return_value = all_faked_execution_results + + study_service = Mock() + study_service.repository = Mock() + study_service.repository.get_list.return_value = [ + Mock( + spec=Study, + id="b", + groups=[], + owner=User(id=2), + public_mode=PublicMode.NONE, + ) + ] + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + study_id = uuid4() + assert ( + launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result + ) + repository.find_by_study.assert_called_once_with(str(study_id)) + assert ( + launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == all_faked_execution_results + ) + assert ( + launcher_service.get_jobs( + None, + params=RequestParameters( + user=JWTUser( + id=2, + impersonator=2, + type="users", + groups=[], + ) + ), + ) + == returned_faked_execution_results ) - launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER)) - repository.delete.assert_called_with("some job") + with pytest.raises(UserHasNotPermissionError): + launcher_service.remove_job( + "some job", + RequestParameters( + user=JWTUser( + id=2, + impersonator=2, + type="users", + groups=[], + ) + ), + ) + launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER)) + repository.delete.assert_called_with("some job") -@pytest.mark.unit_test -@pytest.mark.parametrize( - "config, solver, expected", - [ - pytest.param( - { - "default": "local", - "local": [], - "slurm": [], - }, - "default", - [], - id="empty-config", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "default", - ["123", "456", "798"], - id="local-config-default", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "slurm", - [], - id="local-config-slurm", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "unknown", - [], - id="local-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + @pytest.mark.unit_test + @pytest.mark.parametrize( + "config, solver, expected", + [ + pytest.param( + { + "default": "local", + "local": [], + "slurm": [], + }, + "default", + [], + id="empty-config", ), - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "default", - ["147", "258", "369"], - id="slurm-config-default", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "local", - [], - id="slurm-config-local", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "unknown", - [], - id="slurm-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "default", + ["123", "456", "798"], + id="local-config-default", ), - ), - pytest.param( - { - "default": "slurm", - "local": ["456", "123", "798"], - "slurm": ["258", "147", "369"], - }, - "local", - ["123", "456", "798"], - id="local+slurm-config-local", - ), - ], -) -def test_service_get_solver_versions( - config: Dict[str, Union[str, List[str]]], - solver: Literal["default", "local", "slurm", "unknown"], - expected: List[str], -) -> None: - # Prepare the configuration - # the default server version from the configuration file. - # the default server is initialised to local - default = config.get("default", "local") - local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) - slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) - launcher_config = LauncherConfig( - default=default, - local=local if local else None, - slurm=slurm if slurm else None, - ) - config = Config(launcher=launcher_config) - launcher_service = LauncherService( - config=config, - study_service=Mock(), - job_result_repository=Mock(), - factory_launcher=Mock(), - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "slurm", + [], + id="local-config-slurm", + ), + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "unknown", + [], + id="local-config-unknown", + marks=pytest.mark.xfail( + reason="Unknown solver configuration: 'unknown'", + raises=KeyError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "default", + ["147", "258", "369"], + id="slurm-config-default", + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "local", + [], + id="slurm-config-local", + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "unknown", + [], + id="slurm-config-unknown", + marks=pytest.mark.xfail( + reason="Unknown solver configuration: 'unknown'", + raises=KeyError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "local": ["456", "123", "798"], + "slurm": ["258", "147", "369"], + }, + "local", + ["123", "456", "798"], + id="local+slurm-config-local", + ), + ], ) + def test_service_get_solver_versions( + self, + config: Dict[str, Union[str, List[str]]], + solver: Literal["default", "local", "slurm", "unknown"], + expected: List[str], + ) -> None: + # Prepare the configuration + # the default server version from the configuration file. + # the default server is initialised to local + default = config.get("default", "local") + local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) + slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) + launcher_config = LauncherConfig( + default=default, + local=local if local else None, + slurm=slurm if slurm else None, + ) + config = Config(launcher=launcher_config) + launcher_service = LauncherService( + config=config, + study_service=Mock(), + job_result_repository=Mock(), + factory_launcher=Mock(), + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - # Fetch the solver versions - actual = launcher_service.get_solver_versions(solver) - assert actual == expected - + # Fetch the solver versions + actual = launcher_service.get_solver_versions(solver) + assert actual == expected -@pytest.mark.unit_test -@pytest.mark.parametrize( - "config, launcher, expected", - [ - pytest.param( - { - "default": "local", - "local": [], - "slurm": [], - }, - "default", - [], - id="empty-config", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "default", - ["123", "456", "798"], - id="local-config-default", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "slurm", - [], - id="local-config-slurm", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "unknown", - [], - id="local-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + @pytest.mark.unit_test + @pytest.mark.parametrize( + "config_map, solver, expected", + [ + pytest.param( + {"default": "local", "local": {}, "slurm": {}}, + "default", + {}, + id="empty-config", ), - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "default", - ["147", "258", "369"], - id="slurm-config-default", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "local", - [], - id="slurm-config-local", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "unknown", - [], - id="slurm-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "default", + {"min": 1, "default": 11, "max": 12}, + id="local-config-default", ), - ), - pytest.param( - { - "default": "slurm", - "local": ["456", "123", "798"], - "slurm": ["258", "147", "369"], - }, - "local", - ["123", "456", "798"], - id="local+slurm-config-local", - ), - ], -) -def test_service_nb_core( - config: Dict[str, Union[str, List[str]]], - launcher: Literal["default", "local", "slurm", "unknown"], - expected: List[str], -) -> None: - # Prepare the configuration - # the default server version from the configuration file. - # the default server is initialised to local - # Test nb core of a launcher - default = config.get("default", "local") - local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) - slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) - launcher_config = LauncherConfig( - default=default, - local=local if local else None, - slurm=slurm if slurm else None, - ) - config = Config(launcher=launcher_config) - launcher_service = LauncherService( - config=config, - study_service=Mock(), - job_result_repository=Mock(), - factory_launcher=Mock(), - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - # Fetch the launcher(launcher) nb_cores - launcher_service.get_solver_versions(launcher) - nb_core = launcher_service.config.launcher.get_nb_cores(launcher).to_json() - if launcher in ("local", "default"): - max_cpu = multiprocessing.cpu_count() - default = max(1, max_cpu - 2) - nb_cores_expected = {"defaultValue": default, "max": max_cpu, "min": 1} - else: - nb_cores_expected = {"min": 1, "defaultValue": 22, "max": 24} - # Check the result - assert nb_core == nb_cores_expected - - -@pytest.mark.unit_test -def test_service_kill_job(tmp_path: Path): - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.launcher = launcher - launcher_service.job_result_repository.get.return_value = job_result_mock - launcher_service.launchers = {"slurm": Mock()} - - job_status = launcher_service.kill_job( - job_id=job_id, - params=RequestParameters(user=DEFAULT_ADMIN_USER), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "slurm", + {}, + id="local-config-slurm", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "unknown", + {}, + id="local-config-unknown", + marks=pytest.mark.xfail( + reason="Configuration is not available for the 'unknown' launcher", + raises=InvalidConfigurationError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "default", + {"min": 4, "default": 8, "max": 16}, + id="slurm-config-default", + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "local", + {}, + id="slurm-config-local", + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "unknown", + {}, + id="slurm-config-unknown", + marks=pytest.mark.xfail( + reason="Configuration is not available for the 'unknown' launcher", + raises=InvalidConfigurationError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "local": {"min": 1, "default": 11, "max": 12}, + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "local", + {"min": 1, "default": 11, "max": 12}, + id="local+slurm-config-local", + ), + ], ) + def test_get_nb_cores( + self, + config_map: Dict[str, Union[str, Dict[str, int]]], + solver: Literal["default", "local", "slurm", "unknown"], + expected: Dict[str, int], + ) -> None: + # Prepare the configuration + default = config_map.get("default", "local") + local_nb_cores = config_map.get("local", {}) + slurm_nb_cores = config_map.get("slurm", {}) + launcher_config = LauncherConfig( + default=default, + local=LocalConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": local_nb_cores}), + slurm=SlurmConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": slurm_nb_cores}), + ) + launcher_service = LauncherService( + config=Config(launcher=launcher_config), + study_service=Mock(), + job_result_repository=Mock(), + factory_launcher=Mock(), + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id) + # Fetch the number of cores + actual = launcher_service.get_nb_cores(solver) + + # Check the result + assert actual == NbCoresConfig(**expected) + + @pytest.mark.unit_test + def test_service_kill_job(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.launcher = launcher + launcher_service.job_result_repository.get.return_value = job_result_mock + launcher_service.launchers = {"slurm": Mock()} + + job_status = launcher_service.kill_job( + job_id=job_id, + params=RequestParameters(user=DEFAULT_ADMIN_USER), + ) - assert job_status.job_status == JobStatus.FAILED - launcher_service.job_result_repository.save.assert_called_once_with(job_status) + launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id) + assert job_status.job_status == JobStatus.FAILED + launcher_service.job_result_repository.save.assert_called_once_with(job_status) -def test_append_logs(tmp_path: Path): - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + def test_append_logs(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.output_id = None - job_result_mock.launcher = launcher - job_result_mock.logs = [] - launcher_service.job_result_repository.get.return_value = job_result_mock - - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - launcher_service.append_log(job_id, "test", JobLogType.BEFORE) - launcher_service.job_result_repository.save.assert_called_with(job_result_mock) - assert job_result_mock.logs[0].message == "test" - assert job_result_mock.logs[0].job_id == "job_id" - assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE) - - -def test_get_logs(tmp_path: Path): - study_service = Mock() - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.output_id = None - job_result_mock.launcher = launcher - job_result_mock.logs = [ - JobLog(message="first message", log_type=str(JobLogType.BEFORE)), - JobLog(message="second message", log_type=str(JobLogType.BEFORE)), - JobLog(message="last message", log_type=str(JobLogType.AFTER)), - ] - job_result_mock.launcher_params = '{"archive_output": false}' - - launcher_service.job_result_repository.get.return_value = job_result_mock - slurm_launcher = Mock() - launcher_service.launchers = {"slurm": slurm_launcher} - slurm_launcher.get_log.return_value = "launcher logs" - - logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "first message\nsecond message\nlauncher logs\nlast message" - logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "launcher logs" - - study_service.get_logs.side_effect = ["some sim log", "error log"] - - job_result_mock.output_id = "some id" - logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "first message\nsecond message\nsome sim log\nlast message" - - logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "error log" - - study_service.get_logs.assert_has_calls( - [ - call( - "study_id", - "some id", - job_id, - False, - params=RequestParameters(DEFAULT_ADMIN_USER), - ), - call( - "study_id", - "some id", - job_id, - True, - params=RequestParameters(DEFAULT_ADMIN_USER), - ), + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.output_id = None + job_result_mock.launcher = launcher + job_result_mock.logs = [] + launcher_service.job_result_repository.get.return_value = job_result_mock + + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + # noinspection SpellCheckingInspection + DBSessionMiddleware( + None, + custom_engine=engine, + session_args={"autocommit": False, "autoflush": False}, + ) + launcher_service.append_log(job_id, "test", JobLogType.BEFORE) + launcher_service.job_result_repository.save.assert_called_with(job_result_mock) + assert job_result_mock.logs[0].message == "test" + assert job_result_mock.logs[0].job_id == "job_id" + assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE) + + def test_get_logs(self, tmp_path: Path) -> None: + study_service = Mock() + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.output_id = None + job_result_mock.launcher = launcher + job_result_mock.logs = [ + JobLog(message="first message", log_type=str(JobLogType.BEFORE)), + JobLog(message="second message", log_type=str(JobLogType.BEFORE)), + JobLog(message="last message", log_type=str(JobLogType.AFTER)), ] - ) - + job_result_mock.launcher_params = '{"archive_output": false}' + + launcher_service.job_result_repository.get.return_value = job_result_mock + slurm_launcher = Mock() + launcher_service.launchers = {"slurm": slurm_launcher} + slurm_launcher.get_log.return_value = "launcher logs" + + logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "first message\nsecond message\nlauncher logs\nlast message" + logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "launcher logs" + + study_service.get_logs.side_effect = ["some sim log", "error log"] + + job_result_mock.output_id = "some id" + logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "first message\nsecond message\nsome sim log\nlast message" + + logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "error log" + + study_service.get_logs.assert_has_calls( + [ + call( + "study_id", + "some id", + job_id, + False, + params=RequestParameters(DEFAULT_ADMIN_USER), + ), + call( + "study_id", + "some id", + job_id, + True, + params=RequestParameters(DEFAULT_ADMIN_USER), + ), + ] + ) -def test_manage_output(tmp_path: Path): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) + def test_manage_output(self, tmp_path: Path) -> None: + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + # noinspection SpellCheckingInspection + DBSessionMiddleware( + None, + custom_engine=engine, + session_args={"autocommit": False, "autoflush": False}, + ) - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - output_path = tmp_path / "output" - zipped_output_path = tmp_path / "zipped_output" - os.mkdir(output_path) - os.mkdir(zipped_output_path) - new_output_path = output_path / "new_output" - os.mkdir(new_output_path) - (new_output_path / "log").touch() - (new_output_path / "data").touch() - additional_log = tmp_path / "output.log" - additional_log.write_text("some log") - new_output_zipped_path = zipped_output_path / "test.zip" - with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data: - output_data.writestr("some output", "0\n1") - job_id = "job_id" - zipped_job_id = "zipped_job_id" - study_id = "study_id" - launcher_service.job_result_repository.get.side_effect = [ - None, - JobResult(id=job_id, study_id=study_id), - JobResult(id=job_id, study_id=study_id, output_id="some id"), - JobResult(id=zipped_job_id, study_id=study_id), - JobResult( - id=job_id, - study_id=study_id, - ), - JobResult( - id=job_id, - study_id=study_id, - launcher_params=json.dumps( - { - "archive_output": False, - f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello", - } + output_path = tmp_path / "output" + zipped_output_path = tmp_path / "zipped_output" + os.mkdir(output_path) + os.mkdir(zipped_output_path) + new_output_path = output_path / "new_output" + os.mkdir(new_output_path) + (new_output_path / "log").touch() + (new_output_path / "data").touch() + additional_log = tmp_path / "output.log" + additional_log.write_text("some log") + new_output_zipped_path = zipped_output_path / "test.zip" + with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data: + output_data.writestr("some output", "0\n1") + job_id = "job_id" + zipped_job_id = "zipped_job_id" + study_id = "study_id" + launcher_service.job_result_repository.get.side_effect = [ + None, + JobResult(id=job_id, study_id=study_id), + JobResult(id=job_id, study_id=study_id, output_id="some id"), + JobResult(id=zipped_job_id, study_id=study_id), + JobResult( + id=job_id, + study_id=study_id, + ), + JobResult( + id=job_id, + study_id=study_id, + launcher_params=json.dumps( + { + "archive_output": False, + f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello", + } + ), ), - ), - ] - with pytest.raises(JobNotFound): - launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - - launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - assert not launcher_service._get_job_output_fallback_path(job_id).exists() - launcher_service.study_service.import_output.assert_called() - - launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) - launcher_service.study_service.export_output.assert_called() - - launcher_service._import_output( - zipped_job_id, - zipped_output_path, - { - "out.log": [additional_log], - "antares-out": [additional_log], - "antares-err": [additional_log], - }, - ) - launcher_service.study_service.save_logs.has_calls( - [ - call(study_id, zipped_job_id, "out.log", "some log"), - call(study_id, zipped_job_id, "out", "some log"), - call(study_id, zipped_job_id, "err", "some log"), ] - ) + with pytest.raises(JobNotFound): + launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + + launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + assert not launcher_service._get_job_output_fallback_path(job_id).exists() + launcher_service.study_service.import_output.assert_called() - launcher_service.study_service.import_output.side_effect = [ - StudyNotFoundError(""), - StudyNotFoundError(""), - ] - - assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None - - (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}") - output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - assert output_name is not None - assert output_name.endswith("-hello") - assert launcher_service._get_job_output_fallback_path(job_id).exists() - assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists() - - launcher_service.job_result_repository.get.reset_mock() - launcher_service.job_result_repository.get.side_effect = [ - None, - JobResult(id=job_id, study_id=study_id, output_id=output_name), - ] - with pytest.raises(JobNotFound): launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) + launcher_service.study_service.export_output.assert_called() - study_service.get_study.reset_mock() - study_service.get_study.side_effect = StudyNotFoundError("") + launcher_service._import_output( + zipped_job_id, + zipped_output_path, + { + "out.log": [additional_log], + "antares-out": [additional_log], + "antares-err": [additional_log], + }, + ) + launcher_service.study_service.save_logs.has_calls( + [ + call(study_id, zipped_job_id, "out.log", "some log"), + call(study_id, zipped_job_id, "out", "some log"), + call(study_id, zipped_job_id, "err", "some log"), + ] + ) - export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True) - launcher_service.file_transfer_manager.request_download.return_value = FileDownload( - id="a", name="a", filename="a", ready=True, path="a" - ) - launcher_service.task_service.add_task.return_value = "some id" + launcher_service.study_service.import_output.side_effect = [ + StudyNotFoundError(""), + StudyNotFoundError(""), + ] - assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO( - task="some id", file=export_file - ) + assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None - launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER)) - assert not launcher_service._get_job_output_fallback_path(job_id).exists() + (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}") + output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + assert output_name is not None + assert output_name.endswith("-hello") + assert launcher_service._get_job_output_fallback_path(job_id).exists() + assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists() + launcher_service.job_result_repository.get.reset_mock() + launcher_service.job_result_repository.get.side_effect = [ + None, + JobResult(id=job_id, study_id=study_id, output_id=output_name), + ] + with pytest.raises(JobNotFound): + launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) -def test_save_stats(tmp_path: Path) -> None: - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + study_service.get_study.reset_mock() + study_service.get_study.side_effect = StudyNotFoundError("") - launcher_service = LauncherService( - config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True) + launcher_service.file_transfer_manager.request_download.return_value = FileDownload( + id="a", name="a", filename="a", ready=True, path="a" + ) + launcher_service.task_service.add_task.return_value = "some id" - job_id = "job_id" - study_id = "study_id" - job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS) - - output_path = tmp_path / "some-output" - output_path.mkdir() - - launcher_service._save_solver_stats(job_result, output_path) - launcher_service.job_result_repository.save.assert_not_called() - - expected_saved_stats = """#item duration_ms NbOccurences -mc_years 216328 1 -study_loading 4304 1 -survey_report 158 1 -total 244581 1 -tsgen_hydro 1683 1 -tsgen_load 2702 1 -tsgen_solar 21606 1 -tsgen_thermal 407 2 -tsgen_wind 2500 1 - """ - (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats) - - launcher_service._save_solver_stats(job_result, output_path) - launcher_service.job_result_repository.save.assert_called_with( - JobResult( - id=job_id, - study_id=study_id, - job_status=JobStatus.SUCCESS, - solver_stats=expected_saved_stats, + assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO( + task="some id", file=export_file ) - ) - zip_file = tmp_path / "test.zip" - with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: - output_data.writestr(EXECUTION_INFO_FILE, "0\n1") + launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER)) + assert not launcher_service._get_job_output_fallback_path(job_id).exists() + + def test_save_solver_stats(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - launcher_service._save_solver_stats(job_result, zip_file) - launcher_service.job_result_repository.save.assert_called_with( - JobResult( - id=job_id, - study_id=study_id, - job_status=JobStatus.SUCCESS, - solver_stats="0\n1", + job_id = "job_id" + study_id = "study_id" + job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS) + + output_path = tmp_path / "some-output" + output_path.mkdir() + + launcher_service._save_solver_stats(job_result, output_path) + launcher_service.job_result_repository.save.assert_not_called() + + expected_saved_stats = """#item duration_ms NbOccurences + mc_years 216328 1 + study_loading 4304 1 + survey_report 158 1 + total 244581 1 + tsgen_hydro 1683 1 + tsgen_load 2702 1 + tsgen_solar 21606 1 + tsgen_thermal 407 2 + tsgen_wind 2500 1 + """ + (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats) + + launcher_service._save_solver_stats(job_result, output_path) + launcher_service.job_result_repository.save.assert_called_with( + JobResult( + id=job_id, + study_id=study_id, + job_status=JobStatus.SUCCESS, + solver_stats=expected_saved_stats, + ) ) - ) + zip_file = tmp_path / "test.zip" + with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: + output_data.writestr(EXECUTION_INFO_FILE, "0\n1") + + launcher_service._save_solver_stats(job_result, zip_file) + launcher_service.job_result_repository.save.assert_called_with( + JobResult( + id=job_id, + study_id=study_id, + job_status=JobStatus.SUCCESS, + solver_stats="0\n1", + ) + ) -def test_get_load(tmp_path: Path): - study_service = Mock() - job_repository = Mock() + def test_get_load(self, tmp_path: Path) -> None: + study_service = Mock() + job_repository = Mock() - launcher_service = LauncherService( - config=Mock( + config = Config( storage=StorageConfig(tmp_dir=tmp_path), - launcher=LauncherConfig(local=LocalConfig(), slurm=SlurmConfig(default_n_cpu=12)), - ), - study_service=study_service, - job_result_repository=job_repository, - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - job_repository.get_running.side_effect = [ - [], - [], - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, - ), - ], - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params='{"nb_cpu": 18}', + launcher=LauncherConfig( + local=LocalConfig(), + slurm=SlurmConfig(nb_cores=NbCoresConfig(min=1, default=12, max=24)), ), - Mock( - spec=JobResult, - launcher="local", - launcher_params=None, - ), - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, - ), - Mock( - spec=JobResult, - launcher="local", - launcher_params='{"nb_cpu": 7}', - ), - ], - ] - - with pytest.raises(NotImplementedError): - launcher_service.get_load(from_cluster=True) - - load = launcher_service.get_load() - assert load["slurm"] == 0 - assert load["local"] == 0 - load = launcher_service.get_load() - assert load["slurm"] == 12.0 / 64 - assert load["local"] == 0 - load = launcher_service.get_load() - assert load["slurm"] == 30.0 / 64 - assert load["local"] == 8.0 / os.cpu_count() + ) + launcher_service = LauncherService( + config=config, + study_service=study_service, + job_result_repository=job_repository, + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + job_repository.get_running.side_effect = [ + # call #1 + [], + # call #2 + [], + # call #3 + [ + Mock( + spec=JobResult, + launcher="slurm", + launcher_params=None, + ), + ], + # call #4 + [ + Mock( + spec=JobResult, + launcher="slurm", + launcher_params='{"nb_cpu": 18}', + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="slurm", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params='{"nb_cpu": 7}', + ), + ], + ] + + # call #1 + with pytest.raises(NotImplementedError): + launcher_service.get_load(from_cluster=True) + + # call #2 + load = launcher_service.get_load() + assert load["slurm"] == 0 + assert load["local"] == 0 + + # call #3 + load = launcher_service.get_load() + slurm_config = config.launcher.slurm + assert load["slurm"] == slurm_config.nb_cores.default / slurm_config.max_cores + assert load["local"] == 0 + + # call #4 + load = launcher_service.get_load() + local_config = config.launcher.local + assert load["slurm"] == (18 + slurm_config.nb_cores.default) / slurm_config.max_cores + assert load["local"] == (7 + local_config.nb_cores.default) / local_config.nb_cores.max diff --git a/tests/launcher/test_slurm_launcher.py b/tests/launcher/test_slurm_launcher.py index 7820abcdea..dfb6846e89 100644 --- a/tests/launcher/test_slurm_launcher.py +++ b/tests/launcher/test_slurm_launcher.py @@ -10,11 +10,9 @@ from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb from antareslauncher.main import MainParameters from antareslauncher.study_dto import StudyDTO -from sqlalchemy import create_engine +from sqlalchemy.orm import Session # type: ignore -from antarest.core.config import Config, LauncherConfig, SlurmConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.core.config import Config, LauncherConfig, NbCoresConfig, SlurmConfig from antarest.launcher.adapters.abstractlauncher import LauncherInitException from antarest.launcher.adapters.slurm_launcher.slurm_launcher import ( LOG_DIR_NAME, @@ -24,32 +22,34 @@ SlurmLauncher, VersionNotSupportedError, ) -from antarest.launcher.model import JobStatus, LauncherParametersDTO +from antarest.launcher.model import JobStatus, LauncherParametersDTO, XpansionParametersDTO from antarest.tools.admin_lib import clean_locks_from_config @pytest.fixture def launcher_config(tmp_path: Path) -> Config: - return Config( - launcher=LauncherConfig( - slurm=SlurmConfig( - local_workspace=tmp_path, - default_json_db_name="default_json_db_name", - slurm_script_path="slurm_script_path", - antares_versions_on_remote_server=["42", "45"], - username="username", - hostname="hostname", - port=42, - private_key_file=Path("private_key_file"), - key_password="key_password", - password="password", - ) - ) - ) + data = { + "local_workspace": tmp_path, + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["840", "850", "860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 1, "default": 34, "max": 36}, + } + return Config(launcher=LauncherConfig(slurm=SlurmConfig.from_dict(data))) @pytest.mark.unit_test -def test_slurm_launcher__launcher_init_exception(): +def test_slurm_launcher__launcher_init_exception() -> None: with pytest.raises( LauncherInitException, match="Missing parameter 'launcher.slurm'", @@ -63,13 +63,13 @@ def test_slurm_launcher__launcher_init_exception(): @pytest.mark.unit_test -def test_init_slurm_launcher_arguments(tmp_path: Path): +def test_init_slurm_launcher_arguments(tmp_path: Path) -> None: config = Config( launcher=LauncherConfig( slurm=SlurmConfig( default_wait_time=42, default_time_limit=43, - default_n_cpu=44, + nb_cores=NbCoresConfig(min=1, default=30, max=36), local_workspace=tmp_path, ) ) @@ -88,13 +88,15 @@ def test_init_slurm_launcher_arguments(tmp_path: Path): assert not arguments.xpansion_mode assert not arguments.version assert not arguments.post_processing - assert Path(arguments.studies_in) == config.launcher.slurm.local_workspace / "STUDIES_IN" - assert Path(arguments.output_dir) == config.launcher.slurm.local_workspace / "OUTPUT" - assert Path(arguments.log_dir) == config.launcher.slurm.local_workspace / "LOGS" + slurm_config = config.launcher.slurm + assert slurm_config is not None + assert Path(arguments.studies_in) == slurm_config.local_workspace / "STUDIES_IN" + assert Path(arguments.output_dir) == slurm_config.local_workspace / "OUTPUT" + assert Path(arguments.log_dir) == slurm_config.local_workspace / "LOGS" @pytest.mark.unit_test -def test_init_slurm_launcher_parameters(tmp_path: Path): +def test_init_slurm_launcher_parameters(tmp_path: Path) -> None: config = Config( launcher=LauncherConfig( slurm=SlurmConfig( @@ -115,23 +117,25 @@ def test_init_slurm_launcher_parameters(tmp_path: Path): slurm_launcher = SlurmLauncher(config=config, callbacks=Mock(), event_bus=Mock(), cache=Mock()) main_parameters = slurm_launcher._init_launcher_parameters() - assert main_parameters.json_dir == config.launcher.slurm.local_workspace - assert main_parameters.default_json_db_name == config.launcher.slurm.default_json_db_name - assert main_parameters.slurm_script_path == config.launcher.slurm.slurm_script_path - assert main_parameters.antares_versions_on_remote_server == config.launcher.slurm.antares_versions_on_remote_server + slurm_config = config.launcher.slurm + assert slurm_config is not None + assert main_parameters.json_dir == slurm_config.local_workspace + assert main_parameters.default_json_db_name == slurm_config.default_json_db_name + assert main_parameters.slurm_script_path == slurm_config.slurm_script_path + assert main_parameters.antares_versions_on_remote_server == slurm_config.antares_versions_on_remote_server assert main_parameters.default_ssh_dict == { - "username": config.launcher.slurm.username, - "hostname": config.launcher.slurm.hostname, - "port": config.launcher.slurm.port, - "private_key_file": config.launcher.slurm.private_key_file, - "key_password": config.launcher.slurm.key_password, - "password": config.launcher.slurm.password, + "username": slurm_config.username, + "hostname": slurm_config.hostname, + "port": slurm_config.port, + "private_key_file": slurm_config.private_key_file, + "key_password": slurm_config.key_password, + "password": slurm_config.password, } assert main_parameters.db_primary_key == "name" @pytest.mark.unit_test -def test_slurm_launcher_delete_function(tmp_path: str): +def test_slurm_launcher_delete_function(tmp_path: str) -> None: config = Config(launcher=LauncherConfig(slurm=SlurmConfig(local_workspace=Path(tmp_path)))) slurm_launcher = SlurmLauncher( config=config, @@ -155,64 +159,104 @@ def test_slurm_launcher_delete_function(tmp_path: str): assert not file_path.exists() -def test_extra_parameters(launcher_config: Config): +def test_extra_parameters(launcher_config: Config) -> None: + """ + The goal of this unit test is to control the protected method `_check_and_apply_launcher_params`, + which is called by the `SlurmLauncher.run_study` function, in a separate thread. + + The `_check_and_apply_launcher_params` method extract the parameters from the configuration + and populate a `argparse.Namespace` which is used to launch a simulation using Antares Launcher. + + We want to make sure all the parameters are populated correctly. + """ slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock(), ) - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO()) - assert launcher_params.n_cpu == 1 - assert launcher_params.time_limit == 0 + + apply_params = slurm_launcher._apply_params + launcher_params = apply_params(LauncherParametersDTO()) + slurm_config = slurm_launcher.config.launcher.slurm + assert slurm_config is not None + assert launcher_params.n_cpu == slurm_config.nb_cores.default + assert launcher_params.time_limit == slurm_config.default_time_limit assert not launcher_params.xpansion_mode assert not launcher_params.post_processing - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=12)) + launcher_params = apply_params(LauncherParametersDTO(other_options="")) + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(other_options="foo\tbar baz ")) + assert launcher_params.other_options == "foo bar baz" + + launcher_params = apply_params(LauncherParametersDTO(other_options="/foo?bar")) + assert launcher_params.other_options == "foobar" + + launcher_params = apply_params(LauncherParametersDTO(nb_cpu=12)) assert launcher_params.n_cpu == 12 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=48)) - assert launcher_params.n_cpu == 1 + launcher_params = apply_params(LauncherParametersDTO(nb_cpu=999)) + assert launcher_params.n_cpu == slurm_config.nb_cores.default # out of range - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=10)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=10)) assert launcher_params.time_limit == MIN_TIME_LIMIT - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=999999999)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=999999999)) assert launcher_params.time_limit == MAX_TIME_LIMIT - 3600 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=99999)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=99999)) assert launcher_params.time_limit == 99999 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(xpansion=True)) - assert launcher_params.xpansion_mode + launcher_params = apply_params(LauncherParametersDTO(xpansion=False)) + assert launcher_params.xpansion_mode is None + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=True)) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=True, xpansion_r_version=True)) + assert launcher_params.xpansion_mode == "r" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=False))) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=True))) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "xpansion_sensitivity" + + launcher_params = apply_params(LauncherParametersDTO(post_processing=False)) + assert launcher_params.post_processing is False - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(post_processing=True)) - assert launcher_params.post_processing + launcher_params = apply_params(LauncherParametersDTO(post_processing=True)) + assert launcher_params.post_processing is True - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(adequacy_patch={})) - assert launcher_params.post_processing + launcher_params = apply_params(LauncherParametersDTO(adequacy_patch={})) + assert launcher_params.post_processing is True # noinspection PyUnresolvedReferences @pytest.mark.parametrize( - "version, job_status", - [(42, JobStatus.RUNNING), (99, JobStatus.FAILED), (45, JobStatus.FAILED)], + "version, launcher_called, job_status", + [ + (840, True, JobStatus.RUNNING), + (860, False, JobStatus.FAILED), + pytest.param( + 999, False, JobStatus.FAILED, marks=pytest.mark.xfail(raises=VersionNotSupportedError, strict=True) + ), + ], ) @pytest.mark.unit_test def test_run_study( - tmp_path: Path, launcher_config: Config, version: int, + launcher_called: bool, job_status: JobStatus, -): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -231,7 +275,8 @@ def test_run_study( job_id = str(uuid.uuid4()) study_dir = argument.studies_in / job_id study_dir.mkdir(parents=True) - (study_dir / "study.antares").write_text( + study_antares_path = study_dir.joinpath("study.antares") + study_antares_path.write_text( textwrap.dedent( """\ [antares] @@ -242,22 +287,20 @@ def test_run_study( # noinspection PyUnusedLocal def call_launcher_mock(arguments: Namespace, parameters: MainParameters): - if version != 45: + if launcher_called: slurm_launcher.data_repo_tinydb.save_study(StudyDTO(job_id)) slurm_launcher._call_launcher = call_launcher_mock - if version == 99: - with pytest.raises(VersionNotSupportedError): - slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) - else: - slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) + # When the launcher is called + slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) + # Check the results assert ( version not in launcher_config.launcher.slurm.antares_versions_on_remote_server - or f"solver_version = {version}" in (study_dir / "study.antares").read_text(encoding="utf-8") + or f"solver_version = {version}" in study_antares_path.read_text(encoding="utf-8") ) - # slurm_launcher._clean_local_workspace.assert_called_once() + slurm_launcher.callbacks.export_study.assert_called_once() slurm_launcher.callbacks.update_status.assert_called_once_with(ANY, job_status, ANY, None) if job_status == JobStatus.RUNNING: @@ -266,7 +309,7 @@ def call_launcher_mock(arguments: Namespace, parameters: MainParameters): @pytest.mark.unit_test -def test_check_state(tmp_path: Path, launcher_config: Config): +def test_check_state(tmp_path: Path, launcher_config: Config) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -308,16 +351,7 @@ def test_check_state(tmp_path: Path, launcher_config: Config): @pytest.mark.unit_test -def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - +def test_clean_local_workspace(tmp_path: Path, launcher_config: Config) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -325,7 +359,6 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): use_private_workspace=False, cache=Mock(), ) - (launcher_config.launcher.slurm.local_workspace / "machin.txt").touch() assert os.listdir(launcher_config.launcher.slurm.local_workspace) @@ -335,7 +368,7 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): # noinspection PyUnresolvedReferences @pytest.mark.unit_test -def test_import_study_output(launcher_config, tmp_path): +def test_import_study_output(launcher_config, tmp_path) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -399,7 +432,7 @@ def test_kill_job( run_with_mock, tmp_path: Path, launcher_config: Config, -): +) -> None: launch_id = "launch_id" mock_study = Mock() mock_study.name = launch_id @@ -419,35 +452,36 @@ def test_kill_job( slurm_launcher.kill_job(job_id=launch_id) + slurm_config = launcher_config.launcher.slurm launcher_arguments = Namespace( antares_version=0, check_queue=False, - job_id_to_kill=42, + job_id_to_kill=mock_study.job_id, json_ssh_config=None, log_dir=str(tmp_path / "LOGS"), - n_cpu=1, + n_cpu=slurm_config.nb_cores.default, output_dir=str(tmp_path / "OUTPUT"), post_processing=False, studies_in=str(tmp_path / "STUDIES_IN"), - time_limit=0, + time_limit=slurm_config.default_time_limit, version=False, wait_mode=False, - wait_time=0, + wait_time=slurm_config.default_wait_time, xpansion_mode=None, other_options=None, ) launcher_parameters = MainParameters( json_dir=Path(tmp_path), - default_json_db_name="default_json_db_name", - slurm_script_path="slurm_script_path", - antares_versions_on_remote_server=["42", "45"], + default_json_db_name=slurm_config.default_json_db_name, + slurm_script_path=slurm_config.slurm_script_path, + antares_versions_on_remote_server=slurm_config.antares_versions_on_remote_server, default_ssh_dict={ - "username": "username", - "hostname": "hostname", - "port": 42, - "private_key_file": Path("private_key_file"), - "key_password": "key_password", - "password": "password", + "username": slurm_config.username, + "hostname": slurm_config.hostname, + "port": slurm_config.port, + "private_key_file": slurm_config.private_key_file, + "key_password": slurm_config.key_password, + "password": slurm_config.password, }, db_primary_key="name", ) @@ -456,7 +490,7 @@ def test_kill_job( @patch("antarest.launcher.adapters.slurm_launcher.slurm_launcher.run_with") -def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config): +def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config) -> None: callbacks = Mock() (tmp_path / LOG_DIR_NAME).mkdir() @@ -474,11 +508,7 @@ def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: clean_locks_from_config(launcher_config) assert not (workspaces[0] / WORKSPACE_LOCK_FILE_NAME).exists() - slurm_launcher.data_repo_tinydb.save_study( - StudyDTO( - path="somepath", - ) - ) + slurm_launcher.data_repo_tinydb.save_study(StudyDTO(path="some_path")) run_with_mock.assert_not_called() # will use existing private workspace