diff --git a/antarest/core/config.py b/antarest/core/config.py index 4c232fa475..b48be8ded0 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -1,16 +1,14 @@ -import logging +import multiprocessing import tempfile -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import yaml from antarest.core.model import JSON from antarest.core.roles import RoleType -logger = logging.getLogger(__name__) - @dataclass(frozen=True) class ExternalAuthConfig: @@ -23,13 +21,16 @@ class ExternalAuthConfig: add_ext_groups: bool = False group_mapping: Dict[str, str] = field(default_factory=dict) - @staticmethod - def from_dict(data: JSON) -> "ExternalAuthConfig": - return ExternalAuthConfig( - url=data.get("url", None), - default_group_role=RoleType(data.get("default_group_role", RoleType.READER.value)), - add_ext_groups=data.get("add_ext_groups", False), - group_mapping=data.get("group_mapping", {}), + @classmethod + def from_dict(cls, data: JSON) -> "ExternalAuthConfig": + defaults = cls() + return cls( + url=data.get("url", defaults.url), + default_group_role=( + RoleType(data["default_group_role"]) if "default_group_role" in data else defaults.default_group_role + ), + add_ext_groups=data.get("add_ext_groups", defaults.add_ext_groups), + group_mapping=data.get("group_mapping", defaults.group_mapping), ) @@ -44,13 +45,18 @@ class SecurityConfig: disabled: bool = False external_auth: ExternalAuthConfig = ExternalAuthConfig() - @staticmethod - def from_dict(data: JSON) -> "SecurityConfig": - return SecurityConfig( - jwt_key=data.get("jwt", {}).get("key", ""), - admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", ""), - disabled=data.get("disabled", False), - external_auth=ExternalAuthConfig.from_dict(data.get("external_auth", {})), + @classmethod + def from_dict(cls, data: JSON) -> "SecurityConfig": + defaults = cls() + return cls( + jwt_key=data.get("jwt", {}).get("key", defaults.jwt_key), + admin_pwd=data.get("login", {}).get("admin", {}).get("pwd", defaults.admin_pwd), + disabled=data.get("disabled", defaults.disabled), + external_auth=( + ExternalAuthConfig.from_dict(data["external_auth"]) + if "external_auth" in data + else defaults.external_auth + ), ) @@ -65,13 +71,14 @@ class WorkspaceConfig: groups: List[str] = field(default_factory=lambda: []) path: Path = Path() - @staticmethod - def from_dict(data: JSON) -> "WorkspaceConfig": - return WorkspaceConfig( - path=Path(data["path"]), - groups=data.get("groups", []), - filter_in=data.get("filter_in", [".*"]), - filter_out=data.get("filter_out", []), + @classmethod + def from_dict(cls, data: JSON) -> "WorkspaceConfig": + defaults = cls() + return cls( + filter_in=data.get("filter_in", defaults.filter_in), + filter_out=data.get("filter_out", defaults.filter_out), + groups=data.get("groups", defaults.groups), + path=Path(data["path"]) if "path" in data else defaults.path, ) @@ -91,18 +98,19 @@ class DbConfig: pool_size: int = 5 pool_use_lifo: bool = False - @staticmethod - def from_dict(data: JSON) -> "DbConfig": - return DbConfig( - db_admin_url=data.get("admin_url", None), - db_url=data.get("url", ""), - db_connect_timeout=data.get("db_connect_timeout", 10), - pool_recycle=data.get("pool_recycle", None), - pool_pre_ping=data.get("pool_pre_ping", False), - pool_use_null=data.get("pool_use_null", False), - pool_max_overflow=data.get("pool_max_overflow", 10), - pool_size=data.get("pool_size", 5), - pool_use_lifo=data.get("pool_use_lifo", False), + @classmethod + def from_dict(cls, data: JSON) -> "DbConfig": + defaults = cls() + return cls( + db_admin_url=data.get("admin_url", defaults.db_admin_url), + db_url=data.get("url", defaults.db_url), + db_connect_timeout=data.get("db_connect_timeout", defaults.db_connect_timeout), + pool_recycle=data.get("pool_recycle", defaults.pool_recycle), + pool_pre_ping=data.get("pool_pre_ping", defaults.pool_pre_ping), + pool_use_null=data.get("pool_use_null", defaults.pool_use_null), + pool_max_overflow=data.get("pool_max_overflow", defaults.pool_max_overflow), + pool_size=data.get("pool_size", defaults.pool_size), + pool_use_lifo=data.get("pool_use_lifo", defaults.pool_use_lifo), ) @@ -115,7 +123,7 @@ class StorageConfig: matrixstore: Path = Path("./matrixstore") archive_dir: Path = Path("./archives") tmp_dir: Path = Path(tempfile.gettempdir()) - workspaces: Dict[str, WorkspaceConfig] = field(default_factory=lambda: {}) + workspaces: Dict[str, WorkspaceConfig] = field(default_factory=dict) allow_deletion: bool = False watcher_lock: bool = True watcher_lock_delay: int = 10 @@ -127,39 +135,112 @@ class StorageConfig: auto_archive_sleeping_time: int = 3600 auto_archive_max_parallel: int = 5 - @staticmethod - def from_dict(data: JSON) -> "StorageConfig": - return StorageConfig( - tmp_dir=Path(data.get("tmp_dir", tempfile.gettempdir())), - matrixstore=Path(data["matrixstore"]), - workspaces={n: WorkspaceConfig.from_dict(w) for n, w in data["workspaces"].items()}, - allow_deletion=data.get("allow_deletion", False), - archive_dir=Path(data["archive_dir"]), - watcher_lock=data.get("watcher_lock", True), - watcher_lock_delay=data.get("watcher_lock_delay", 10), - download_default_expiration_timeout_minutes=data.get("download_default_expiration_timeout_minutes", 1440), - matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", 3600), - matrix_gc_dry_run=data.get("matrix_gc_dry_run", False), - auto_archive_threshold_days=data.get("auto_archive_threshold_days", 60), - auto_archive_dry_run=data.get("auto_archive_dry_run", False), - auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", 3600), - auto_archive_max_parallel=data.get("auto_archive_max_parallel", 5), + @classmethod + def from_dict(cls, data: JSON) -> "StorageConfig": + defaults = cls() + workspaces = ( + {key: WorkspaceConfig.from_dict(value) for key, value in data["workspaces"].items()} + if "workspaces" in data + else defaults.workspaces + ) + return cls( + matrixstore=Path(data["matrixstore"]) if "matrixstore" in data else defaults.matrixstore, + archive_dir=Path(data["archive_dir"]) if "archive_dir" in data else defaults.archive_dir, + tmp_dir=Path(data["tmp_dir"]) if "tmp_dir" in data else defaults.tmp_dir, + workspaces=workspaces, + allow_deletion=data.get("allow_deletion", defaults.allow_deletion), + watcher_lock=data.get("watcher_lock", defaults.watcher_lock), + watcher_lock_delay=data.get("watcher_lock_delay", defaults.watcher_lock_delay), + download_default_expiration_timeout_minutes=( + data.get( + "download_default_expiration_timeout_minutes", + defaults.download_default_expiration_timeout_minutes, + ) + ), + matrix_gc_sleeping_time=data.get("matrix_gc_sleeping_time", defaults.matrix_gc_sleeping_time), + matrix_gc_dry_run=data.get("matrix_gc_dry_run", defaults.matrix_gc_dry_run), + auto_archive_threshold_days=data.get("auto_archive_threshold_days", defaults.auto_archive_threshold_days), + auto_archive_dry_run=data.get("auto_archive_dry_run", defaults.auto_archive_dry_run), + auto_archive_sleeping_time=data.get("auto_archive_sleeping_time", defaults.auto_archive_sleeping_time), + auto_archive_max_parallel=data.get("auto_archive_max_parallel", defaults.auto_archive_max_parallel), ) +@dataclass(frozen=True) +class NbCoresConfig: + """ + The NBCoresConfig class is designed to manage the configuration of the number of CPU cores + """ + + min: int = 1 + default: int = 22 + max: int = 24 + + def to_json(self) -> Dict[str, int]: + """ + Retrieves the number of cores parameters, returning a dictionary containing the values "min" + (minimum allowed value), "defaultValue" (default value), and "max" (maximum allowed value) + + Returns: + A dictionary: `{"min": min, "defaultValue": default, "max": max}`. + Because ReactJs Material UI expects "min", "defaultValue" and "max" keys. + """ + return {"min": self.min, "defaultValue": self.default, "max": self.max} + + def __post_init__(self) -> None: + """validation of CPU configuration""" + if 1 <= self.min <= self.default <= self.max: + return + msg = f"Invalid configuration: 1 <= {self.min=} <= {self.default=} <= {self.max=}" + raise ValueError(msg) + + @dataclass(frozen=True) class LocalConfig: + """Sub config object dedicated to launcher module (local)""" + binaries: Dict[str, Path] = field(default_factory=dict) + enable_nb_cores_detection: bool = True + nb_cores: NbCoresConfig = NbCoresConfig() - @staticmethod - def from_dict(data: JSON) -> Optional["LocalConfig"]: - return LocalConfig( - binaries={str(v): Path(p) for v, p in data["binaries"].items()}, + @classmethod + def from_dict(cls, data: JSON) -> "LocalConfig": + """ + Creates an instance of LocalConfig from a data dictionary + Args: + data: Parse config from dict. + Returns: object NbCoresConfig + """ + defaults = cls() + binaries = data.get("binaries", defaults.binaries) + enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection) + nb_cores = data.get("nb_cores", asdict(defaults.nb_cores)) + if enable_nb_cores_detection: + nb_cores.update(cls._autodetect_nb_cores()) + return cls( + binaries={str(v): Path(p) for v, p in binaries.items()}, + enable_nb_cores_detection=enable_nb_cores_detection, + nb_cores=NbCoresConfig(**nb_cores), ) + @classmethod + def _autodetect_nb_cores(cls) -> Dict[str, int]: + """ + Automatically detects the number of cores available on the user's machine + Returns: Instance of NbCoresConfig + """ + min_cpu = cls.nb_cores.min + max_cpu = multiprocessing.cpu_count() + default = max(min_cpu, max_cpu - 2) + return {"min": min_cpu, "max": max_cpu, "default": default} + @dataclass(frozen=True) class SlurmConfig: + """ + Sub config object dedicated to launcher module (slurm) + """ + local_workspace: Path = Path() username: str = "" hostname: str = "" @@ -169,31 +250,68 @@ class SlurmConfig: password: str = "" default_wait_time: int = 0 default_time_limit: int = 0 - default_n_cpu: int = 1 default_json_db_name: str = "" slurm_script_path: str = "" max_cores: int = 64 antares_versions_on_remote_server: List[str] = field(default_factory=list) + enable_nb_cores_detection: bool = False + nb_cores: NbCoresConfig = NbCoresConfig() - @staticmethod - def from_dict(data: JSON) -> "SlurmConfig": - return SlurmConfig( - local_workspace=Path(data["local_workspace"]), - username=data["username"], - hostname=data["hostname"], - port=data["port"], - private_key_file=data.get("private_key_file", None), - key_password=data.get("key_password", None), - password=data.get("password", None), - default_wait_time=data["default_wait_time"], - default_time_limit=data["default_time_limit"], - default_n_cpu=data["default_n_cpu"], - default_json_db_name=data["default_json_db_name"], - slurm_script_path=data["slurm_script_path"], - antares_versions_on_remote_server=data["antares_versions_on_remote_server"], - max_cores=data.get("max_cores", 64), + @classmethod + def from_dict(cls, data: JSON) -> "SlurmConfig": + """ + Creates an instance of SlurmConfig from a data dictionary + + Args: + data: Parsed config from dict. + Returns: object SlurmConfig + """ + defaults = cls() + enable_nb_cores_detection = data.get("enable_nb_cores_detection", defaults.enable_nb_cores_detection) + nb_cores = data.get("nb_cores", asdict(defaults.nb_cores)) + if "default_n_cpu" in data: + # Use the old way to configure the NB cores for backward compatibility + nb_cores["default"] = int(data["default_n_cpu"]) + nb_cores["min"] = min(nb_cores["min"], nb_cores["default"]) + nb_cores["max"] = max(nb_cores["max"], nb_cores["default"]) + if enable_nb_cores_detection: + nb_cores.update(cls._autodetect_nb_cores()) + return cls( + local_workspace=Path(data.get("local_workspace", defaults.local_workspace)), + username=data.get("username", defaults.username), + hostname=data.get("hostname", defaults.hostname), + port=data.get("port", defaults.port), + private_key_file=data.get("private_key_file", defaults.private_key_file), + key_password=data.get("key_password", defaults.key_password), + password=data.get("password", defaults.password), + default_wait_time=data.get("default_wait_time", defaults.default_wait_time), + default_time_limit=data.get("default_time_limit", defaults.default_time_limit), + default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name), + slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path), + antares_versions_on_remote_server=data.get( + "antares_versions_on_remote_server", + defaults.antares_versions_on_remote_server, + ), + max_cores=data.get("max_cores", defaults.max_cores), + enable_nb_cores_detection=enable_nb_cores_detection, + nb_cores=NbCoresConfig(**nb_cores), ) + @classmethod + def _autodetect_nb_cores(cls) -> Dict[str, int]: + raise NotImplementedError("NB Cores auto-detection is not implemented for SLURM server") + + +class InvalidConfigurationError(Exception): + """ + Exception raised when an attempt is made to retrieve the number of cores + of a launcher that doesn't exist in the configuration. + """ + + def __init__(self, launcher: str): + msg = f"Configuration is not available for the '{launcher}' launcher" + super().__init__(msg) + @dataclass(frozen=True) class LauncherConfig: @@ -202,27 +320,53 @@ class LauncherConfig: """ default: str = "local" - local: Optional[LocalConfig] = LocalConfig() - slurm: Optional[SlurmConfig] = SlurmConfig() + local: Optional[LocalConfig] = None + slurm: Optional[SlurmConfig] = None batch_size: int = 9999 - @staticmethod - def from_dict(data: JSON) -> "LauncherConfig": - local: Optional[LocalConfig] = None - if "local" in data: - local = LocalConfig.from_dict(data["local"]) - - slurm: Optional[SlurmConfig] = None - if "slurm" in data: - slurm = SlurmConfig.from_dict(data["slurm"]) - - return LauncherConfig( - default=data.get("default", "local"), + @classmethod + def from_dict(cls, data: JSON) -> "LauncherConfig": + defaults = cls() + default = data.get("default", cls.default) + local = LocalConfig.from_dict(data["local"]) if "local" in data else defaults.local + slurm = SlurmConfig.from_dict(data["slurm"]) if "slurm" in data else defaults.slurm + batch_size = data.get("batch_size", defaults.batch_size) + return cls( + default=default, local=local, slurm=slurm, - batch_size=data.get("batch_size", 9999), + batch_size=batch_size, ) + def __post_init__(self) -> None: + possible = {"local", "slurm"} + if self.default in possible: + return + msg = f"Invalid configuration: {self.default=} must be one of {possible!r}" + raise ValueError(msg) + + def get_nb_cores(self, launcher: str) -> "NbCoresConfig": + """ + Retrieve the number of cores configuration for a given launcher: "local" or "slurm". + If "default" is specified, retrieve the configuration of the default launcher. + + Args: + launcher: type of launcher "local", "slurm" or "default". + + Returns: + Number of cores of the given launcher. + + Raises: + InvalidConfigurationError: Exception raised when an attempt is made to retrieve + the number of cores of a launcher that doesn't exist in the configuration. + """ + config_map = {"local": self.local, "slurm": self.slurm} + config_map["default"] = config_map[self.default] + launcher_config = config_map.get(launcher) + if launcher_config is None: + raise InvalidConfigurationError(launcher) + return launcher_config.nb_cores + @dataclass(frozen=True) class LoggingConfig: @@ -234,14 +378,13 @@ class LoggingConfig: json: bool = False level: str = "INFO" - @staticmethod - def from_dict(data: JSON) -> "LoggingConfig": - logging_config: Dict[str, Any] = data or {} - logfile: Optional[str] = logging_config.get("logfile") - return LoggingConfig( - logfile=Path(logfile) if logfile is not None else None, - json=logging_config.get("json", False), - level=logging_config.get("level", "INFO"), + @classmethod + def from_dict(cls, data: JSON) -> "LoggingConfig": + defaults = cls() + return cls( + logfile=Path(data["logfile"]) if "logfile" in data else defaults.logfile, + json=data.get("json", defaults.json), + level=data.get("level", defaults.level), ) @@ -255,12 +398,13 @@ class RedisConfig: port: int = 6379 password: Optional[str] = None - @staticmethod - def from_dict(data: JSON) -> "RedisConfig": - return RedisConfig( - host=data["host"], - port=data["port"], - password=data.get("password", None), + @classmethod + def from_dict(cls, data: JSON) -> "RedisConfig": + defaults = cls() + return cls( + host=data.get("host", defaults.host), + port=data.get("port", defaults.port), + password=data.get("password", defaults.password), ) @@ -271,9 +415,9 @@ class EventBusConfig: """ # noinspection PyUnusedLocal - @staticmethod - def from_dict(data: JSON) -> "EventBusConfig": - return EventBusConfig() + @classmethod + def from_dict(cls, data: JSON) -> "EventBusConfig": + return cls() @dataclass(frozen=True) @@ -284,10 +428,11 @@ class CacheConfig: checker_delay: float = 0.2 # in seconds - @staticmethod - def from_dict(data: JSON) -> "CacheConfig": - return CacheConfig( - checker_delay=float(data["checker_delay"]) if "checker_delay" in data else 0.2, + @classmethod + def from_dict(cls, data: JSON) -> "CacheConfig": + defaults = cls() + return cls( + checker_delay=data.get("checker_delay", defaults.checker_delay), ) @@ -296,9 +441,13 @@ class RemoteWorkerConfig: name: str queues: List[str] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "RemoteWorkerConfig": - return RemoteWorkerConfig(name=data["name"], queues=data.get("queues", [])) + @classmethod + def from_dict(cls, data: JSON) -> "RemoteWorkerConfig": + defaults = cls(name="") # `name` is mandatory + return cls( + name=data["name"], + queues=data.get("queues", defaults.queues), + ) @dataclass(frozen=True) @@ -310,16 +459,17 @@ class TaskConfig: max_workers: int = 5 remote_workers: List[RemoteWorkerConfig] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "TaskConfig": - return TaskConfig( - max_workers=int(data["max_workers"]) if "max_workers" in data else 5, - remote_workers=list( - map( - lambda x: RemoteWorkerConfig.from_dict(x), - data.get("remote_workers", []), - ) - ), + @classmethod + def from_dict(cls, data: JSON) -> "TaskConfig": + defaults = cls() + remote_workers = ( + [RemoteWorkerConfig.from_dict(d) for d in data["remote_workers"]] + if "remote_workers" in data + else defaults.remote_workers + ) + return cls( + max_workers=data.get("max_workers", defaults.max_workers), + remote_workers=remote_workers, ) @@ -332,11 +482,12 @@ class ServerConfig: worker_threadpool_size: int = 5 services: List[str] = field(default_factory=list) - @staticmethod - def from_dict(data: JSON) -> "ServerConfig": - return ServerConfig( - worker_threadpool_size=int(data["worker_threadpool_size"]) if "worker_threadpool_size" in data else 5, - services=data.get("services", []), + @classmethod + def from_dict(cls, data: JSON) -> "ServerConfig": + defaults = cls() + return cls( + worker_threadpool_size=data.get("worker_threadpool_size", defaults.worker_threadpool_size), + services=data.get("services", defaults.services), ) @@ -360,36 +511,27 @@ class Config: tasks: TaskConfig = TaskConfig() root_path: str = "" - @staticmethod - def from_dict(data: JSON, res: Optional[Path] = None) -> "Config": - """ - Parse config from dict. - - Args: - data: dict struct to parse - res: resources path is not present in yaml file. - - Returns: - - """ - return Config( - security=SecurityConfig.from_dict(data.get("security", {})), - storage=StorageConfig.from_dict(data["storage"]), - launcher=LauncherConfig.from_dict(data.get("launcher", {})), - db=DbConfig.from_dict(data["db"]) if "db" in data else DbConfig(), - logging=LoggingConfig.from_dict(data.get("logging", {})), - debug=data.get("debug", False), - resources_path=res or Path(), - root_path=data.get("root_path", ""), - redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else None, - eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else EventBusConfig(), - cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else CacheConfig(), - tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else TaskConfig(), - server=ServerConfig.from_dict(data["server"]) if "server" in data else ServerConfig(), + @classmethod + def from_dict(cls, data: JSON) -> "Config": + defaults = cls() + return cls( + server=ServerConfig.from_dict(data["server"]) if "server" in data else defaults.server, + security=SecurityConfig.from_dict(data["security"]) if "security" in data else defaults.security, + storage=StorageConfig.from_dict(data["storage"]) if "storage" in data else defaults.storage, + launcher=LauncherConfig.from_dict(data["launcher"]) if "launcher" in data else defaults.launcher, + db=DbConfig.from_dict(data["db"]) if "db" in data else defaults.db, + logging=LoggingConfig.from_dict(data["logging"]) if "logging" in data else defaults.logging, + debug=data.get("debug", defaults.debug), + resources_path=data["resources_path"] if "resources_path" in data else defaults.resources_path, + redis=RedisConfig.from_dict(data["redis"]) if "redis" in data else defaults.redis, + eventbus=EventBusConfig.from_dict(data["eventbus"]) if "eventbus" in data else defaults.eventbus, + cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else defaults.cache, + tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else defaults.tasks, + root_path=data.get("root_path", defaults.root_path), ) - @staticmethod - def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config": + @classmethod + def from_yaml_file(cls, file: Path, res: Optional[Path] = None) -> "Config": """ Parse config from yaml file. @@ -400,5 +542,8 @@ def from_yaml_file(file: Path, res: Optional[Path] = None) -> "Config": Returns: """ - data = yaml.safe_load(open(file)) - return Config.from_dict(data, res) + with open(file) as f: + data = yaml.safe_load(f) + if res is not None: + data["resources_path"] = res + return cls.from_dict(data) diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py index 926f2b50a3..00283b9ce8 100644 --- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py +++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py @@ -32,7 +32,6 @@ logger = logging.getLogger(__name__) logging.getLogger("paramiko").setLevel("WARN") -MAX_NB_CPU = 24 MAX_TIME_LIMIT = 864000 MIN_TIME_LIMIT = 3600 WORKSPACE_LOCK_FILE_NAME = ".lock" @@ -153,7 +152,7 @@ def _init_launcher_arguments(self, local_workspace: Optional[Path] = None) -> ar main_options_parameters = ParserParameters( default_wait_time=self.slurm_config.default_wait_time, default_time_limit=self.slurm_config.default_time_limit, - default_n_cpu=self.slurm_config.default_n_cpu, + default_n_cpu=self.slurm_config.nb_cores.default, studies_in_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_INPUT_DIR_NAME)), log_dir=str((Path(self.slurm_config.local_workspace) / LOG_DIR_NAME)), finished_dir=str((Path(local_workspace or self.slurm_config.local_workspace) / STUDIES_OUTPUT_DIR_NAME)), @@ -440,7 +439,7 @@ def _run_study( _override_solver_version(study_path, version) append_log(launch_uuid, "Submitting study to slurm launcher") - launcher_args = self._check_and_apply_launcher_params(launcher_params) + launcher_args = self._apply_params(launcher_params) self._call_launcher(launcher_args, self.launcher_params) launch_success = self._check_if_study_is_in_launcher_db(launch_uuid) @@ -481,23 +480,40 @@ def _check_if_study_is_in_launcher_db(self, job_id: str) -> bool: studies = self.data_repo_tinydb.get_list_of_studies() return any(s.name == job_id for s in studies) - def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace: + def _apply_params(self, launcher_params: LauncherParametersDTO) -> argparse.Namespace: + """ + Populate a `argparse.Namespace` object with the user parameters. + + Args: + launcher_params: + Contains the launcher parameters selected by the user. + If a parameter is not provided (`None`), the default value should be retrieved + from the configuration. + + Returns: + The `argparse.Namespace` object which is then passed to `antarestlauncher.main.run_with`, + to launch a simulation using Antares Launcher. + """ if launcher_params: launcher_args = deepcopy(self.launcher_args) - other_options = [] + if launcher_params.other_options: - options = re.split("\\s+", launcher_params.other_options) - for opt in options: - other_options.append(re.sub("[^a-zA-Z0-9_,-]", "", opt)) - if launcher_params.xpansion is not None: - launcher_args.xpansion_mode = "r" if launcher_params.xpansion_r_version else "cpp" + options = launcher_params.other_options.split() + other_options = [re.sub("[^a-zA-Z0-9_,-]", "", opt) for opt in options] + else: + other_options = [] + + # launcher_params.xpansion can be an `XpansionParametersDTO`, a bool or `None` + if launcher_params.xpansion: # not None and not False + launcher_args.xpansion_mode = {True: "r", False: "cpp"}[launcher_params.xpansion_r_version] if ( isinstance(launcher_params.xpansion, XpansionParametersDTO) and launcher_params.xpansion.sensitivity_mode ): other_options.append("xpansion_sensitivity") + time_limit = launcher_params.time_limit - if time_limit and isinstance(time_limit, int): + if time_limit is not None: if MIN_TIME_LIMIT > time_limit: logger.warning( f"Invalid slurm launcher time limit ({time_limit})," @@ -512,15 +528,23 @@ def _check_and_apply_launcher_params(self, launcher_params: LauncherParametersDT launcher_args.time_limit = MAX_TIME_LIMIT - 3600 else: launcher_args.time_limit = time_limit + post_processing = launcher_params.post_processing - if isinstance(post_processing, bool): + if post_processing is not None: launcher_args.post_processing = post_processing + nb_cpu = launcher_params.nb_cpu - if nb_cpu and isinstance(nb_cpu, int): - if 0 < nb_cpu <= MAX_NB_CPU: + if nb_cpu is not None: + nb_cores = self.slurm_config.nb_cores + if nb_cores.min <= nb_cpu <= nb_cores.max: launcher_args.n_cpu = nb_cpu else: - logger.warning(f"Invalid slurm launcher nb_cpu ({nb_cpu}), should be between 1 and 24") + logger.warning( + f"Invalid slurm launcher nb_cpu ({nb_cpu})," + f" should be between {nb_cores.min} and {nb_cores.max}" + ) + launcher_args.n_cpu = nb_cores.default + if launcher_params.adequacy_patch is not None: # the adequacy patch can be an empty object launcher_args.post_processing = True diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index 7a3c615811..a9bf0f6fde 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -17,14 +17,14 @@ class XpansionParametersDTO(BaseModel): class LauncherParametersDTO(BaseModel): - # Warning ! This class must be retrocompatible (that's the reason for the weird bool/XpansionParametersDTO union) + # Warning ! This class must be retro-compatible (that's the reason for the weird bool/XpansionParametersDTO union) # The reason is that it's stored in json format in database and deserialized using the latest class version # If compatibility is to be broken, an (alembic) data migration script should be added adequacy_patch: Optional[Dict[str, Any]] = None nb_cpu: Optional[int] = None post_processing: bool = False - time_limit: Optional[int] = None - xpansion: Union[bool, Optional[XpansionParametersDTO]] = None + time_limit: Optional[int] = None # 3600 <= time_limit < 864000 (10 days) + xpansion: Union[XpansionParametersDTO, bool, None] = None xpansion_r_version: bool = False archive_output: bool = True auto_unzip: bool = True diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 340c00c27c..aa081e5462 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -1,4 +1,5 @@ import functools +import json import logging import os import shutil @@ -10,7 +11,7 @@ from fastapi import HTTPException -from antarest.core.config import Config +from antarest.core.config import Config, NbCoresConfig from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager @@ -99,6 +100,21 @@ def _init_extensions(self) -> Dict[str, ILauncherExtension]: def get_launchers(self) -> List[str]: return list(self.launchers.keys()) + def get_nb_cores(self, launcher: str) -> NbCoresConfig: + """ + Retrieve the configuration of the launcher's nb of cores. + + Args: + launcher: name of the launcher: "default", "slurm" or "local". + + Returns: + Number of cores of the launcher + + Raises: + InvalidConfigurationError: if the launcher configuration is not available + """ + return self.config.launcher.get_nb_cores(launcher) + def _after_export_flat_hooks( self, job_id: str, @@ -586,27 +602,31 @@ def get_load(self, from_cluster: bool = False) -> Dict[str, float]: local_running_jobs.append(job) else: logger.warning(f"Unknown job launcher {job.launcher}") + load = {} - if self.config.launcher.slurm: + + slurm_config = self.config.launcher.slurm + if slurm_config is not None: if from_cluster: - raise NotImplementedError - slurm_used_cpus = functools.reduce( - lambda count, j: count - + ( - LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu - or self.config.launcher.slurm.default_n_cpu # type: ignore - ), - slurm_running_jobs, - 0, - ) - load["slurm"] = float(slurm_used_cpus) / self.config.launcher.slurm.max_cores - if self.config.launcher.local: - local_used_cpus = functools.reduce( - lambda count, j: count + (LauncherParametersDTO.parse_raw(j.launcher_params or "{}").nb_cpu or 1), - local_running_jobs, - 0, - ) - load["local"] = float(local_used_cpus) / (os.cpu_count() or 1) + raise NotImplementedError("Cluster load not implemented yet") + default_cpu = slurm_config.nb_cores.default + slurm_used_cpus = 0 + for job in slurm_running_jobs: + obj = json.loads(job.launcher_params) if job.launcher_params else {} + launch_params = LauncherParametersDTO(**obj) + slurm_used_cpus += launch_params.nb_cpu or default_cpu + load["slurm"] = slurm_used_cpus / slurm_config.max_cores + + local_config = self.config.launcher.local + if local_config is not None: + default_cpu = local_config.nb_cores.default + local_used_cpus = 0 + for job in local_running_jobs: + obj = json.loads(job.launcher_params) if job.launcher_params else {} + launch_params = LauncherParametersDTO(**obj) + local_used_cpus += launch_params.nb_cpu or default_cpu + load["local"] = local_used_cpus / local_config.nb_cores.max + return load def get_solver_versions(self, solver: str) -> List[str]: diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index ffb4cf6ccf..51b3582997 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, Query from fastapi.exceptions import HTTPException -from antarest.core.config import Config +from antarest.core.config import Config, InvalidConfigurationError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters @@ -230,4 +230,48 @@ def get_solver_versions( raise UnknownSolverConfig(solver) return service.get_solver_versions(solver) + # noinspection SpellCheckingInspection + @bp.get( + "/launcher/nbcores", # We avoid "nb_cores" and "nb-cores" in endpoints + tags=[APITag.launcher], + summary="Retrieving Min, Default, and Max Core Count", + response_model=Dict[str, int], + ) + def get_nb_cores( + launcher: str = Query( + "default", + examples={ + "Default launcher": { + "description": "Min, Default, and Max Core Count", + "value": "default", + }, + "SLURM launcher": { + "description": "Min, Default, and Max Core Count", + "value": "slurm", + }, + "Local launcher": { + "description": "Min, Default, and Max Core Count", + "value": "local", + }, + }, + ) + ) -> Dict[str, int]: + """ + Retrieve the numer of cores of the launcher. + + Args: + - `launcher`: name of the configuration to read: "slurm" or "local". + If "default" is specified, retrieve the configuration of the default launcher. + + Returns: + - "min": min number of cores + - "defaultValue": default number of cores + - "max": max number of cores + """ + logger.info(f"Fetching the number of cores for the '{launcher}' configuration") + try: + return service.config.launcher.get_nb_cores(launcher).to_json() + except InvalidConfigurationError: + raise UnknownSolverConfig(launcher) + return bp diff --git a/resources/application.yaml b/resources/application.yaml index 75a7b8605e..a85357634f 100644 --- a/resources/application.yaml +++ b/resources/application.yaml @@ -20,9 +20,9 @@ db: #pool_recycle: storage: - tmp_dir: /tmp + tmp_dir: ./tmp matrixstore: ./matrices - archive_dir: examples/archives + archive_dir: ./examples/archives allow_deletion: false # indicate if studies found in non default workspace can be deleted by the application #matrix_gc_sleeping_time: 3600 # time in seconds to sleep between two garbage collection #matrix_gc_dry_run: False # Skip matrix effective deletion @@ -32,20 +32,23 @@ storage: #auto_archive_max_parallel: 5 # max auto archival tasks in parallel workspaces: default: # required, no filters applied, this folder is not watched - path: examples/internal_studies/ + path: ./examples/internal_studies/ # other workspaces can be added # if a directory is to be ignored by the watcher, place a file named AW_NO_SCAN inside tmp: - path: examples/studies/ + path: ./examples/studies/ # filter_in: ['.*'] # default to '.*' # filter_out: [] # default to empty # groups: [] # default empty launcher: default: local + local: binaries: 700: path/to/700 + enable_nb_cores_detection: true + # slurm: # local_workspace: path/to/workspace # username: username @@ -56,7 +59,11 @@ launcher: # password: password_is_optional_but_necessary_if_key_is_absent # default_wait_time: 900 # default_time_limit: 172800 -# default_n_cpu: 12 +# enable_nb_cores_detection: False +# nb_cores: +# min: 1 +# default: 22 +# max: 24 # default_json_db_name: launcher_db.json # slurm_script_path: /path/to/launchantares_v1.1.3.sh # db_primary_key: name @@ -70,7 +77,7 @@ launcher: debug: true -root_path: "" +root_path: "api" #tasks: # max_workers: 5 diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml index 1bb5e30878..02fbb4b8bc 100644 --- a/resources/deploy/config.prod.yaml +++ b/resources/deploy/config.prod.yaml @@ -32,9 +32,12 @@ storage: launcher: default: local + local: binaries: 800: /antares_simulator/antares-8.2-solver + enable_nb_cores_detection: true + # slurm: # local_workspace: path/to/workspace # username: username @@ -45,7 +48,11 @@ launcher: # password: password_is_optional_but_necessary_if_key_is_absent # default_wait_time: 900 # default_time_limit: 172800 -# default_n_cpu: 12 +# enable_nb_cores_detection: False +# nb_cores: +# min: 1 +# default: 22 +# max: 24 # default_json_db_name: launcher_db.json # slurm_script_path: /path/to/launchantares_v1.1.3.sh # db_primary_key: name @@ -59,7 +66,7 @@ launcher: debug: false -root_path: "/api" +root_path: "api" #tasks: # max_workers: 5 diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml index 48cea48a22..810e1f8d24 100644 --- a/resources/deploy/config.yaml +++ b/resources/deploy/config.yaml @@ -29,9 +29,12 @@ storage: launcher: default: local + local: binaries: 700: path/to/700 + enable_nb_cores_detection: true + # slurm: # local_workspace: path/to/workspace # username: username @@ -42,7 +45,11 @@ launcher: # password: password_is_optional_but_necessary_if_key_is_absent # default_wait_time: 900 # default_time_limit: 172800 -# default_n_cpu: 12 +# enable_nb_cores_detection: False +# nb_cores: +# min: 1 +# default: 22 +# max: 24 # default_json_db_name: launcher_db.json # slurm_script_path: /path/to/launchantares_v1.1.3.sh # db_primary_key: name diff --git a/tests/conftest_db.py b/tests/conftest_db.py index 877ca119d1..bcb4177766 100644 --- a/tests/conftest_db.py +++ b/tests/conftest_db.py @@ -3,7 +3,8 @@ import pytest from sqlalchemy import create_engine # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.engine.base import Engine # type: ignore +from sqlalchemy.orm import Session, sessionmaker # type: ignore from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base @@ -12,7 +13,7 @@ @pytest.fixture(name="db_engine") -def db_engine_fixture() -> Generator[Any, None, None]: +def db_engine_fixture() -> Generator[Engine, None, None]: """ Fixture that creates an in-memory SQLite database engine for testing. @@ -26,7 +27,7 @@ def db_engine_fixture() -> Generator[Any, None, None]: @pytest.fixture(name="db_session") -def db_session_fixture(db_engine) -> Generator: +def db_session_fixture(db_engine: Engine) -> Generator[Session, None, None]: """ Fixture that creates a database session for testing purposes. @@ -46,7 +47,7 @@ def db_session_fixture(db_engine) -> Generator: @pytest.fixture(name="db_middleware", autouse=True) def db_middleware_fixture( - db_engine: Any, + db_engine: Engine, ) -> Generator[DBSessionMiddleware, None, None]: """ Fixture that sets up a database session middleware with custom engine settings. diff --git a/tests/core/assets/__init__.py b/tests/core/assets/__init__.py new file mode 100644 index 0000000000..773f16ec60 --- /dev/null +++ b/tests/core/assets/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path + +ASSETS_DIR = Path(__file__).parent.resolve() diff --git a/tests/core/assets/config/application-2.14.yaml b/tests/core/assets/config/application-2.14.yaml new file mode 100644 index 0000000000..650093286d --- /dev/null +++ b/tests/core/assets/config/application-2.14.yaml @@ -0,0 +1,61 @@ +security: + disabled: false + jwt: + key: super-secret + login: + admin: + pwd: admin + +db: + url: "sqlite:////home/john/antares_data/database.db" + +storage: + tmp_dir: /tmp + matrixstore: /home/john/antares_data/matrices + archive_dir: /home/john/antares_data/archives + allow_deletion: false + workspaces: + default: + path: /home/john/antares_data/internal_studies/ + studies: + path: /home/john/antares_data/studies/ + +launcher: + default: slurm + local: + binaries: + 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver + 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver + + slurm: + local_workspace: /home/john/antares_data/slurm_workspace + + username: antares + hostname: slurm-prod-01 + + port: 22 + private_key_file: /home/john/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + default_n_cpu: 20 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '850' # 8.5.1/antares-8.5-solver + - '860' # 8.6.2/antares-8.6-solver + - '870' # 8.7.0/antares-8.7-solver + +debug: false + +root_path: "" + +server: + worker_threadpool_size: 12 + services: + - watcher + - matrix_gc + +logging: + level: INFO diff --git a/tests/core/assets/config/application-2.15.yaml b/tests/core/assets/config/application-2.15.yaml new file mode 100644 index 0000000000..c51d32aaae --- /dev/null +++ b/tests/core/assets/config/application-2.15.yaml @@ -0,0 +1,66 @@ +security: + disabled: false + jwt: + key: super-secret + login: + admin: + pwd: admin + +db: + url: "sqlite:////home/john/antares_data/database.db" + +storage: + tmp_dir: /tmp + matrixstore: /home/john/antares_data/matrices + archive_dir: /home/john/antares_data/archives + allow_deletion: false + workspaces: + default: + path: /home/john/antares_data/internal_studies/ + studies: + path: /home/john/antares_data/studies/ + +launcher: + default: slurm + local: + binaries: + 850: /home/john/opt/antares-8.5.0-Ubuntu-20.04/antares-solver + 860: /home/john/opt/antares-8.6.0-Ubuntu-20.04/antares-8.6-solver + enable_nb_cores_detection: True + + slurm: + local_workspace: /home/john/antares_data/slurm_workspace + + username: antares + hostname: slurm-prod-01 + + port: 22 + private_key_file: /home/john/.ssh/id_rsa + key_password: + default_wait_time: 900 + default_time_limit: 172800 + enable_nb_cores_detection: False + nb_cores: + min: 1 + default: 22 + max: 24 + default_json_db_name: launcher_db.json + slurm_script_path: /applis/antares/launchAntares.sh + db_primary_key: name + antares_versions_on_remote_server: + - '850' # 8.5.1/antares-8.5-solver + - '860' # 8.6.2/antares-8.6-solver + - '870' # 8.7.0/antares-8.7-solver + +debug: false + +root_path: "" + +server: + worker_threadpool_size: 12 + services: + - watcher + - matrix_gc + +logging: + level: INFO diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 1c1c96a180..00c6f9458d 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,15 +1,253 @@ from pathlib import Path +from unittest import mock import pytest -from antarest.core.config import Config +from antarest.core.config import ( + Config, + InvalidConfigurationError, + LauncherConfig, + LocalConfig, + NbCoresConfig, + SlurmConfig, +) +from tests.core.assets import ASSETS_DIR +LAUNCHER_CONFIG = { + "default": "slurm", + "local": { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + }, + "slurm": { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 1, "default": 34, "max": 36}, + }, + "batch_size": 100, +} -@pytest.mark.unit_test -def test_get_yaml(project_path: Path): - config = Config.from_yaml_file(file=project_path / "resources/application.yaml") - assert config.security.admin_pwd == "admin" - assert config.storage.workspaces["default"].path == Path("examples/internal_studies/") - assert not config.logging.json - assert config.logging.level == "INFO" +class TestNbCoresConfig: + def test_init__default_values(self): + config = NbCoresConfig() + assert config.min == 1 + assert config.default == 22 + assert config.max == 24 + + def test_init__invalid_values(self): + with pytest.raises(ValueError): + # default < min + NbCoresConfig(min=2, default=1, max=24) + with pytest.raises(ValueError): + # default > max + NbCoresConfig(min=1, default=25, max=24) + with pytest.raises(ValueError): + # min < 0 + NbCoresConfig(min=0, default=22, max=23) + with pytest.raises(ValueError): + # min > max + NbCoresConfig(min=22, default=22, max=21) + + def test_to_json(self): + config = NbCoresConfig() + # ReactJs Material UI expects "min", "defaultValue" and "max" keys + assert config.to_json() == {"min": 1, "defaultValue": 22, "max": 24} + + +class TestLocalConfig: + def test_init__default_values(self): + config = LocalConfig() + assert config.binaries == {}, "binaries should be empty by default" + assert config.enable_nb_cores_detection, "nb cores auto-detection should be enabled by default" + assert config.nb_cores == NbCoresConfig() + + def test_from_dict(self): + config = LocalConfig.from_dict( + { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + } + ) + assert config.binaries == {"860": Path("/bin/solver-860.exe")} + assert not config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) + + def test_from_dict__auto_detect(self): + with mock.patch("multiprocessing.cpu_count", return_value=8): + config = LocalConfig.from_dict( + { + "binaries": {"860": Path("/bin/solver-860.exe")}, + "enable_nb_cores_detection": True, + } + ) + assert config.binaries == {"860": Path("/bin/solver-860.exe")} + assert config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=1, default=6, max=8) + + +class TestSlurmConfig: + def test_init__default_values(self): + config = SlurmConfig() + assert config.local_workspace == Path() + assert config.username == "" + assert config.hostname == "" + assert config.port == 0 + assert config.private_key_file == Path() + assert config.key_password == "" + assert config.password == "" + assert config.default_wait_time == 0 + assert config.default_time_limit == 0 + assert config.default_json_db_name == "" + assert config.slurm_script_path == "" + assert config.max_cores == 64 + assert config.antares_versions_on_remote_server == [], "solver versions should be empty by default" + assert not config.enable_nb_cores_detection, "nb cores auto-detection shouldn't be enabled by default" + assert config.nb_cores == NbCoresConfig() + + def test_from_dict(self): + config = SlurmConfig.from_dict( + { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 2, "default": 10, "max": 20}, + } + ) + assert config.local_workspace == Path("/home/john/antares/workspace") + assert config.username == "john" + assert config.hostname == "slurm-001" + assert config.port == 22 + assert config.private_key_file == Path("/home/john/.ssh/id_rsa") + assert config.key_password == "password" + assert config.password == "password" + assert config.default_wait_time == 10 + assert config.default_time_limit == 20 + assert config.default_json_db_name == "antares.db" + assert config.slurm_script_path == "/path/to/slurm/launcher.sh" + assert config.max_cores == 32 + assert config.antares_versions_on_remote_server == ["860"] + assert not config.enable_nb_cores_detection + assert config.nb_cores == NbCoresConfig(min=2, default=10, max=20) + + def test_from_dict__default_n_cpu__backport(self): + config = SlurmConfig.from_dict( + { + "local_workspace": Path("/home/john/antares/workspace"), + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["860"], + "default_n_cpu": 15, + } + ) + assert config.nb_cores == NbCoresConfig(min=1, default=15, max=24) + + def test_from_dict__auto_detect(self): + with pytest.raises(NotImplementedError): + SlurmConfig.from_dict({"enable_nb_cores_detection": True}) + + +class TestLauncherConfig: + def test_init__default_values(self): + config = LauncherConfig() + assert config.default == "local", "default launcher should be local" + assert config.local is None + assert config.slurm is None + assert config.batch_size == 9999 + + def test_from_dict(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.default == "slurm" + assert config.local == LocalConfig( + binaries={"860": Path("/bin/solver-860.exe")}, + enable_nb_cores_detection=False, + nb_cores=NbCoresConfig(min=2, default=10, max=20), + ) + assert config.slurm == SlurmConfig( + local_workspace=Path("/home/john/antares/workspace"), + username="john", + hostname="slurm-001", + port=22, + private_key_file=Path("/home/john/.ssh/id_rsa"), + key_password="password", + password="password", + default_wait_time=10, + default_time_limit=20, + default_json_db_name="antares.db", + slurm_script_path="/path/to/slurm/launcher.sh", + max_cores=32, + antares_versions_on_remote_server=["860"], + enable_nb_cores_detection=False, + nb_cores=NbCoresConfig(min=1, default=34, max=36), + ) + assert config.batch_size == 100 + + def test_init__invalid_launcher(self): + with pytest.raises(ValueError): + LauncherConfig(default="invalid_launcher") + + def test_get_nb_cores__default(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + # default == "slurm" + assert config.get_nb_cores(launcher="default") == NbCoresConfig(min=1, default=34, max=36) + + def test_get_nb_cores__local(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.get_nb_cores(launcher="local") == NbCoresConfig(min=2, default=10, max=20) + + def test_get_nb_cores__slurm(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + assert config.get_nb_cores(launcher="slurm") == NbCoresConfig(min=1, default=34, max=36) + + def test_get_nb_cores__invalid_configuration(self): + config = LauncherConfig.from_dict(LAUNCHER_CONFIG) + with pytest.raises(InvalidConfigurationError): + config.get_nb_cores("invalid_launcher") + config = LauncherConfig.from_dict({}) + with pytest.raises(InvalidConfigurationError): + config.get_nb_cores("slurm") + + +class TestConfig: + @pytest.mark.parametrize("config_name", ["application-2.14.yaml", "application-2.15.yaml"]) + def test_from_yaml_file(self, config_name: str) -> None: + yaml_path = ASSETS_DIR.joinpath("config", config_name) + config = Config.from_yaml_file(yaml_path) + assert config.security.admin_pwd == "admin" + assert config.storage.workspaces["default"].path == Path("/home/john/antares_data/internal_studies") + assert not config.logging.json + assert config.logging.level == "INFO" diff --git a/tests/integration/assets/config.template.yml b/tests/integration/assets/config.template.yml index f3ad1d256f..71c58a1ba0 100644 --- a/tests/integration/assets/config.template.yml +++ b/tests/integration/assets/config.template.yml @@ -27,6 +27,7 @@ launcher: local: binaries: 700: {{launcher_mock}} + enable_nb_cores_detection: True debug: false diff --git a/tests/integration/launcher_blueprint/test_launcher_local.py b/tests/integration/launcher_blueprint/test_launcher_local.py new file mode 100644 index 0000000000..7244fba8ee --- /dev/null +++ b/tests/integration/launcher_blueprint/test_launcher_local.py @@ -0,0 +1,70 @@ +import http + +import pytest +from starlette.testclient import TestClient + +from antarest.core.config import LocalConfig + + +# noinspection SpellCheckingInspection +@pytest.mark.integration_test +class TestLauncherNbCores: + """ + The purpose of this unit test is to check the `/v1/launcher/nbcores` endpoint. + """ + + def test_get_launcher_nb_cores( + self, + client: TestClient, + user_access_token: str, + ) -> None: + # NOTE: we have `enable_nb_cores_detection: True` in `tests/integration/assets/config.template.yml`. + local_nb_cores = LocalConfig.from_dict({"enable_nb_cores_detection": True}).nb_cores + nb_cores_expected = local_nb_cores.to_json() + res = client.get( + "/v1/launcher/nbcores", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + res.raise_for_status() + actual = res.json() + assert actual == nb_cores_expected + + res = client.get( + "/v1/launcher/nbcores?launcher=default", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + res.raise_for_status() + actual = res.json() + assert actual == nb_cores_expected + + res = client.get( + "/v1/launcher/nbcores?launcher=local", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + res.raise_for_status() + actual = res.json() + assert actual == nb_cores_expected + + # Check that the endpoint raise an exception when the "slurm" launcher is requested. + res = client.get( + "/v1/launcher/nbcores?launcher=slurm", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() + actual = res.json() + assert actual == { + "description": "Unknown solver configuration: 'slurm'", + "exception": "UnknownSolverConfig", + } + + # Check that the endpoint raise an exception when an unknown launcher is requested. + res = client.get( + "/v1/launcher/nbcores?launcher=unknown", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() + actual = res.json() + assert actual == { + "description": "Unknown solver configuration: 'unknown'", + "exception": "UnknownSolverConfig", + } diff --git a/tests/launcher/test_local_launcher.py b/tests/launcher/test_local_launcher.py index 04741d319a..53adc03bf0 100644 --- a/tests/launcher/test_local_launcher.py +++ b/tests/launcher/test_local_launcher.py @@ -1,19 +1,28 @@ import os import textwrap +import uuid from pathlib import Path from unittest.mock import Mock, call -from uuid import uuid4 import pytest -from sqlalchemy import create_engine from antarest.core.config import Config, LauncherConfig, LocalConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.launcher.adapters.abstractlauncher import LauncherInitException from antarest.launcher.adapters.local_launcher.local_launcher import LocalLauncher from antarest.launcher.model import JobStatus, LauncherParametersDTO +SOLVER_NAME = "solver.bat" if os.name == "nt" else "solver.sh" + + +@pytest.fixture +def launcher_config(tmp_path: Path) -> Config: + """ + Fixture to create a launcher config with a local launcher. + """ + solver_path = tmp_path.joinpath(SOLVER_NAME) + data = {"binaries": {"700": solver_path}, "enable_nb_cores_detection": True} + return Config(launcher=LauncherConfig(local=LocalConfig.from_dict(data))) + @pytest.mark.unit_test def test_local_launcher__launcher_init_exception(): @@ -30,21 +39,12 @@ def test_local_launcher__launcher_init_exception(): @pytest.mark.unit_test -def test_compute(tmp_path: Path): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - local_launcher = LocalLauncher(Config(), callbacks=Mock(), event_bus=Mock(), cache=Mock()) +def test_compute(tmp_path: Path, launcher_config: Config): + local_launcher = LocalLauncher(launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock()) # prepare a dummy executable to simulate Antares Solver if os.name == "nt": - solver_name = "solver.bat" - solver_path = tmp_path.joinpath(solver_name) + solver_path = tmp_path.joinpath(SOLVER_NAME) solver_path.write_text( textwrap.dedent( """\ @@ -55,8 +55,7 @@ def test_compute(tmp_path: Path): ) ) else: - solver_name = "solver.sh" - solver_path = tmp_path.joinpath(solver_name) + solver_path = tmp_path.joinpath(SOLVER_NAME) solver_path.write_text( textwrap.dedent( """\ @@ -68,8 +67,8 @@ def test_compute(tmp_path: Path): ) solver_path.chmod(0o775) - uuid = uuid4() - local_launcher.job_id_to_study_id = {str(uuid): ("study-id", tmp_path / "run", Mock())} + study_id = uuid.uuid4() + local_launcher.job_id_to_study_id = {str(study_id): ("study-id", tmp_path / "run", Mock())} local_launcher.callbacks.import_output.return_value = "some output" launcher_parameters = LauncherParametersDTO( adequacy_patch=None, @@ -86,15 +85,15 @@ def test_compute(tmp_path: Path): local_launcher._compute( antares_solver_path=solver_path, study_uuid="study-id", - uuid=uuid, + uuid=study_id, launcher_parameters=launcher_parameters, ) # noinspection PyUnresolvedReferences local_launcher.callbacks.update_status.assert_has_calls( [ - call(str(uuid), JobStatus.RUNNING, None, None), - call(str(uuid), JobStatus.SUCCESS, None, "some output"), + call(str(study_id), JobStatus.RUNNING, None, None), + call(str(study_id), JobStatus.SUCCESS, None, "some output"), ] ) diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py index 2c9f94d89c..a6177c5e61 100644 --- a/tests/launcher/test_service.py +++ b/tests/launcher/test_service.py @@ -3,15 +3,24 @@ import time from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, List, Literal, Union +from typing import Dict, List, Union from unittest.mock import Mock, call, patch from uuid import uuid4 from zipfile import ZIP_DEFLATED, ZipFile import pytest from sqlalchemy import create_engine - -from antarest.core.config import Config, LauncherConfig, LocalConfig, SlurmConfig, StorageConfig +from typing_extensions import Literal + +from antarest.core.config import ( + Config, + InvalidConfigurationError, + LauncherConfig, + LocalConfig, + NbCoresConfig, + SlurmConfig, + StorageConfig, +) from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownload, FileDownloadDTO, FileDownloadTaskDTO from antarest.core.interfaces.eventbus import Event, EventType @@ -20,7 +29,7 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.dbmodel import Base -from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LogType +from antarest.launcher.model import JobLog, JobLogType, JobResult, JobStatus, LauncherParametersDTO, LogType from antarest.launcher.service import ( EXECUTION_INFO_FILE, LAUNCHER_PARAM_NAME_SUFFIX, @@ -33,780 +42,908 @@ from antarest.study.model import OwnerInfo, PublicMode, Study, StudyMetadataDTO -@pytest.mark.unit_test -@patch.object(Auth, "get_current_user") -def test_service_run_study(get_current_user_mock): - get_current_user_mock.return_value = None - storage_service_mock = Mock() - storage_service_mock.get_study_information.return_value = StudyMetadataDTO( - id="id", - name="name", - created=1, - updated=1, - type="rawstudy", - owner=OwnerInfo(id=0, name="author"), - groups=[], - public_mode=PublicMode.NONE, - version=42, - workspace="default", - managed=True, - archived=False, - ) - storage_service_mock.get_study_path.return_value = Path("path/to/study") - - uuid = uuid4() - launcher_mock = Mock() - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - event_bus = Mock() - - pending = JobResult( - id=str(uuid), - study_id="study_uuid", - job_status=JobStatus.PENDING, - launcher="local", - ) - repository = Mock() - repository.save.return_value = pending - - launcher_service = LauncherService( - config=Config(), - study_service=storage_service_mock, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=event_bus, - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher_service._generate_new_id = lambda: str(uuid) - - job_id = launcher_service.run_study( - "study_uuid", - "local", - None, - RequestParameters( - user=JWTUser( - id=0, - impersonator=0, - type="users", - ) - ), - ) - - assert job_id == str(uuid) - repository.save.assert_called_once_with(pending) - event_bus.push.assert_called_once_with( - Event( - type=EventType.STUDY_JOB_STARTED, - payload=pending.to_dto().dict(), - permissions=PermissionInfo(owner=0), +class TestLauncherService: + @pytest.mark.unit_test + @patch.object(Auth, "get_current_user") + def test_service_run_study(self, get_current_user_mock) -> None: + get_current_user_mock.return_value = None + storage_service_mock = Mock() + # noinspection SpellCheckingInspection + storage_service_mock.get_study_information.return_value = StudyMetadataDTO( + id="id", + name="name", + created="1", + updated="1", + type="rawstudy", + owner=OwnerInfo(id=0, name="author"), + groups=[], + public_mode=PublicMode.NONE, + version=42, + workspace="default", + managed=True, + archived=False, ) - ) + storage_service_mock.get_study_path.return_value = Path("path/to/study") + uuid = uuid4() + launcher_mock = Mock() + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} -@pytest.mark.unit_test -def test_service_get_result_from_launcher(): - launcher_mock = Mock() - fake_execution_result = JobResult( - id=str(uuid4()), - study_id="sid", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - launcher="local", - ) - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.get.return_value = fake_execution_result - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - job_id = uuid4() - assert ( - launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) + event_bus = Mock() + pending = JobResult( + id=str(uuid), + study_id="study_uuid", + job_status=JobStatus.PENDING, + launcher="local", + launcher_params=LauncherParametersDTO().json(), + ) + repository = Mock() + repository.save.return_value = pending + + launcher_service = LauncherService( + config=Config(), + study_service=storage_service_mock, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=event_bus, + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher_service._generate_new_id = lambda: str(uuid) -@pytest.mark.unit_test -def test_service_get_result_from_database(): - launcher_mock = Mock() - fake_execution_result = JobResult( - id=str(uuid4()), - study_id="sid", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - ) - launcher_mock.get_result.return_value = None - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.get.return_value = fake_execution_result - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - assert ( - launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) + job_id = launcher_service.run_study( + "study_uuid", + "local", + LauncherParametersDTO(), + RequestParameters( + user=JWTUser( + id=0, + impersonator=0, + type="users", + ) + ), + ) + assert job_id == str(uuid) + repository.save.assert_called_once_with(pending) + event_bus.push.assert_called_once_with( + Event( + type=EventType.STUDY_JOB_STARTED, + payload=pending.to_dto().dict(), + permissions=PermissionInfo(owner=0), + ) + ) -@pytest.mark.unit_test -def test_service_get_jobs_from_database(): - launcher_mock = Mock() - now = datetime.utcnow() - fake_execution_result = [ - JobResult( + @pytest.mark.unit_test + def test_service_get_result_from_launcher(self) -> None: + launcher_mock = Mock() + fake_execution_result = JobResult( id=str(uuid4()), - study_id="a", + study_id="sid", job_status=JobStatus.SUCCESS, msg="Hello, World!", exit_code=0, + launcher="local", ) - ] - returned_faked_execution_results = [ - JobResult( - id="1", - study_id="a", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - creation_date=now, - ), - JobResult( - id="2", - study_id="b", - job_status=JobStatus.SUCCESS, - msg="Hello, World!", - exit_code=0, - creation_date=now, - ), - ] - all_faked_execution_results = returned_faked_execution_results + [ - JobResult( - id="3", - study_id="c", + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.get.return_value = fake_execution_result + + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + job_id = uuid4() + assert ( + launcher_service.get_result(job_uuid=job_id, params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result + ) + + @pytest.mark.unit_test + def test_service_get_result_from_database(self) -> None: + launcher_mock = Mock() + fake_execution_result = JobResult( + id=str(uuid4()), + study_id="sid", job_status=JobStatus.SUCCESS, msg="Hello, World!", exit_code=0, - creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1), - ) - ] - launcher_mock.get_result.return_value = None - factory_launcher_mock = Mock() - factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} - - repository = Mock() - repository.find_by_study.return_value = fake_execution_result - repository.get_all.return_value = all_faked_execution_results - - study_service = Mock() - study_service.repository = Mock() - study_service.repository.get_list.return_value = [ - Mock( - spec=Study, - id="b", - groups=[], - owner=User(id=2), - public_mode=PublicMode.NONE, ) - ] - - launcher_service = LauncherService( - config=Config(), - study_service=study_service, - job_result_repository=repository, - factory_launcher=factory_launcher_mock, - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + launcher_mock.get_result.return_value = None + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.get.return_value = fake_execution_result + + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - study_id = uuid4() - assert ( - launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == fake_execution_result - ) - repository.find_by_study.assert_called_once_with(str(study_id)) - assert ( - launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER)) - == all_faked_execution_results - ) - assert ( - launcher_service.get_jobs( - None, - params=RequestParameters( - user=JWTUser( - id=2, - impersonator=2, - type="users", - groups=[], - ) - ), + assert ( + launcher_service.get_result(job_uuid=uuid4(), params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result ) - == returned_faked_execution_results - ) - with pytest.raises(UserHasNotPermissionError): - launcher_service.remove_job( - "some job", - RequestParameters( - user=JWTUser( - id=2, - impersonator=2, - type="users", - groups=[], - ) + @pytest.mark.unit_test + def test_service_get_jobs_from_database(self) -> None: + launcher_mock = Mock() + now = datetime.utcnow() + fake_execution_result = [ + JobResult( + id=str(uuid4()), + study_id="a", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + ) + ] + returned_faked_execution_results = [ + JobResult( + id="1", + study_id="a", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now, + ), + JobResult( + id="2", + study_id="b", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now, ), + ] + all_faked_execution_results = returned_faked_execution_results + [ + JobResult( + id="3", + study_id="c", + job_status=JobStatus.SUCCESS, + msg="Hello, World!", + exit_code=0, + creation_date=now - timedelta(days=ORPHAN_JOBS_VISIBILITY_THRESHOLD + 1), + ) + ] + launcher_mock.get_result.return_value = None + factory_launcher_mock = Mock() + factory_launcher_mock.build_launcher.return_value = {"local": launcher_mock} + + repository = Mock() + repository.find_by_study.return_value = fake_execution_result + repository.get_all.return_value = all_faked_execution_results + + study_service = Mock() + study_service.repository = Mock() + study_service.repository.get_list.return_value = [ + Mock( + spec=Study, + id="b", + groups=[], + owner=User(id=2), + public_mode=PublicMode.NONE, + ) + ] + + launcher_service = LauncherService( + config=Config(), + study_service=study_service, + job_result_repository=repository, + factory_launcher=factory_launcher_mock, + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), ) - launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER)) - repository.delete.assert_called_with("some job") + study_id = uuid4() + assert ( + launcher_service.get_jobs(str(study_id), params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == fake_execution_result + ) + repository.find_by_study.assert_called_once_with(str(study_id)) + assert ( + launcher_service.get_jobs(None, params=RequestParameters(user=DEFAULT_ADMIN_USER)) + == all_faked_execution_results + ) + assert ( + launcher_service.get_jobs( + None, + params=RequestParameters( + user=JWTUser( + id=2, + impersonator=2, + type="users", + groups=[], + ) + ), + ) + == returned_faked_execution_results + ) + with pytest.raises(UserHasNotPermissionError): + launcher_service.remove_job( + "some job", + RequestParameters( + user=JWTUser( + id=2, + impersonator=2, + type="users", + groups=[], + ) + ), + ) -@pytest.mark.unit_test -@pytest.mark.parametrize( - "config, solver, expected", - [ - pytest.param( - { - "default": "local", - "local": [], - "slurm": [], - }, - "default", - [], - id="empty-config", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "default", - ["123", "456", "798"], - id="local-config-default", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "slurm", - [], - id="local-config-slurm", - ), - pytest.param( - { - "default": "local", - "local": ["456", "123", "798"], - }, - "unknown", - [], - id="local-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + launcher_service.remove_job("some job", RequestParameters(user=DEFAULT_ADMIN_USER)) + repository.delete.assert_called_with("some job") + + @pytest.mark.unit_test + @pytest.mark.parametrize( + "config, solver, expected", + [ + pytest.param( + { + "default": "local", + "local": [], + "slurm": [], + }, + "default", + [], + id="empty-config", ), - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "default", - ["147", "258", "369"], - id="slurm-config-default", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "local", - [], - id="slurm-config-local", - ), - pytest.param( - { - "default": "slurm", - "slurm": ["258", "147", "369"], - }, - "unknown", - [], - id="slurm-config-unknown", - marks=pytest.mark.xfail( - reason="Unknown solver configuration: 'unknown'", - raises=KeyError, - strict=True, + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "default", + ["123", "456", "798"], + id="local-config-default", ), - ), - pytest.param( - { - "default": "slurm", - "local": ["456", "123", "798"], - "slurm": ["258", "147", "369"], - }, - "local", - ["123", "456", "798"], - id="local+slurm-config-local", - ), - ], -) -def test_service_get_solver_versions( - config: Dict[str, Union[str, List[str]]], - solver: Literal["default", "local", "slurm", "unknown"], - expected: List[str], -) -> None: - # Prepare the configuration - default = config.get("default", "local") - local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) - slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) - launcher_config = LauncherConfig( - default=default, - local=local if local else None, - slurm=slurm if slurm else None, - ) - config = Config(launcher=launcher_config) - launcher_service = LauncherService( - config=config, - study_service=Mock(), - job_result_repository=Mock(), - factory_launcher=Mock(), - event_bus=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "slurm", + [], + id="local-config-slurm", + ), + pytest.param( + { + "default": "local", + "local": ["456", "123", "798"], + }, + "unknown", + [], + id="local-config-unknown", + marks=pytest.mark.xfail( + reason="Unknown solver configuration: 'unknown'", + raises=KeyError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "default", + ["147", "258", "369"], + id="slurm-config-default", + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "local", + [], + id="slurm-config-local", + ), + pytest.param( + { + "default": "slurm", + "slurm": ["258", "147", "369"], + }, + "unknown", + [], + id="slurm-config-unknown", + marks=pytest.mark.xfail( + reason="Unknown solver configuration: 'unknown'", + raises=KeyError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "local": ["456", "123", "798"], + "slurm": ["258", "147", "369"], + }, + "local", + ["123", "456", "798"], + id="local+slurm-config-local", + ), + ], ) + def test_service_get_solver_versions( + self, + config: Dict[str, Union[str, List[str]]], + solver: Literal["default", "local", "slurm", "unknown"], + expected: List[str], + ) -> None: + # Prepare the configuration + # the default server version from the configuration file. + # the default server is initialised to local + default = config.get("default", "local") + local = LocalConfig(binaries={k: Path(f"solver-{k}.exe") for k in config.get("local", [])}) + slurm = SlurmConfig(antares_versions_on_remote_server=config.get("slurm", [])) + launcher_config = LauncherConfig( + default=default, + local=local if local else None, + slurm=slurm if slurm else None, + ) + config = Config(launcher=launcher_config) + launcher_service = LauncherService( + config=config, + study_service=Mock(), + job_result_repository=Mock(), + factory_launcher=Mock(), + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - # Fetch the solver versions - actual = launcher_service.get_solver_versions(solver) + # Fetch the solver versions + actual = launcher_service.get_solver_versions(solver) + assert actual == expected - # Check the result - assert actual == expected + @pytest.mark.unit_test + @pytest.mark.parametrize( + "config_map, solver, expected", + [ + pytest.param( + {"default": "local", "local": {}, "slurm": {}}, + "default", + {}, + id="empty-config", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "default", + {"min": 1, "default": 11, "max": 12}, + id="local-config-default", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "slurm", + {}, + id="local-config-slurm", + ), + pytest.param( + { + "default": "local", + "local": {"min": 1, "default": 11, "max": 12}, + }, + "unknown", + {}, + id="local-config-unknown", + marks=pytest.mark.xfail( + reason="Configuration is not available for the 'unknown' launcher", + raises=InvalidConfigurationError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "default", + {"min": 4, "default": 8, "max": 16}, + id="slurm-config-default", + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "local", + {}, + id="slurm-config-local", + ), + pytest.param( + { + "default": "slurm", + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "unknown", + {}, + id="slurm-config-unknown", + marks=pytest.mark.xfail( + reason="Configuration is not available for the 'unknown' launcher", + raises=InvalidConfigurationError, + strict=True, + ), + ), + pytest.param( + { + "default": "slurm", + "local": {"min": 1, "default": 11, "max": 12}, + "slurm": {"min": 4, "default": 8, "max": 16}, + }, + "local", + {"min": 1, "default": 11, "max": 12}, + id="local+slurm-config-local", + ), + ], + ) + def test_get_nb_cores( + self, + config_map: Dict[str, Union[str, Dict[str, int]]], + solver: Literal["default", "local", "slurm", "unknown"], + expected: Dict[str, int], + ) -> None: + # Prepare the configuration + default = config_map.get("default", "local") + local_nb_cores = config_map.get("local", {}) + slurm_nb_cores = config_map.get("slurm", {}) + launcher_config = LauncherConfig( + default=default, + local=LocalConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": local_nb_cores}), + slurm=SlurmConfig.from_dict({"enable_nb_cores_detection": False, "nb_cores": slurm_nb_cores}), + ) + launcher_service = LauncherService( + config=Config(launcher=launcher_config), + study_service=Mock(), + job_result_repository=Mock(), + factory_launcher=Mock(), + event_bus=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + # Fetch the number of cores + actual = launcher_service.get_nb_cores(solver) + + # Check the result + assert actual == NbCoresConfig(**expected) + + @pytest.mark.unit_test + def test_service_kill_job(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.launcher = launcher + launcher_service.job_result_repository.get.return_value = job_result_mock + launcher_service.launchers = {"slurm": Mock()} + + job_status = launcher_service.kill_job( + job_id=job_id, + params=RequestParameters(user=DEFAULT_ADMIN_USER), + ) -@pytest.mark.unit_test -def test_service_kill_job(tmp_path: Path): - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id) - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.launcher = launcher - launcher_service.job_result_repository.get.return_value = job_result_mock - launcher_service.launchers = {"slurm": Mock()} - - job_status = launcher_service.kill_job( - job_id=job_id, - params=RequestParameters(user=DEFAULT_ADMIN_USER), - ) + assert job_status.job_status == JobStatus.FAILED + launcher_service.job_result_repository.save.assert_called_once_with(job_status) - launcher_service.launchers[launcher].kill_job.assert_called_once_with(job_id=job_id) + def test_append_logs(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - assert job_status.job_status == JobStatus.FAILED - launcher_service.job_result_repository.save.assert_called_once_with(job_status) + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.output_id = None + job_result_mock.launcher = launcher + job_result_mock.logs = [] + launcher_service.job_result_repository.get.return_value = job_result_mock + + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + # noinspection SpellCheckingInspection + DBSessionMiddleware( + None, + custom_engine=engine, + session_args={"autocommit": False, "autoflush": False}, + ) + launcher_service.append_log(job_id, "test", JobLogType.BEFORE) + launcher_service.job_result_repository.save.assert_called_with(job_result_mock) + assert job_result_mock.logs[0].message == "test" + assert job_result_mock.logs[0].job_id == "job_id" + assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE) + + def test_get_logs(self, tmp_path: Path) -> None: + study_service = Mock() + launcher_service = LauncherService( + config=Config(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + launcher = "slurm" + job_id = "job_id" + job_result_mock = Mock() + job_result_mock.id = job_id + job_result_mock.study_id = "study_id" + job_result_mock.output_id = None + job_result_mock.launcher = launcher + job_result_mock.logs = [ + JobLog(message="first message", log_type=str(JobLogType.BEFORE)), + JobLog(message="second message", log_type=str(JobLogType.BEFORE)), + JobLog(message="last message", log_type=str(JobLogType.AFTER)), + ] + job_result_mock.launcher_params = '{"archive_output": false}' + + launcher_service.job_result_repository.get.return_value = job_result_mock + slurm_launcher = Mock() + launcher_service.launchers = {"slurm": slurm_launcher} + slurm_launcher.get_log.return_value = "launcher logs" + + logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "first message\nsecond message\nlauncher logs\nlast message" + logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "launcher logs" + + study_service.get_logs.side_effect = ["some sim log", "error log"] + + job_result_mock.output_id = "some id" + logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "first message\nsecond message\nsome sim log\nlast message" + + logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) + assert logs == "error log" + + study_service.get_logs.assert_has_calls( + [ + call( + "study_id", + "some id", + job_id, + False, + params=RequestParameters(DEFAULT_ADMIN_USER), + ), + call( + "study_id", + "some id", + job_id, + True, + params=RequestParameters(DEFAULT_ADMIN_USER), + ), + ] + ) + def test_manage_output(self, tmp_path: Path) -> None: + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + # noinspection SpellCheckingInspection + DBSessionMiddleware( + None, + custom_engine=engine, + session_args={"autocommit": False, "autoflush": False}, + ) -def test_append_logs(tmp_path: Path): - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.output_id = None - job_result_mock.launcher = launcher - job_result_mock.logs = [] - launcher_service.job_result_repository.get.return_value = job_result_mock - - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - launcher_service.append_log(job_id, "test", JobLogType.BEFORE) - launcher_service.job_result_repository.save.assert_called_with(job_result_mock) - assert job_result_mock.logs[0].message == "test" - assert job_result_mock.logs[0].job_id == "job_id" - assert job_result_mock.logs[0].log_type == str(JobLogType.BEFORE) - - -def test_get_logs(tmp_path: Path): - study_service = Mock() - launcher_service = LauncherService( - config=Config(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - launcher = "slurm" - job_id = "job_id" - job_result_mock = Mock() - job_result_mock.id = job_id - job_result_mock.study_id = "study_id" - job_result_mock.output_id = None - job_result_mock.launcher = launcher - job_result_mock.logs = [ - JobLog(message="first message", log_type=str(JobLogType.BEFORE)), - JobLog(message="second message", log_type=str(JobLogType.BEFORE)), - JobLog(message="last message", log_type=str(JobLogType.AFTER)), - ] - job_result_mock.launcher_params = '{"archive_output": false}' - - launcher_service.job_result_repository.get.return_value = job_result_mock - slurm_launcher = Mock() - launcher_service.launchers = {"slurm": slurm_launcher} - slurm_launcher.get_log.return_value = "launcher logs" - - logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "first message\nsecond message\nlauncher logs\nlast message" - logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "launcher logs" - - study_service.get_logs.side_effect = ["some sim log", "error log"] - - job_result_mock.output_id = "some id" - logs = launcher_service.get_log(job_id, LogType.STDOUT, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "first message\nsecond message\nsome sim log\nlast message" - - logs = launcher_service.get_log(job_id, LogType.STDERR, RequestParameters(DEFAULT_ADMIN_USER)) - assert logs == "error log" - - study_service.get_logs.assert_has_calls( - [ - call( - "study_id", - "some id", - job_id, - False, - params=RequestParameters(DEFAULT_ADMIN_USER), + output_path = tmp_path / "output" + zipped_output_path = tmp_path / "zipped_output" + os.mkdir(output_path) + os.mkdir(zipped_output_path) + new_output_path = output_path / "new_output" + os.mkdir(new_output_path) + (new_output_path / "log").touch() + (new_output_path / "data").touch() + additional_log = tmp_path / "output.log" + additional_log.write_text("some log") + new_output_zipped_path = zipped_output_path / "test.zip" + with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data: + output_data.writestr("some output", "0\n1") + job_id = "job_id" + zipped_job_id = "zipped_job_id" + study_id = "study_id" + launcher_service.job_result_repository.get.side_effect = [ + None, + JobResult(id=job_id, study_id=study_id), + JobResult(id=job_id, study_id=study_id, output_id="some id"), + JobResult(id=zipped_job_id, study_id=study_id), + JobResult( + id=job_id, + study_id=study_id, ), - call( - "study_id", - "some id", - job_id, - True, - params=RequestParameters(DEFAULT_ADMIN_USER), + JobResult( + id=job_id, + study_id=study_id, + launcher_params=json.dumps( + { + "archive_output": False, + f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello", + } + ), ), ] - ) + with pytest.raises(JobNotFound): + launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - -def test_manage_output(tmp_path: Path): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) - - launcher_service = LauncherService( - config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - output_path = tmp_path / "output" - zipped_output_path = tmp_path / "zipped_output" - os.mkdir(output_path) - os.mkdir(zipped_output_path) - new_output_path = output_path / "new_output" - os.mkdir(new_output_path) - (new_output_path / "log").touch() - (new_output_path / "data").touch() - additional_log = tmp_path / "output.log" - additional_log.write_text("some log") - new_output_zipped_path = zipped_output_path / "test.zip" - with ZipFile(new_output_zipped_path, "w", ZIP_DEFLATED) as output_data: - output_data.writestr("some output", "0\n1") - job_id = "job_id" - zipped_job_id = "zipped_job_id" - study_id = "study_id" - launcher_service.job_result_repository.get.side_effect = [ - None, - JobResult(id=job_id, study_id=study_id), - JobResult(id=job_id, study_id=study_id, output_id="some id"), - JobResult(id=zipped_job_id, study_id=study_id), - JobResult( - id=job_id, - study_id=study_id, - ), - JobResult( - id=job_id, - study_id=study_id, - launcher_params=json.dumps( - { - "archive_output": False, - f"{LAUNCHER_PARAM_NAME_SUFFIX}": "hello", - } - ), - ), - ] - with pytest.raises(JobNotFound): launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + assert not launcher_service._get_job_output_fallback_path(job_id).exists() + launcher_service.study_service.import_output.assert_called() - launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - assert not launcher_service._get_job_output_fallback_path(job_id).exists() - launcher_service.study_service.import_output.assert_called() - - launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) - launcher_service.study_service.export_output.assert_called() - - launcher_service._import_output( - zipped_job_id, - zipped_output_path, - { - "out.log": [additional_log], - "antares-out": [additional_log], - "antares-err": [additional_log], - }, - ) - launcher_service.study_service.save_logs.has_calls( - [ - call(study_id, zipped_job_id, "out.log", "some log"), - call(study_id, zipped_job_id, "out", "some log"), - call(study_id, zipped_job_id, "err", "some log"), - ] - ) - - launcher_service.study_service.import_output.side_effect = [ - StudyNotFoundError(""), - StudyNotFoundError(""), - ] - - assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None - - (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}") - output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) - assert output_name is not None - assert output_name.endswith("-hello") - assert launcher_service._get_job_output_fallback_path(job_id).exists() - assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists() - - launcher_service.job_result_repository.get.reset_mock() - launcher_service.job_result_repository.get.side_effect = [ - None, - JobResult(id=job_id, study_id=study_id, output_id=output_name), - ] - with pytest.raises(JobNotFound): launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) + launcher_service.study_service.export_output.assert_called() - study_service.get_study.reset_mock() - study_service.get_study.side_effect = StudyNotFoundError("") + launcher_service._import_output( + zipped_job_id, + zipped_output_path, + { + "out.log": [additional_log], + "antares-out": [additional_log], + "antares-err": [additional_log], + }, + ) + launcher_service.study_service.save_logs.has_calls( + [ + call(study_id, zipped_job_id, "out.log", "some log"), + call(study_id, zipped_job_id, "out", "some log"), + call(study_id, zipped_job_id, "err", "some log"), + ] + ) - export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True) - launcher_service.file_transfer_manager.request_download.return_value = FileDownload( - id="a", name="a", filename="a", ready=True, path="a" - ) - launcher_service.task_service.add_task.return_value = "some id" + launcher_service.study_service.import_output.side_effect = [ + StudyNotFoundError(""), + StudyNotFoundError(""), + ] - assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO( - task="some id", file=export_file - ) + assert launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) is None - launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER)) - assert not launcher_service._get_job_output_fallback_path(job_id).exists() + (new_output_path / "info.antares-output").write_text(f"[general]\nmode=eco\nname=foo\ntimestamp={time.time()}") + output_name = launcher_service._import_output(job_id, output_path, {"out.log": [additional_log]}) + assert output_name is not None + assert output_name.endswith("-hello") + assert launcher_service._get_job_output_fallback_path(job_id).exists() + assert (launcher_service._get_job_output_fallback_path(job_id) / output_name / "out.log").exists() + launcher_service.job_result_repository.get.reset_mock() + launcher_service.job_result_repository.get.side_effect = [ + None, + JobResult(id=job_id, study_id=study_id, output_id=output_name), + ] + with pytest.raises(JobNotFound): + launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) -def test_save_stats(tmp_path: Path) -> None: - study_service = Mock() - study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + study_service.get_study.reset_mock() + study_service.get_study.side_effect = StudyNotFoundError("") - launcher_service = LauncherService( - config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), - study_service=study_service, - job_result_repository=Mock(), - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) + export_file = FileDownloadDTO(id="a", name="a", filename="a", ready=True) + launcher_service.file_transfer_manager.request_download.return_value = FileDownload( + id="a", name="a", filename="a", ready=True, path="a" + ) + launcher_service.task_service.add_task.return_value = "some id" - job_id = "job_id" - study_id = "study_id" - job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS) - - output_path = tmp_path / "some-output" - output_path.mkdir() - - launcher_service._save_solver_stats(job_result, output_path) - launcher_service.job_result_repository.save.assert_not_called() - - expected_saved_stats = """#item duration_ms NbOccurences -mc_years 216328 1 -study_loading 4304 1 -survey_report 158 1 -total 244581 1 -tsgen_hydro 1683 1 -tsgen_load 2702 1 -tsgen_solar 21606 1 -tsgen_thermal 407 2 -tsgen_wind 2500 1 - """ - (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats) - - launcher_service._save_solver_stats(job_result, output_path) - launcher_service.job_result_repository.save.assert_called_with( - JobResult( - id=job_id, - study_id=study_id, - job_status=JobStatus.SUCCESS, - solver_stats=expected_saved_stats, + assert launcher_service.download_output("job_id", RequestParameters(DEFAULT_ADMIN_USER)) == FileDownloadTaskDTO( + task="some id", file=export_file ) - ) - zip_file = tmp_path / "test.zip" - with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: - output_data.writestr(EXECUTION_INFO_FILE, "0\n1") + launcher_service.remove_job(job_id, RequestParameters(user=DEFAULT_ADMIN_USER)) + assert not launcher_service._get_job_output_fallback_path(job_id).exists() + + def test_save_solver_stats(self, tmp_path: Path) -> None: + study_service = Mock() + study_service.get_study.return_value = Mock(spec=Study, groups=[], owner=None, public_mode=PublicMode.NONE) + + launcher_service = LauncherService( + config=Mock(storage=StorageConfig(tmp_dir=tmp_path)), + study_service=study_service, + job_result_repository=Mock(), + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) - launcher_service._save_solver_stats(job_result, zip_file) - launcher_service.job_result_repository.save.assert_called_with( - JobResult( - id=job_id, - study_id=study_id, - job_status=JobStatus.SUCCESS, - solver_stats="0\n1", + job_id = "job_id" + study_id = "study_id" + job_result = JobResult(id=job_id, study_id=study_id, job_status=JobStatus.SUCCESS) + + output_path = tmp_path / "some-output" + output_path.mkdir() + + launcher_service._save_solver_stats(job_result, output_path) + launcher_service.job_result_repository.save.assert_not_called() + + expected_saved_stats = """#item duration_ms NbOccurences + mc_years 216328 1 + study_loading 4304 1 + survey_report 158 1 + total 244581 1 + tsgen_hydro 1683 1 + tsgen_load 2702 1 + tsgen_solar 21606 1 + tsgen_thermal 407 2 + tsgen_wind 2500 1 + """ + (output_path / EXECUTION_INFO_FILE).write_text(expected_saved_stats) + + launcher_service._save_solver_stats(job_result, output_path) + launcher_service.job_result_repository.save.assert_called_with( + JobResult( + id=job_id, + study_id=study_id, + job_status=JobStatus.SUCCESS, + solver_stats=expected_saved_stats, + ) ) - ) + zip_file = tmp_path / "test.zip" + with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: + output_data.writestr(EXECUTION_INFO_FILE, "0\n1") + + launcher_service._save_solver_stats(job_result, zip_file) + launcher_service.job_result_repository.save.assert_called_with( + JobResult( + id=job_id, + study_id=study_id, + job_status=JobStatus.SUCCESS, + solver_stats="0\n1", + ) + ) -def test_get_load(tmp_path: Path): - study_service = Mock() - job_repository = Mock() + def test_get_load(self, tmp_path: Path) -> None: + study_service = Mock() + job_repository = Mock() - launcher_service = LauncherService( - config=Mock( + config = Config( storage=StorageConfig(tmp_dir=tmp_path), - launcher=LauncherConfig(local=LocalConfig(), slurm=SlurmConfig(default_n_cpu=12)), - ), - study_service=study_service, - job_result_repository=job_repository, - event_bus=Mock(), - factory_launcher=Mock(), - file_transfer_manager=Mock(), - task_service=Mock(), - cache=Mock(), - ) - - job_repository.get_running.side_effect = [ - [], - [], - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, - ), - ], - [ - Mock( - spec=JobResult, - launcher="slurm", - launcher_params='{"nb_cpu": 18}', - ), - Mock( - spec=JobResult, - launcher="local", - launcher_params=None, - ), - Mock( - spec=JobResult, - launcher="slurm", - launcher_params=None, + launcher=LauncherConfig( + local=LocalConfig(), + slurm=SlurmConfig(nb_cores=NbCoresConfig(min=1, default=12, max=24)), ), - Mock( - spec=JobResult, - launcher="local", - launcher_params='{"nb_cpu": 7}', - ), - ], - ] - - with pytest.raises(NotImplementedError): - launcher_service.get_load(from_cluster=True) - - load = launcher_service.get_load() - assert load["slurm"] == 0 - assert load["local"] == 0 - load = launcher_service.get_load() - assert load["slurm"] == 12.0 / 64 - assert load["local"] == 0 - load = launcher_service.get_load() - assert load["slurm"] == 30.0 / 64 - assert load["local"] == 8.0 / os.cpu_count() + ) + launcher_service = LauncherService( + config=config, + study_service=study_service, + job_result_repository=job_repository, + event_bus=Mock(), + factory_launcher=Mock(), + file_transfer_manager=Mock(), + task_service=Mock(), + cache=Mock(), + ) + + job_repository.get_running.side_effect = [ + # call #1 + [], + # call #2 + [], + # call #3 + [ + Mock( + spec=JobResult, + launcher="slurm", + launcher_params=None, + ), + ], + # call #4 + [ + Mock( + spec=JobResult, + launcher="slurm", + launcher_params='{"nb_cpu": 18}', + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="slurm", + launcher_params=None, + ), + Mock( + spec=JobResult, + launcher="local", + launcher_params='{"nb_cpu": 7}', + ), + ], + ] + + # call #1 + with pytest.raises(NotImplementedError): + launcher_service.get_load(from_cluster=True) + + # call #2 + load = launcher_service.get_load() + assert load["slurm"] == 0 + assert load["local"] == 0 + + # call #3 + load = launcher_service.get_load() + slurm_config = config.launcher.slurm + assert load["slurm"] == slurm_config.nb_cores.default / slurm_config.max_cores + assert load["local"] == 0 + + # call #4 + load = launcher_service.get_load() + local_config = config.launcher.local + assert load["slurm"] == (18 + slurm_config.nb_cores.default) / slurm_config.max_cores + assert load["local"] == (7 + local_config.nb_cores.default) / local_config.nb_cores.max diff --git a/tests/launcher/test_slurm_launcher.py b/tests/launcher/test_slurm_launcher.py index 7820abcdea..dfb6846e89 100644 --- a/tests/launcher/test_slurm_launcher.py +++ b/tests/launcher/test_slurm_launcher.py @@ -10,11 +10,9 @@ from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb from antareslauncher.main import MainParameters from antareslauncher.study_dto import StudyDTO -from sqlalchemy import create_engine +from sqlalchemy.orm import Session # type: ignore -from antarest.core.config import Config, LauncherConfig, SlurmConfig -from antarest.core.persistence import Base -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.core.config import Config, LauncherConfig, NbCoresConfig, SlurmConfig from antarest.launcher.adapters.abstractlauncher import LauncherInitException from antarest.launcher.adapters.slurm_launcher.slurm_launcher import ( LOG_DIR_NAME, @@ -24,32 +22,34 @@ SlurmLauncher, VersionNotSupportedError, ) -from antarest.launcher.model import JobStatus, LauncherParametersDTO +from antarest.launcher.model import JobStatus, LauncherParametersDTO, XpansionParametersDTO from antarest.tools.admin_lib import clean_locks_from_config @pytest.fixture def launcher_config(tmp_path: Path) -> Config: - return Config( - launcher=LauncherConfig( - slurm=SlurmConfig( - local_workspace=tmp_path, - default_json_db_name="default_json_db_name", - slurm_script_path="slurm_script_path", - antares_versions_on_remote_server=["42", "45"], - username="username", - hostname="hostname", - port=42, - private_key_file=Path("private_key_file"), - key_password="key_password", - password="password", - ) - ) - ) + data = { + "local_workspace": tmp_path, + "username": "john", + "hostname": "slurm-001", + "port": 22, + "private_key_file": Path("/home/john/.ssh/id_rsa"), + "key_password": "password", + "password": "password", + "default_wait_time": 10, + "default_time_limit": 20, + "default_json_db_name": "antares.db", + "slurm_script_path": "/path/to/slurm/launcher.sh", + "max_cores": 32, + "antares_versions_on_remote_server": ["840", "850", "860"], + "enable_nb_cores_detection": False, + "nb_cores": {"min": 1, "default": 34, "max": 36}, + } + return Config(launcher=LauncherConfig(slurm=SlurmConfig.from_dict(data))) @pytest.mark.unit_test -def test_slurm_launcher__launcher_init_exception(): +def test_slurm_launcher__launcher_init_exception() -> None: with pytest.raises( LauncherInitException, match="Missing parameter 'launcher.slurm'", @@ -63,13 +63,13 @@ def test_slurm_launcher__launcher_init_exception(): @pytest.mark.unit_test -def test_init_slurm_launcher_arguments(tmp_path: Path): +def test_init_slurm_launcher_arguments(tmp_path: Path) -> None: config = Config( launcher=LauncherConfig( slurm=SlurmConfig( default_wait_time=42, default_time_limit=43, - default_n_cpu=44, + nb_cores=NbCoresConfig(min=1, default=30, max=36), local_workspace=tmp_path, ) ) @@ -88,13 +88,15 @@ def test_init_slurm_launcher_arguments(tmp_path: Path): assert not arguments.xpansion_mode assert not arguments.version assert not arguments.post_processing - assert Path(arguments.studies_in) == config.launcher.slurm.local_workspace / "STUDIES_IN" - assert Path(arguments.output_dir) == config.launcher.slurm.local_workspace / "OUTPUT" - assert Path(arguments.log_dir) == config.launcher.slurm.local_workspace / "LOGS" + slurm_config = config.launcher.slurm + assert slurm_config is not None + assert Path(arguments.studies_in) == slurm_config.local_workspace / "STUDIES_IN" + assert Path(arguments.output_dir) == slurm_config.local_workspace / "OUTPUT" + assert Path(arguments.log_dir) == slurm_config.local_workspace / "LOGS" @pytest.mark.unit_test -def test_init_slurm_launcher_parameters(tmp_path: Path): +def test_init_slurm_launcher_parameters(tmp_path: Path) -> None: config = Config( launcher=LauncherConfig( slurm=SlurmConfig( @@ -115,23 +117,25 @@ def test_init_slurm_launcher_parameters(tmp_path: Path): slurm_launcher = SlurmLauncher(config=config, callbacks=Mock(), event_bus=Mock(), cache=Mock()) main_parameters = slurm_launcher._init_launcher_parameters() - assert main_parameters.json_dir == config.launcher.slurm.local_workspace - assert main_parameters.default_json_db_name == config.launcher.slurm.default_json_db_name - assert main_parameters.slurm_script_path == config.launcher.slurm.slurm_script_path - assert main_parameters.antares_versions_on_remote_server == config.launcher.slurm.antares_versions_on_remote_server + slurm_config = config.launcher.slurm + assert slurm_config is not None + assert main_parameters.json_dir == slurm_config.local_workspace + assert main_parameters.default_json_db_name == slurm_config.default_json_db_name + assert main_parameters.slurm_script_path == slurm_config.slurm_script_path + assert main_parameters.antares_versions_on_remote_server == slurm_config.antares_versions_on_remote_server assert main_parameters.default_ssh_dict == { - "username": config.launcher.slurm.username, - "hostname": config.launcher.slurm.hostname, - "port": config.launcher.slurm.port, - "private_key_file": config.launcher.slurm.private_key_file, - "key_password": config.launcher.slurm.key_password, - "password": config.launcher.slurm.password, + "username": slurm_config.username, + "hostname": slurm_config.hostname, + "port": slurm_config.port, + "private_key_file": slurm_config.private_key_file, + "key_password": slurm_config.key_password, + "password": slurm_config.password, } assert main_parameters.db_primary_key == "name" @pytest.mark.unit_test -def test_slurm_launcher_delete_function(tmp_path: str): +def test_slurm_launcher_delete_function(tmp_path: str) -> None: config = Config(launcher=LauncherConfig(slurm=SlurmConfig(local_workspace=Path(tmp_path)))) slurm_launcher = SlurmLauncher( config=config, @@ -155,64 +159,104 @@ def test_slurm_launcher_delete_function(tmp_path: str): assert not file_path.exists() -def test_extra_parameters(launcher_config: Config): +def test_extra_parameters(launcher_config: Config) -> None: + """ + The goal of this unit test is to control the protected method `_check_and_apply_launcher_params`, + which is called by the `SlurmLauncher.run_study` function, in a separate thread. + + The `_check_and_apply_launcher_params` method extract the parameters from the configuration + and populate a `argparse.Namespace` which is used to launch a simulation using Antares Launcher. + + We want to make sure all the parameters are populated correctly. + """ slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), event_bus=Mock(), cache=Mock(), ) - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO()) - assert launcher_params.n_cpu == 1 - assert launcher_params.time_limit == 0 + + apply_params = slurm_launcher._apply_params + launcher_params = apply_params(LauncherParametersDTO()) + slurm_config = slurm_launcher.config.launcher.slurm + assert slurm_config is not None + assert launcher_params.n_cpu == slurm_config.nb_cores.default + assert launcher_params.time_limit == slurm_config.default_time_limit assert not launcher_params.xpansion_mode assert not launcher_params.post_processing - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=12)) + launcher_params = apply_params(LauncherParametersDTO(other_options="")) + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(other_options="foo\tbar baz ")) + assert launcher_params.other_options == "foo bar baz" + + launcher_params = apply_params(LauncherParametersDTO(other_options="/foo?bar")) + assert launcher_params.other_options == "foobar" + + launcher_params = apply_params(LauncherParametersDTO(nb_cpu=12)) assert launcher_params.n_cpu == 12 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(nb_cpu=48)) - assert launcher_params.n_cpu == 1 + launcher_params = apply_params(LauncherParametersDTO(nb_cpu=999)) + assert launcher_params.n_cpu == slurm_config.nb_cores.default # out of range - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=10)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=10)) assert launcher_params.time_limit == MIN_TIME_LIMIT - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=999999999)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=999999999)) assert launcher_params.time_limit == MAX_TIME_LIMIT - 3600 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(time_limit=99999)) + launcher_params = apply_params(LauncherParametersDTO(time_limit=99999)) assert launcher_params.time_limit == 99999 - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(xpansion=True)) - assert launcher_params.xpansion_mode + launcher_params = apply_params(LauncherParametersDTO(xpansion=False)) + assert launcher_params.xpansion_mode is None + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=True)) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=True, xpansion_r_version=True)) + assert launcher_params.xpansion_mode == "r" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=False))) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "" + + launcher_params = apply_params(LauncherParametersDTO(xpansion=XpansionParametersDTO(sensitivity_mode=True))) + assert launcher_params.xpansion_mode == "cpp" + assert launcher_params.other_options == "xpansion_sensitivity" + + launcher_params = apply_params(LauncherParametersDTO(post_processing=False)) + assert launcher_params.post_processing is False - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(post_processing=True)) - assert launcher_params.post_processing + launcher_params = apply_params(LauncherParametersDTO(post_processing=True)) + assert launcher_params.post_processing is True - launcher_params = slurm_launcher._check_and_apply_launcher_params(LauncherParametersDTO(adequacy_patch={})) - assert launcher_params.post_processing + launcher_params = apply_params(LauncherParametersDTO(adequacy_patch={})) + assert launcher_params.post_processing is True # noinspection PyUnresolvedReferences @pytest.mark.parametrize( - "version, job_status", - [(42, JobStatus.RUNNING), (99, JobStatus.FAILED), (45, JobStatus.FAILED)], + "version, launcher_called, job_status", + [ + (840, True, JobStatus.RUNNING), + (860, False, JobStatus.FAILED), + pytest.param( + 999, False, JobStatus.FAILED, marks=pytest.mark.xfail(raises=VersionNotSupportedError, strict=True) + ), + ], ) @pytest.mark.unit_test def test_run_study( - tmp_path: Path, launcher_config: Config, version: int, + launcher_called: bool, job_status: JobStatus, -): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) +) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -231,7 +275,8 @@ def test_run_study( job_id = str(uuid.uuid4()) study_dir = argument.studies_in / job_id study_dir.mkdir(parents=True) - (study_dir / "study.antares").write_text( + study_antares_path = study_dir.joinpath("study.antares") + study_antares_path.write_text( textwrap.dedent( """\ [antares] @@ -242,22 +287,20 @@ def test_run_study( # noinspection PyUnusedLocal def call_launcher_mock(arguments: Namespace, parameters: MainParameters): - if version != 45: + if launcher_called: slurm_launcher.data_repo_tinydb.save_study(StudyDTO(job_id)) slurm_launcher._call_launcher = call_launcher_mock - if version == 99: - with pytest.raises(VersionNotSupportedError): - slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) - else: - slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) + # When the launcher is called + slurm_launcher._run_study(study_uuid, job_id, LauncherParametersDTO(), str(version)) + # Check the results assert ( version not in launcher_config.launcher.slurm.antares_versions_on_remote_server - or f"solver_version = {version}" in (study_dir / "study.antares").read_text(encoding="utf-8") + or f"solver_version = {version}" in study_antares_path.read_text(encoding="utf-8") ) - # slurm_launcher._clean_local_workspace.assert_called_once() + slurm_launcher.callbacks.export_study.assert_called_once() slurm_launcher.callbacks.update_status.assert_called_once_with(ANY, job_status, ANY, None) if job_status == JobStatus.RUNNING: @@ -266,7 +309,7 @@ def call_launcher_mock(arguments: Namespace, parameters: MainParameters): @pytest.mark.unit_test -def test_check_state(tmp_path: Path, launcher_config: Config): +def test_check_state(tmp_path: Path, launcher_config: Config) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -308,16 +351,7 @@ def test_check_state(tmp_path: Path, launcher_config: Config): @pytest.mark.unit_test -def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): - engine = create_engine("sqlite:///:memory:", echo=False) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - +def test_clean_local_workspace(tmp_path: Path, launcher_config: Config) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -325,7 +359,6 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): use_private_workspace=False, cache=Mock(), ) - (launcher_config.launcher.slurm.local_workspace / "machin.txt").touch() assert os.listdir(launcher_config.launcher.slurm.local_workspace) @@ -335,7 +368,7 @@ def test_clean_local_workspace(tmp_path: Path, launcher_config: Config): # noinspection PyUnresolvedReferences @pytest.mark.unit_test -def test_import_study_output(launcher_config, tmp_path): +def test_import_study_output(launcher_config, tmp_path) -> None: slurm_launcher = SlurmLauncher( config=launcher_config, callbacks=Mock(), @@ -399,7 +432,7 @@ def test_kill_job( run_with_mock, tmp_path: Path, launcher_config: Config, -): +) -> None: launch_id = "launch_id" mock_study = Mock() mock_study.name = launch_id @@ -419,35 +452,36 @@ def test_kill_job( slurm_launcher.kill_job(job_id=launch_id) + slurm_config = launcher_config.launcher.slurm launcher_arguments = Namespace( antares_version=0, check_queue=False, - job_id_to_kill=42, + job_id_to_kill=mock_study.job_id, json_ssh_config=None, log_dir=str(tmp_path / "LOGS"), - n_cpu=1, + n_cpu=slurm_config.nb_cores.default, output_dir=str(tmp_path / "OUTPUT"), post_processing=False, studies_in=str(tmp_path / "STUDIES_IN"), - time_limit=0, + time_limit=slurm_config.default_time_limit, version=False, wait_mode=False, - wait_time=0, + wait_time=slurm_config.default_wait_time, xpansion_mode=None, other_options=None, ) launcher_parameters = MainParameters( json_dir=Path(tmp_path), - default_json_db_name="default_json_db_name", - slurm_script_path="slurm_script_path", - antares_versions_on_remote_server=["42", "45"], + default_json_db_name=slurm_config.default_json_db_name, + slurm_script_path=slurm_config.slurm_script_path, + antares_versions_on_remote_server=slurm_config.antares_versions_on_remote_server, default_ssh_dict={ - "username": "username", - "hostname": "hostname", - "port": 42, - "private_key_file": Path("private_key_file"), - "key_password": "key_password", - "password": "password", + "username": slurm_config.username, + "hostname": slurm_config.hostname, + "port": slurm_config.port, + "private_key_file": slurm_config.private_key_file, + "key_password": slurm_config.key_password, + "password": slurm_config.password, }, db_primary_key="name", ) @@ -456,7 +490,7 @@ def test_kill_job( @patch("antarest.launcher.adapters.slurm_launcher.slurm_launcher.run_with") -def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config): +def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: Config) -> None: callbacks = Mock() (tmp_path / LOG_DIR_NAME).mkdir() @@ -474,11 +508,7 @@ def test_launcher_workspace_init(run_with_mock, tmp_path: Path, launcher_config: clean_locks_from_config(launcher_config) assert not (workspaces[0] / WORKSPACE_LOCK_FILE_NAME).exists() - slurm_launcher.data_repo_tinydb.save_study( - StudyDTO( - path="somepath", - ) - ) + slurm_launcher.data_repo_tinydb.save_study(StudyDTO(path="some_path")) run_with_mock.assert_not_called() # will use existing private workspace diff --git a/tests/study/business/conftest.py b/tests/study/business/conftest.py deleted file mode 100644 index 2638d47b3d..0000000000 --- a/tests/study/business/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -import contextlib - -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from antarest.dbmodel import Base - - -@pytest.fixture(scope="function", name="db_engine") -def db_engine_fixture(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(engine) - yield engine - engine.dispose() - - -@pytest.fixture(scope="function", name="db_session") -def db_session_fixture(db_engine): - make_session = sessionmaker(bind=db_engine) - with contextlib.closing(make_session()) as session: - yield session diff --git a/tests/test_resources.py b/tests/test_resources.py index 2a0bf94677..330116e507 100644 --- a/tests/test_resources.py +++ b/tests/test_resources.py @@ -4,6 +4,8 @@ import pytest +from antarest.core.config import Config + HERE = pathlib.Path(__file__).parent.resolve() PROJECT_DIR = next(iter(p for p in HERE.parents if p.joinpath("antarest").exists())) RESOURCES_DIR = PROJECT_DIR.joinpath("resources") @@ -84,3 +86,17 @@ def test_empty_study_zip(filename: str, expected_list: Sequence[str]): with zipfile.ZipFile(resource_path) as myzip: actual = sorted(myzip.namelist()) assert actual == expected_list + + +def test_resources_config(): + """ + Check that the "resources/config.yaml" file is valid. + + The launcher section must be configured to use a local launcher + with NB Cores detection enabled. + """ + config_path = RESOURCES_DIR.joinpath("deploy/config.yaml") + config = Config.from_yaml_file(config_path, res=RESOURCES_DIR) + assert config.launcher.default == "local" + assert config.launcher.local is not None + assert config.launcher.local.enable_nb_cores_detection is True