diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 18f22fa..5e79ce7 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -e .[test] + python -m pip install -r requirements-test.txt - name: Test with pytest run: | pytest diff --git a/antareslauncher/data_repo/data_repo_tinydb.py b/antareslauncher/data_repo/data_repo_tinydb.py index 8231aff..96e8aa4 100644 --- a/antareslauncher/data_repo/data_repo_tinydb.py +++ b/antareslauncher/data_repo/data_repo_tinydb.py @@ -1,3 +1,4 @@ +import copy import logging import typing as t @@ -83,7 +84,9 @@ def save_study(self, study: StudyDTO): pk_name = self.db_primary_key pk_value = getattr(study, pk_name) old = self.db.get(tinydb.where(pk_name) == pk_value) - new = vars(study) + study_dict = vars(study) + new = copy.deepcopy(study_dict) # to avoid modifying the study object + new["antares_version"] = f"{new['antares_version']:2d}" if old: diff = _calc_diff(old, new) logger.info(f"Updating study '{pk_value}' in database: {diff!r}") diff --git a/antareslauncher/main.py b/antareslauncher/main.py index 02b3226..8eda4a6 100644 --- a/antareslauncher/main.py +++ b/antareslauncher/main.py @@ -20,6 +20,7 @@ from antareslauncher.use_cases.retrieve.retrieve_controller import RetrieveController from antareslauncher.use_cases.retrieve.state_updater import StateUpdater from antareslauncher.use_cases.wait_loop_controller.wait_controller import WaitController +from antares.study.version import SolverMinorVersion class NoJsonConfigFileError(Exception): @@ -67,7 +68,7 @@ class MainParameters: json_dir: Path default_json_db_name: str slurm_script_path: str - antares_versions_on_remote_server: t.Sequence[str] + antares_versions_on_remote_server: t.Sequence[SolverMinorVersion] default_ssh_dict: t.Mapping[str, t.Any] db_primary_key: str partition: str = "" @@ -120,7 +121,7 @@ def run_with(arguments: argparse.Namespace, parameters: MainParameters, show_ban post_processing=arguments.post_processing, antares_versions_on_remote_server=parameters.antares_versions_on_remote_server, other_options=arguments.other_options or "", - antares_version=arguments.antares_version, + antares_version=SolverMinorVersion.parse(arguments.antares_version), ), ) launch_controller = LaunchController(repo=data_repo, env=environment, display=display) diff --git a/antareslauncher/parameters_reader.py b/antareslauncher/parameters_reader.py index 6e04777..d25849e 100644 --- a/antareslauncher/parameters_reader.py +++ b/antareslauncher/parameters_reader.py @@ -8,6 +8,7 @@ from antareslauncher.main import MainParameters from antareslauncher.main_option_parser import ParserParameters +from antares.study.version import SolverMinorVersion ALT2_PARENT = Path.home() / "antares_launcher_settings" ALT1_PARENT = Path.cwd() @@ -51,7 +52,7 @@ def __init__(self, json_ssh_conf: Path, yaml_filepath: Path): self.remote_slurm_script_path = obj["SLURM_SCRIPT_PATH"] self.partition = obj.get("PARTITION", "") self.quality_of_service = obj.get("QUALITY_OF_SERVICE", "") - self.antares_versions = obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"] + self.antares_versions = [SolverMinorVersion.parse(v) for v in obj["ANTARES_VERSIONS_ON_REMOTE_SERVER"]] self.db_primary_key = obj["DB_PRIMARY_KEY"] self.json_dir = Path(obj["JSON_DIR"]).expanduser() self.json_db_name = obj.get("DEFAULT_JSON_DB_NAME", DEFAULT_JSON_DB_NAME) diff --git a/antareslauncher/remote_environnement/remote_environment_with_slurm.py b/antareslauncher/remote_environnement/remote_environment_with_slurm.py index 71d75ba..7045070 100644 --- a/antareslauncher/remote_environnement/remote_environment_with_slurm.py +++ b/antareslauncher/remote_environnement/remote_environment_with_slurm.py @@ -12,6 +12,7 @@ from antareslauncher.remote_environnement.slurm_script_features import ScriptParametersDTO, SlurmScriptFeatures from antareslauncher.remote_environnement.ssh_connection import SshConnection from antareslauncher.study_dto import StudyDTO +from antares.study.version import SolverMinorVersion logger = logging.getLogger(__name__) @@ -206,7 +207,7 @@ def submit_job(self, my_study: StudyDTO): input_zipfile_name=Path(my_study.zipfile_path).name, time_limit=time_limit, n_cpu=my_study.n_cpu, - antares_version=my_study.antares_version, + antares_version=SolverMinorVersion.parse(my_study.antares_version), run_mode=my_study.run_mode, post_processing=my_study.post_processing, other_options=my_study.other_options or "", diff --git a/antareslauncher/remote_environnement/slurm_script_features.py b/antareslauncher/remote_environnement/slurm_script_features.py index e741b04..41c2817 100644 --- a/antareslauncher/remote_environnement/slurm_script_features.py +++ b/antareslauncher/remote_environnement/slurm_script_features.py @@ -2,6 +2,7 @@ import shlex from antareslauncher.study_dto import Modes +from antares.study.version import SolverMinorVersion @dataclasses.dataclass @@ -10,7 +11,7 @@ class ScriptParametersDTO: input_zipfile_name: str time_limit: int n_cpu: int - antares_version: int + antares_version: SolverMinorVersion run_mode: Modes post_processing: bool other_options: str @@ -81,7 +82,7 @@ def compose_launch_command( for arg in [ self.solver_script_path, script_params.input_zipfile_name, - str(script_params.antares_version), + f"{script_params.antares_version:2d}", _job_type, str(script_params.post_processing), script_params.other_options, diff --git a/antareslauncher/study_dto.py b/antareslauncher/study_dto.py index 7202686..65fdf93 100644 --- a/antareslauncher/study_dto.py +++ b/antareslauncher/study_dto.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from enum import IntEnum from pathlib import Path - +from antares.study.version import StudyVersion class Modes(IntEnum): antares = 1 @@ -43,7 +43,7 @@ class StudyDTO: # Simulation stage data time_limit: t.Optional[int] = None n_cpu: int = 1 - antares_version: int = 0 + antares_version: StudyVersion = StudyVersion.parse(0) xpansion_mode: str = "" # "", "r", "cpp" run_mode: Modes = Modes.antares post_processing: bool = False @@ -59,4 +59,5 @@ def from_dict(cls, doc: t.Mapping[str, t.Any]) -> "StudyDTO": """ attrs = dict(**doc) attrs.pop("name", None) # calculated + attrs["antares_version"] = StudyVersion.parse(attrs["antares_version"]) return cls(**attrs) diff --git a/antareslauncher/use_cases/create_list/study_list_composer.py b/antareslauncher/use_cases/create_list/study_list_composer.py index b874ea7..040f35b 100644 --- a/antareslauncher/use_cases/create_list/study_list_composer.py +++ b/antareslauncher/use_cases/create_list/study_list_composer.py @@ -6,9 +6,11 @@ from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb from antareslauncher.display.display_terminal import DisplayTerminal from antareslauncher.study_dto import Modes, StudyDTO +from antares.study.version import SolverMinorVersion, StudyVersion +DEFAULT_VERSION = SolverMinorVersion.parse(0) -def get_solver_version(study_dir: Path, *, default: int = 0) -> int: +def get_solver_version(study_dir: Path, *, default: SolverMinorVersion = DEFAULT_VERSION) -> SolverMinorVersion: """ Retrieve the solver version number or else the study version number from the "study.antares" file. @@ -28,7 +30,7 @@ def get_solver_version(study_dir: Path, *, default: int = 0) -> int: section = config["antares"] for key in "solver_version", "version": if key in section: - return int(section[key]) + return SolverMinorVersion.parse(section[key]) return default @@ -41,9 +43,9 @@ class StudyListComposerParameters: xpansion_mode: str # "", "r", "cpp" output_dir: str post_processing: bool - antares_versions_on_remote_server: t.Sequence[str] + antares_versions_on_remote_server: t.Sequence[SolverMinorVersion] other_options: str - antares_version: int = 0 + antares_version: SolverMinorVersion = DEFAULT_VERSION class StudyListComposer: @@ -66,7 +68,7 @@ def __init__( self.antares_version = parameters.antares_version self._new_study_added = False self.DEFAULT_JOB_LOG_DIR_PATH = str(Path(self.log_dir) / "JOB_LOGS") - self.ANTARES_VERSIONS_ON_REMOTE_SERVER = [int(v) for v in parameters.antares_versions_on_remote_server] + self.ANTARES_VERSIONS_ON_REMOTE_SERVER = parameters.antares_versions_on_remote_server def get_list_of_studies(self): """Retrieve the list of studies from the repo @@ -76,7 +78,7 @@ def get_list_of_studies(self): """ return self._repo.get_list_of_studies() - def _create_study(self, path: Path, antares_version: int, xpansion_mode: str) -> StudyDTO: + def _create_study(self, path: Path, antares_version: SolverMinorVersion, xpansion_mode: str) -> StudyDTO: run_mode = { "": Modes.antares, "r": Modes.xpansion_r, @@ -86,7 +88,7 @@ def _create_study(self, path: Path, antares_version: int, xpansion_mode: str) -> path=str(path), n_cpu=self.n_cpu, time_limit=self.time_limit, - antares_version=antares_version, + antares_version=StudyVersion.parse(antares_version), job_log_dir=self.DEFAULT_JOB_LOG_DIR_PATH, output_dir=str(self.output_dir), xpansion_mode=xpansion_mode, @@ -120,7 +122,7 @@ def update_study_database(self): def _update_database_with_directory(self, directory_path: Path): solver_version = get_solver_version(directory_path) - antares_version = self.antares_version or solver_version + antares_version = self.antares_version if self.antares_version != DEFAULT_VERSION else solver_version if not antares_version: self._display.show_message( "... not a valid Antares study", diff --git a/requirements.txt b/requirements.txt index a69fdbf..79ec340 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +antares-study-version~=1.0.7 bcrypt~=3.2.2 cffi~=1.15.1 cryptography~=39.0.1 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5a42243..208b9b1 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,6 +8,7 @@ from antareslauncher.display.display_terminal import DisplayTerminal from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer, StudyListComposerParameters from tests.unit.assets import ASSETS_DIR +from antares.study.version import SolverMinorVersion @pytest.fixture(name="studies_in_dir") @@ -44,15 +45,16 @@ def study_list_composer_fixture( xpansion_mode="", output_dir=str(tmp_path.joinpath("FINISHED")), post_processing=False, - antares_versions_on_remote_server=[ + antares_versions_on_remote_server=[SolverMinorVersion.parse(v) for v in [ "800", "810", "820", "830", "840", "850", - ], + ]], other_options="", + ), ) return composer diff --git a/tests/unit/test_data_repo_tinydb.py b/tests/unit/test_data_repo_tinydb.py index 52ad688..ee46549 100644 --- a/tests/unit/test_data_repo_tinydb.py +++ b/tests/unit/test_data_repo_tinydb.py @@ -1,10 +1,7 @@ import random from pathlib import Path -from unittest import mock -from uuid import uuid4 import pytest -import tinydb from antareslauncher.data_repo.data_repo_tinydb import DataRepoTinydb from antareslauncher.study_dto import StudyDTO diff --git a/tests/unit/test_remote_environment_with_slurm.py b/tests/unit/test_remote_environment_with_slurm.py index b472eed..edc281a 100644 --- a/tests/unit/test_remote_environment_with_slurm.py +++ b/tests/unit/test_remote_environment_with_slurm.py @@ -19,6 +19,7 @@ ) from antareslauncher.remote_environnement.slurm_script_features import ScriptParametersDTO, SlurmScriptFeatures from antareslauncher.study_dto import Modes, StudyDTO +from antares.study.version import StudyVersion class TestRemoteEnvironmentWithSlurm: @@ -50,7 +51,7 @@ def study(self) -> StudyDTO: path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f", n_cpu=42, zipfile_path="path/to/study/91f1f911-4f4a-426f-b127-d0c2a2465b5f-foo.zip", - antares_version=700, + antares_version=StudyVersion.parse(700), local_final_zipfile_path="local_final_zipfile_path", run_mode=Modes.antares, ) @@ -689,7 +690,7 @@ def test_compose_launch_command( f" --cpus-per-task={study.n_cpu}" f" {filename_launch_script}" f" {Path(study.zipfile_path).name}" - f" {study.antares_version}" + f" {study.antares_version:2d}" f" {job_type}" f" {post_processing}" f" ''" diff --git a/tests/unit/test_study_dto.py b/tests/unit/test_study_dto.py new file mode 100644 index 0000000..5f33287 --- /dev/null +++ b/tests/unit/test_study_dto.py @@ -0,0 +1,23 @@ +from antares.study.version import StudyVersion + +from antareslauncher.study_dto import StudyDTO + + +def test_study_dto_from_dict_old_version_syntax(): + + study_dict = { + "path": "/path/to/study", + "antares_version": 880 + } + + study_dto = StudyDTO.from_dict(study_dict) + assert study_dto.antares_version == StudyVersion.parse("8.8") + + +def test_study_dto_from_dict(): + study_dict = { + "path": "/path/to/study", + "antares_version": "9.0" + } + study_dto = StudyDTO.from_dict(study_dict) + assert study_dto.antares_version == StudyVersion.parse("9.0") diff --git a/tests/unit/test_study_list_composer.py b/tests/unit/test_study_list_composer.py index 427bb28..6aa5ac4 100644 --- a/tests/unit/test_study_list_composer.py +++ b/tests/unit/test_study_list_composer.py @@ -3,6 +3,7 @@ import pytest from antareslauncher.use_cases.create_list.study_list_composer import StudyListComposer, get_solver_version +from antares.study.version import SolverMinorVersion CONFIG_NOMINAL_VERSION = """\ [antares] @@ -93,7 +94,8 @@ def test_update_study_database__antares_version( study_list_composer: StudyListComposer, antares_version: int, ): - study_list_composer.antares_version = antares_version + parsed_version = SolverMinorVersion.parse(antares_version) + study_list_composer.antares_version = parsed_version study_list_composer.update_study_database() studies = study_list_composer.get_list_of_studies() @@ -101,9 +103,9 @@ def test_update_study_database__antares_version( actual_versions = {s.name: s.antares_version for s in studies} if antares_version == 0: expected_versions = { - "013 TS Generation - Solar power": 850, # solver_version - "024 Hurdle costs - 1": 840, # versions - "SMTA-case": 810, # version + "013 TS Generation - Solar power": "8.5", # solver_version + "024 Hurdle costs - 1": "8.4", # versions + "SMTA-case": "8.1", # version } elif antares_version in study_list_composer.ANTARES_VERSIONS_ON_REMOTE_SERVER: study_names = { @@ -114,7 +116,7 @@ def test_update_study_database__antares_version( "MISSING Study version", "SMTA-case", } - expected_versions = dict.fromkeys(study_names, antares_version) + expected_versions = dict.fromkeys(study_names, parsed_version) else: expected_versions = {} assert actual_versions == {n: expected_versions[n] for n in actual_versions}