From da13c45878f7904c82f105a15a5bbed6036b5f7e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 23 Oct 2023 15:43:09 +0200 Subject: [PATCH 01/22] chore(typing): simplify typing in pytest fixture --- tests/conftest_db.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest_db.py b/tests/conftest_db.py index f22ab16d9b..925e0fe639 100644 --- a/tests/conftest_db.py +++ b/tests/conftest_db.py @@ -1,5 +1,5 @@ import contextlib -from typing import Any, Generator +import typing as t import pytest from sqlalchemy import create_engine # type: ignore @@ -13,7 +13,7 @@ @pytest.fixture(name="db_engine") -def db_engine_fixture() -> Generator[Engine, None, None]: +def db_engine_fixture() -> t.Generator[Engine, None, None]: """ Fixture that creates an in-memory SQLite database engine for testing. @@ -28,7 +28,7 @@ def db_engine_fixture() -> Generator[Engine, None, None]: @pytest.fixture(name="db_session") -def db_session_fixture(db_engine: Engine) -> Generator[Session, None, None]: +def db_session_fixture(db_engine: Engine) -> t.Generator[Session, None, None]: """ Fixture that creates a database session for testing purposes. @@ -49,7 +49,7 @@ def db_session_fixture(db_engine: Engine) -> Generator[Session, None, None]: @pytest.fixture(name="db_middleware", autouse=True) def db_middleware_fixture( db_engine: Engine, -) -> Generator[DBSessionMiddleware, None, None]: +) -> t.Generator[DBSessionMiddleware, None, None]: """ Fixture that sets up a database session middleware with custom engine settings. From 1e2e6176ba0943c7c53d31f3fbc8e950c50b6188 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 23 Oct 2023 17:28:18 +0200 Subject: [PATCH 02/22] chore(variant-study-service): correct the `timeout` parameter to use a default value --- antarest/core/tasks/service.py | 10 +++++----- .../variantstudy/variant_study_service.py | 19 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index bfb4c7230a..32ac8d0da0 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -34,6 +34,9 @@ TaskUpdateNotifier = Callable[[str], None] Task = Callable[[TaskUpdateNotifier], TaskResult] +DEFAULT_AWAIT_MAX_TIMEOUT = 172800 # 48 hours +"""Default timeout for `await_task` in seconds.""" + class ITaskService(ABC): @abstractmethod @@ -74,7 +77,7 @@ def list_tasks(self, task_filter: TaskListFilter, request_params: RequestParamet raise NotImplementedError() @abstractmethod - def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None: + def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: raise NotImplementedError() @@ -83,9 +86,6 @@ def noop_notifier(message: str) -> None: """This function is used in tasks when no notification is required.""" -DEFAULT_AWAIT_MAX_TIMEOUT = 172800 - - class TaskJobService(ITaskService): def __init__( self, @@ -283,7 +283,7 @@ def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestPara user = None if request_params.user.is_site_admin() else request_params.user.impersonator return self.repo.list(task_filter, user) - def await_task(self, task_id: str, timeout_sec: Optional[int] = None) -> None: + def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: logger.info(f"Awaiting task {task_id}") if task_id in self.tasks: self.tasks[task_id].result(timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index f956564656..de7c0e5042 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -30,7 +30,7 @@ from antarest.core.model import JSON, PermissionInfo, PublicMode, StudyPermissionType from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.tasks.model import CustomTaskEventMessages, TaskDTO, TaskResult, TaskType -from antarest.core.tasks.service import ITaskService, TaskUpdateNotifier, noop_notifier +from antarest.core.tasks.service import DEFAULT_AWAIT_MAX_TIMEOUT, ITaskService, TaskUpdateNotifier, noop_notifier from antarest.core.utils.utils import assert_this, suppress_exception from antarest.matrixstore.service import MatrixService from antarest.study.model import RawStudy, Study, StudyAdditionalData, StudyMetadataDTO, StudySimResultDTO @@ -483,7 +483,6 @@ def get( use_cache: indicate if cache should be used Returns: study data formatted in json - """ self._safe_generation(metadata, timeout=60) self.repository.refresh(metadata) @@ -882,19 +881,19 @@ def get_study_task(self, study_id: str, params: RequestParameters) -> TaskDTO: def create(self, study: VariantStudy) -> VariantStudy: """ - Create empty new study + Create an empty new study. Args: - study: study information - Returns: new study information + study: Study information. + Returns: New study information. """ raise NotImplementedError() def exists(self, metadata: VariantStudy) -> bool: """ - Check study exist. + Check if study exists. Args: - metadata: study - Returns: true if study presents in disk, false else. + metadata: Study metadata. + Returns: `True` if the study is present on disk, `False` otherwise. """ return ( (metadata.snapshot is not None) @@ -961,13 +960,13 @@ def copy( return dst_meta - def _wait_for_generation(self, metadata: VariantStudy, timeout: Optional[int] = None) -> bool: + def _wait_for_generation(self, metadata: VariantStudy, timeout: int) -> bool: task_id = self.generate_task(metadata) self.task_service.await_task(task_id, timeout) result = self.task_service.status_task(task_id, RequestParameters(DEFAULT_ADMIN_USER)) return (result.result is not None) and result.result.success - def _safe_generation(self, metadata: VariantStudy, timeout: Optional[int] = None) -> None: + def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: try: if not self.exists(metadata) and not self._wait_for_generation(metadata, timeout): raise ValueError() From 38addea8ceb246b77bd49793147bec4ebe6baa8e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 23 Oct 2023 19:04:17 +0200 Subject: [PATCH 03/22] feat(repository): add a method to update a task status to TIMEOUT in database --- antarest/core/tasks/repository.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/antarest/core/tasks/repository.py b/antarest/core/tasks/repository.py index c295f7ef28..1994c55fab 100644 --- a/antarest/core/tasks/repository.py +++ b/antarest/core/tasks/repository.py @@ -5,7 +5,7 @@ from fastapi import HTTPException -from antarest.core.tasks.model import TaskJob, TaskListFilter +from antarest.core.tasks.model import TaskJob, TaskListFilter, TaskStatus from antarest.core.utils.fastapi_sqlalchemy import db from antarest.core.utils.utils import assert_this @@ -82,3 +82,11 @@ def delete(self, tid: str) -> None: if task: db.session.delete(task) db.session.commit() + + def update_timeout(self, task_id: str, timeout: int) -> None: + """Update task status to TIMEOUT.""" + task: TaskJob = db.session.get(TaskJob, task_id) + task.status = TaskStatus.TIMEOUT + task.result_msg = f"Task '{task_id}' timeout after {timeout} seconds" + task.result_status = False + db.session.commit() From ff357bc6a862c5409ccf31deb0f478ebbb075388 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 23 Oct 2023 19:06:19 +0200 Subject: [PATCH 04/22] fix(task-job-service): update the task status to TIMEOUT if the processing is too long --- antarest/core/tasks/service.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 32ac8d0da0..8861e4c8fa 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -284,21 +284,24 @@ def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestPara return self.repo.list(task_filter, user) def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: - logger.info(f"Awaiting task {task_id}") + logger.info(f"Awaiting task '{task_id}'...") if task_id in self.tasks: - self.tasks[task_id].result(timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT) + self.tasks[task_id].result(timeout_sec) else: - logger.warning(f"Task {task_id} not handled by this worker, will poll for task completion from db") - end = time.time() + (timeout_sec or DEFAULT_AWAIT_MAX_TIMEOUT) + logger.warning(f"Task '{task_id}' not handled by this worker, will poll for task completion from db") + end = time.time() + timeout_sec while time.time() < end: with db(): task = self.repo.get(task_id) - if not task: - logger.error(f"Awaited task {task_id} was not found") - break + if task is None: + logger.error(f"Awaited task '{task_id}' was not found") + return if TaskStatus(task.status).is_final(): - break - time.sleep(2) + return + time.sleep(2) + logger.error(f"Timeout while awaiting task '{task_id}'") + with db(): + self.repo.update_timeout(task_id, timeout_sec) def _run_task( self, From 959b2c021c3e708f0f0b4d5e77e147ac3a71d7c5 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 23 Oct 2023 19:07:20 +0200 Subject: [PATCH 05/22] refactor(variant-study-service): improve implementation of safe generation --- .../variantstudy/variant_study_service.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index de7c0e5042..2e9974b44c 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -960,16 +960,20 @@ def copy( return dst_meta - def _wait_for_generation(self, metadata: VariantStudy, timeout: int) -> bool: - task_id = self.generate_task(metadata) - self.task_service.await_task(task_id, timeout) - result = self.task_service.status_task(task_id, RequestParameters(DEFAULT_ADMIN_USER)) - return (result.result is not None) and result.result.success - def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: try: - if not self.exists(metadata) and not self._wait_for_generation(metadata, timeout): - raise ValueError() + if self.exists(metadata): + # The study is already present on disk => nothing to do + return + + task_id = self.generate_task(metadata) + self.task_service.await_task(task_id, timeout) + result = self.task_service.status_task(task_id, RequestParameters(DEFAULT_ADMIN_USER)) + if result.result and result.result.success: + # OK, the study has been generated + return + raise ValueError("Fail to generate variant study") + except Exception as e: logger.error(f"Fail to generate variant study {metadata.id}", exc_info=e) raise VariantGenerationError(f"Error while generating {metadata.id}") from None From eb87bb8782a3eb5fca213a64da24bc64e916ed9d Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 24 Oct 2023 13:27:41 +0200 Subject: [PATCH 06/22] chore: improve log message readability by adding prominent emojis --- antarest/core/tasks/service.py | 11 +++++++++-- antarest/core/utils/utils.py | 1 + .../storage/variantstudy/variant_study_service.py | 15 ++++++++++----- antarest/study/web/raw_studies_blueprint.py | 2 +- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 8861e4c8fa..d1c339a692 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -141,6 +141,7 @@ def _send_worker_task(logger_: TaskUpdateNotifier) -> TaskResult: task_type, ) while not task_result_wrapper: + logger.info("💤 Sleeping 1 second...") time.sleep(1) self.event_bus.remove_listener(listener_id) return task_result_wrapper[0] @@ -284,9 +285,14 @@ def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestPara return self.repo.list(task_filter, user) def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: - logger.info(f"Awaiting task '{task_id}'...") if task_id in self.tasks: - self.tasks[task_id].result(timeout_sec) + try: + logger.info(f"🤔 Awaiting task '{task_id}'...") + self.tasks[task_id].result(timeout_sec) + logger.info(f"📌 Task '{task_id}' done.") + except Exception as exc: + logger.critical(f"🤕 Task '{task_id}' failed: {exc}.") + raise else: logger.warning(f"Task '{task_id}' not handled by this worker, will poll for task completion from db") end = time.time() + timeout_sec @@ -298,6 +304,7 @@ def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) return if TaskStatus(task.status).is_final(): return + logger.info("💤 Sleeping 2 seconds...") time.sleep(2) logger.error(f"Timeout while awaiting task '{task_id}'") with db(): diff --git a/antarest/core/utils/utils.py b/antarest/core/utils/utils.py index de0894416c..9154c023a3 100644 --- a/antarest/core/utils/utils.py +++ b/antarest/core/utils/utils.py @@ -114,6 +114,7 @@ def retry(func: Callable[[], T], attempts: int = 10, interval: float = 0.5) -> T attempt += 1 return func() except Exception as e: + logger.info(f"💤 Sleeping {interval} second(s)...") time.sleep(interval) caught_exception = e raise caught_exception or ShouldNotHappenException() diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index 2e9974b44c..ecb580e0de 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -690,9 +690,11 @@ def _generate( ) last_executed_command_index = ( None - if is_parent_newer - or from_scratch - or (isinstance(parent_study, VariantStudy) and not self.exists(parent_study)) + if ( + is_parent_newer + or from_scratch + or (isinstance(parent_study, VariantStudy) and not self.exists(parent_study)) + ) else last_executed_command_index ) @@ -966,6 +968,8 @@ def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_ # The study is already present on disk => nothing to do return + logger.info("🔹 Starting variant study generation...") + # Create and run the generation task in a thread pool. task_id = self.generate_task(metadata) self.task_service.await_task(task_id, timeout) result = self.task_service.status_task(task_id, RequestParameters(DEFAULT_ADMIN_USER)) @@ -975,8 +979,9 @@ def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_ raise ValueError("Fail to generate variant study") except Exception as e: - logger.error(f"Fail to generate variant study {metadata.id}", exc_info=e) - raise VariantGenerationError(f"Error while generating {metadata.id}") from None + logger.error(f"⚡ Fail to generate variant study {metadata.id}", exc_info=e) + # raise VariantGenerationError(f"Error while generating {metadata.id}") from None + raise @staticmethod def _get_snapshot_last_executed_command_index( diff --git a/antarest/study/web/raw_studies_blueprint.py b/antarest/study/web/raw_studies_blueprint.py index edf09aff37..c25af6e8a4 100644 --- a/antarest/study/web/raw_studies_blueprint.py +++ b/antarest/study/web/raw_studies_blueprint.py @@ -92,7 +92,7 @@ def get_study( or a file attachment (Microsoft Office document, CSV/TSV file...). """ logger.info( - f"Fetching data at {path} (depth={depth}) from study {uuid}", + f"📘 Fetching data at {path} (depth={depth}) from study {uuid}", extra={"user": current_user.id}, ) parameters = RequestParameters(user=current_user) From 979ac6a4e5cf1cd6c74ee7b5474c389c8cad510f Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 24 Oct 2023 13:56:59 +0200 Subject: [PATCH 07/22] feat(variant-generation): improve error handling for timeout tasks --- antarest/core/exceptions.py | 5 +++++ .../variantstudy/variant_study_service.py | 17 ++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index a3623e3ed1..878e1c7d6f 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -56,6 +56,11 @@ def __init__(self, message: str) -> None: super().__init__(HTTPStatus.EXPECTATION_FAILED, message) +class VariantGenerationTimeoutError(HTTPException): + def __init__(self, message: str) -> None: + super().__init__(HTTPStatus.REQUEST_TIMEOUT, message) + + class NoParentStudyError(HTTPException): def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index ecb580e0de..4adeb0f1aa 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -1,3 +1,4 @@ +import concurrent.futures import json import logging import re @@ -21,6 +22,7 @@ StudyNotFoundError, StudyTypeUnsupported, VariantGenerationError, + VariantGenerationTimeoutError, VariantStudyParentNotValid, ) from antarest.core.filetransfer.model import FileDownloadTaskDTO @@ -818,7 +820,7 @@ def _get_commands_and_notifier( from_index: int = 0, ) -> Tuple[List[List[ICommand]], Callable[[int, bool, str], None]]: # Generate - commands: List[List[ICommand]] = self._to_icommand(variant_study, from_index) + commands: List[List[ICommand]] = self._to_commands(variant_study, from_index) def notify(command_index: int, command_result: bool, command_message: str) -> None: try: @@ -845,7 +847,7 @@ def notify(command_index: int, command_result: bool, command_message: str) -> No return commands, notify - def _to_icommand(self, metadata: VariantStudy, from_index: int = 0) -> List[List[ICommand]]: + def _to_commands(self, metadata: VariantStudy, from_index: int = 0) -> List[List[ICommand]]: commands: List[List[ICommand]] = [ self.command_factory.to_command(command_block.to_dto()) for index, command_block in enumerate(metadata.commands) @@ -976,12 +978,17 @@ def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_ if result.result and result.result.success: # OK, the study has been generated return - raise ValueError("Fail to generate variant study") + raise ValueError("No task result or result failed") + + except concurrent.futures.TimeoutError as e: + # Raise a REQUEST_TIMEOUT error (408) + logger.error(f"⚡ Timeout while generating variant study {metadata.id}", exc_info=e) + raise VariantGenerationTimeoutError(f"Timeout while generating {metadata.id}") from None except Exception as e: + # raise a EXPECTATION_FAILED error (417) logger.error(f"⚡ Fail to generate variant study {metadata.id}", exc_info=e) - # raise VariantGenerationError(f"Error while generating {metadata.id}") from None - raise + raise VariantGenerationError(f"Error while generating {metadata.id}") from None @staticmethod def _get_snapshot_last_executed_command_index( From 0ba63200d12032adbce7795b9da0ce033f87121b Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 24 Oct 2023 13:32:57 +0200 Subject: [PATCH 08/22] test(variant-generation): add a unit test to check that the variant generation is not blocking the main thread --- .../variant_blueprint/test_thermal_cluster.py | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 tests/integration/variant_blueprint/test_thermal_cluster.py diff --git a/tests/integration/variant_blueprint/test_thermal_cluster.py b/tests/integration/variant_blueprint/test_thermal_cluster.py new file mode 100644 index 0000000000..70e2f8dcba --- /dev/null +++ b/tests/integration/variant_blueprint/test_thermal_cluster.py @@ -0,0 +1,147 @@ +import http +import random +import typing as t +from unittest import mock + +import numpy as np +import pytest +from starlette.testclient import TestClient + +from antarest.core.tasks.model import TaskDTO, TaskStatus +from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id + + +def _create_thermal_params(cluster_name: str) -> t.Mapping[str, t.Any]: + # noinspection SpellCheckingInspection + return { + "name": cluster_name, + "group": "Gas", + "unitcount": random.randint(1, 10), + "nominalcapacity": random.random() * 1000, + "min-stable-power": random.random() * 1000, + "min-up-time": random.randint(1, 168), + "min-down-time": random.randint(1, 168), + "co2": random.random() * 10, + "marginal-cost": random.random() * 100, + "spread-cost": random.random(), + "startup-cost": random.random() * 100000, + "market-bid-cost": random.random() * 100, + } + + +@pytest.mark.integration_test +class TestThermalCluster: + """ + The goal of this test is to check the performance of the update + of the properties and matrices of the thermal clusters in the case where we + have a cascade of study variants. + We want also to check that the variant snapshots are correctly created without + blocking thread (no timeout). + """ + + def test_cascade_update( + self, + client: TestClient, + user_access_token: str, + study_id: str, + ) -> None: + """ + This test is based on the study "STA-mini.zip", which is a RAW study. + We will first convert this study to a managed study, and then we will + create a cascade of _N_ variants (more than `max_workers`). + Finally, we will read the thermal clusters of all areas of the last variant + to check that the variant generation is not blocking. + """ + # First, we create a copy of the study, and we convert it to a managed study. + res = client.post( + f"/v1/studies/{study_id}/copy", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"dest": "default", "with_outputs": False, "use_task": False}, + ) + assert res.status_code == http.HTTPStatus.CREATED, res.json() + base_study_id = res.json() + assert base_study_id is not None + + # Store the variant IDs in a list. + cascade_ids = [base_study_id] + total_count = 6 # `max_workers` is set to 5 in the configuration (default value). + for count in range(1, total_count + 1): + # Create a new variant of the last study in the cascade. + prev_id = cascade_ids[-1] + res = client.post( + f"/v1/studies/{prev_id}/variants", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"name": f"Variant {count}"}, + ) + assert res.status_code == http.HTTPStatus.OK, res.json() # should be CREATED + variant_id = res.json() + assert variant_id is not None + cascade_ids.append(variant_id) + + # Create a thermal cluster in an area (randomly chosen). + area_id = "fr" + cluster_name = f"Cluster {count}" + cmd_args = { + "action": "create_cluster", + "args": { + "area_id": area_id, + "cluster_name": transform_name_to_id(cluster_name, lower=False), + "parameters": _create_thermal_params(cluster_name), + "prepro": np.random.rand(8760, 6).tolist(), + "modulation": np.random.rand(8760, 4).tolist(), + }, + } + res = client.post( + f"/v1/studies/{variant_id}/commands", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=[cmd_args], + ) + assert res.status_code == http.HTTPStatus.OK, res.json() # should be CREATED + + # At this point, we have a base study copied in the default workspace, + # and a certain number of variants stored in the database. + # But no variant is physically stored in the default workspace. + res = client.get( + "/v1/studies", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"managed": True}, + ) + assert res.status_code == http.HTTPStatus.OK, res.json() + study_map = res.json() # dict of study properties, indexed by study ID + assert set(study_map) | {base_study_id} == set(cascade_ids) + + # Now, we will generate the last variant + variant_id = cascade_ids[-1] + res = client.put( + f"/v1/studies/{variant_id}/generate", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"denormalize": False, "from_scratch": True}, + ) + assert res.status_code == http.HTTPStatus.OK, res.json() + task_id = res.json() + + # wait for task completion + res = client.get( + f"/v1/tasks/{task_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"wait_for_completion": True}, + ) + assert res.status_code == http.HTTPStatus.OK, res.json() + task = TaskDTO(**res.json()) + assert task.dict() == { + "completion_date_utc": mock.ANY, + "creation_date_utc": mock.ANY, + "id": task_id, + "logs": None, + "name": f"Generation of {variant_id} study", + "owner": 1, + "ref_id": variant_id, + "result": { + "message": f"{variant_id} generated successfully", + "return_value": mock.ANY, + "success": True, + }, + "status": TaskStatus.COMPLETED, + "type": "VARIANT_GENERATION", + } + From 0cac11c4233a80655323a7eabfb9cac62c22e1df Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 6 Nov 2023 23:25:56 +0100 Subject: [PATCH 09/22] feat(tasks): add timeout parameter to control task completion wait time Added the 'timeout' parameter to the /v1/tasks/{task_id} endpoint to enable users to control the waiting time for task completion. --- antarest/core/tasks/service.py | 2 +- antarest/core/tasks/web.py | 45 ++++++++++++++++--- .../variant_blueprint/test_thermal_cluster.py | 3 +- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index d1c339a692..28038e2d4c 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -287,7 +287,7 @@ def list_db_tasks(self, task_filter: TaskListFilter, request_params: RequestPara def await_task(self, task_id: str, timeout_sec: int = DEFAULT_AWAIT_MAX_TIMEOUT) -> None: if task_id in self.tasks: try: - logger.info(f"🤔 Awaiting task '{task_id}'...") + logger.info(f"🤔 Awaiting task '{task_id}' {timeout_sec}s...") self.tasks[task_id].result(timeout_sec) logger.info(f"📌 Task '{task_id}' done.") except Exception as exc: diff --git a/antarest/core/tasks/web.py b/antarest/core/tasks/web.py index 2d303a4f7b..7ce33e6740 100644 --- a/antarest/core/tasks/web.py +++ b/antarest/core/tasks/web.py @@ -1,13 +1,15 @@ +import concurrent.futures +import http import logging from typing import Any -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from antarest.core.config import Config from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters -from antarest.core.tasks.model import TaskListFilter -from antarest.core.tasks.service import TaskJobService +from antarest.core.tasks.model import TaskDTO, TaskListFilter +from antarest.core.tasks.service import DEFAULT_AWAIT_MAX_TIMEOUT, TaskJobService from antarest.core.utils.web import APITag from antarest.login.auth import Auth @@ -17,13 +19,13 @@ def create_tasks_api(service: TaskJobService, config: Config) -> APIRouter: """ Endpoints login implementation + Args: service: login facade service config: server config - jwt: jwt manager Returns: - + API router """ bp = APIRouter(prefix="/v1") auth = Auth(config) @@ -36,17 +38,46 @@ def list_tasks( request_params = RequestParameters(user=current_user) return service.list_tasks(filter, request_params) - @bp.get("/tasks/{task_id}", tags=[APITag.tasks]) + @bp.get("/tasks/{task_id}", tags=[APITag.tasks], response_model=TaskDTO) def get_task( task_id: str, wait_for_completion: bool = False, with_logs: bool = False, + timeout: int = DEFAULT_AWAIT_MAX_TIMEOUT, current_user: JWTUser = Depends(auth.get_current_user), ) -> Any: + """ + Retrieve information about a specific task. + + Args: + - `task_id`: Unique identifier of the task. + - `wait_for_completion`: Set to `True` to wait for task completion. + - `with_logs`: Set to `True` to retrieve the job logs (Antares Solver logs). + - `timeout`: Maximum time in seconds to wait for task completion. + + Raises: + - 408 REQUEST_TIMEOUT: when the request times out while waiting for task completion. + + Returns: + TaskDTO: Information about the specified task. + """ request_params = RequestParameters(user=current_user) task_status = service.status_task(task_id, request_params, with_logs) + if wait_for_completion and not task_status.status.is_final(): - service.await_task(task_id) + # Ensure 0 <= timeout <= 48 h + timeout = min(max(0, timeout), DEFAULT_AWAIT_MAX_TIMEOUT) + try: + service.await_task(task_id, timeout_sec=timeout) + except concurrent.futures.TimeoutError as exc: # pragma: no cover + # Note that if the task does not complete within the specified time, + # the task will continue running but the user will receive a timeout. + # In this case, it is the user's responsibility to cancel the task. + raise HTTPException( + status_code=http.HTTPStatus.REQUEST_TIMEOUT, + detail="The request timed out while waiting for task completion.", + ) from exc + return service.status_task(task_id, request_params, with_logs) @bp.put("/tasks/{task_id}/cancel", tags=[APITag.tasks]) diff --git a/tests/integration/variant_blueprint/test_thermal_cluster.py b/tests/integration/variant_blueprint/test_thermal_cluster.py index 70e2f8dcba..567b974a80 100644 --- a/tests/integration/variant_blueprint/test_thermal_cluster.py +++ b/tests/integration/variant_blueprint/test_thermal_cluster.py @@ -124,7 +124,7 @@ def test_cascade_update( res = client.get( f"/v1/tasks/{task_id}", headers={"Authorization": f"Bearer {user_access_token}"}, - params={"wait_for_completion": True}, + params={"wait_for_completion": True, "timeout": 10}, ) assert res.status_code == http.HTTPStatus.OK, res.json() task = TaskDTO(**res.json()) @@ -144,4 +144,3 @@ def test_cascade_update( "status": TaskStatus.COMPLETED, "type": "VARIANT_GENERATION", } - From cd15eec62302844aed3e740cb3679f8f92da3846 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 7 Nov 2023 01:19:55 +0100 Subject: [PATCH 10/22] refactor(variant-study-repo): optionally use a user-defined SqlAlchemy session instead of the global `db.session` and improve the documentation --- .../matrixstore/matrix_garbage_collector.py | 2 +- .../study/storage/variantstudy/repository.py | 62 +++++++++++++++---- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/antarest/matrixstore/matrix_garbage_collector.py b/antarest/matrixstore/matrix_garbage_collector.py index 31458692ff..340d038f15 100644 --- a/antarest/matrixstore/matrix_garbage_collector.py +++ b/antarest/matrixstore/matrix_garbage_collector.py @@ -55,7 +55,7 @@ def _get_raw_studies_matrices(self) -> Set[str]: def _get_variant_studies_matrices(self) -> Set[str]: logger.info("Getting all matrices used in variant studies") - command_blocks: List[CommandBlock] = self.variant_study_service.repository.get_all_commandblocks() + command_blocks: List[CommandBlock] = self.variant_study_service.repository.get_all_command_blocks() def transform_to_command(command_dto: CommandDTO, study_ref: str) -> List[ICommand]: try: diff --git a/antarest/study/storage/variantstudy/repository.py b/antarest/study/storage/variantstudy/repository.py index 28dbdfdc3c..b00ab387c9 100644 --- a/antarest/study/storage/variantstudy/repository.py +++ b/antarest/study/storage/variantstudy/repository.py @@ -1,29 +1,65 @@ -from typing import List +import typing as t + +from sqlalchemy.orm import Session # type: ignore from antarest.core.interfaces.cache import ICache from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.study.model import Study from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy class VariantStudyRepository(StudyMetadataRepository): """ - Variant study repository + Variant study repository """ - def __init__(self, cache_service: ICache): + def __init__(self, cache_service: ICache, session: t.Optional[Session] = None): + """ + Initialize the variant study repository. + + Args: + cache_service: Cache service for the repository. + session: Optional SQLAlchemy session to be used. + """ super().__init__(cache_service) + self._session = session - def get_children(self, parent_id: str) -> List[VariantStudy]: - studies: List[VariantStudy] = db.session.query(VariantStudy).filter(VariantStudy.parent_id == parent_id).all() - return studies + @property + def session(self) -> Session: + """ + Get the SQLAlchemy session for the repository. + + Returns: + SQLAlchemy session. + """ + if self._session is None: + # Get or create the session from a context variable (thread local variable) + return db.session + # Get the user-defined session + return self._session - def get_all_commandblocks(self) -> List[CommandBlock]: - outputs = db.session.query(CommandBlock).all() + def get_children(self, parent_id: str) -> t.List[VariantStudy]: + """ + Get the children of a variant study. + + Args: + parent_id: Identifier of the parent study. + + Returns: + List of `VariantStudy` objects. + """ + studies: t.List[VariantStudy] = ( + self.session.query(VariantStudy).filter(VariantStudy.parent_id == parent_id).all() + ) + return studies - # for mypy - assert isinstance(outputs, list) - for output in outputs: - assert isinstance(output, CommandBlock) + def get_all_command_blocks(self) -> t.List[CommandBlock]: + """ + Get all command blocks. - return outputs + Returns: + List of `CommandBlock` objects. + """ + cmd_blocks: t.List[CommandBlock] = self.session.query(CommandBlock).all() + return cmd_blocks From 8262e14897dab55f9dad9bd0a21b289f9f481002 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 7 Nov 2023 01:20:49 +0100 Subject: [PATCH 11/22] feat(variant-study-repo): add the `get_ancestor_or_self_ids` to retrieve all ancestor IDs of a variant study --- .../study/storage/variantstudy/repository.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/antarest/study/storage/variantstudy/repository.py b/antarest/study/storage/variantstudy/repository.py index b00ab387c9..8a20108021 100644 --- a/antarest/study/storage/variantstudy/repository.py +++ b/antarest/study/storage/variantstudy/repository.py @@ -54,6 +54,30 @@ def get_children(self, parent_id: str) -> t.List[VariantStudy]: ) return studies + def get_ancestor_or_self_ids(self, variant_id: str) -> t.Sequence[str]: + """ + Retrieve the list of ancestor variant identifiers, including the `variant_id`, + its parent, and all predecessors of the parent, up to and including the ID + of the root study (`RawStudy`). + + Args: + variant_id: Unique identifier of the child variant. + + Returns: + Ordered list of study identifiers. + """ + # see: [Recursive Queries](https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE) + top_q = self.session.query(Study.id, Study.parent_id) + top_q = top_q.filter(Study.id == variant_id) + top_q = top_q.cte("study_cte", recursive=True) + + bot_q = self.session.query(Study.id, Study.parent_id) + bot_q = bot_q.join(top_q, Study.id == top_q.c.parent_id) + + recursive_q = top_q.union_all(bot_q) + q = self.session.query(recursive_q) + return [r[0] for r in q] + def get_all_command_blocks(self) -> t.List[CommandBlock]: """ Get all command blocks. From 4c9c081da03bde9ce3c1b202562d485099cb7f35 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 8 Nov 2023 15:33:40 +0100 Subject: [PATCH 12/22] style(db): add type hints in database model classes --- .../storage/variantstudy/model/dbmodel.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/antarest/study/storage/variantstudy/model/dbmodel.py b/antarest/study/storage/variantstudy/model/dbmodel.py index 235cd28472..3026115ce4 100644 --- a/antarest/study/storage/variantstudy/model/dbmodel.py +++ b/antarest/study/storage/variantstudy/model/dbmodel.py @@ -1,4 +1,6 @@ +import datetime import json +import typing as t import uuid from dataclasses import dataclass @@ -18,13 +20,13 @@ class VariantStudySnapshot(Base): # type: ignore __tablename__ = "variant_study_snapshot" - id = Column( + id: str = Column( String(36), ForeignKey("variantstudy.id"), primary_key=True, ) - created_at = Column(DateTime) - last_executed_command = Column(String(), nullable=True) + created_at: datetime.date = Column(DateTime) + last_executed_command: t.Optional[str] = Column(String(), nullable=True) __mapper_args__ = { "polymorphic_identity": "variant_study_snapshot", @@ -42,17 +44,17 @@ class CommandBlock(Base): # type: ignore __tablename__ = "commandblock" - id = Column( + id: str = Column( String(36), primary_key=True, default=lambda: str(uuid.uuid4()), unique=True, ) - study_id = Column(String(36), ForeignKey("variantstudy.id")) - index = Column(Integer) - command = Column(String(255)) - version = Column(Integer) - args = Column(String()) + study_id: str = Column(String(36), ForeignKey("variantstudy.id")) + index: int = Column(Integer) + command: str = Column(String(255)) + version: int = Column(Integer) + args: str = Column(String()) def to_dto(self) -> CommandDTO: return CommandDTO(id=self.id, action=self.command, args=json.loads(self.args)) @@ -66,12 +68,12 @@ class VariantStudy(Study): __tablename__ = "variantstudy" - id = Column( + id: str = Column( String(36), ForeignKey("study.id"), primary_key=True, ) - generation_task = Column(String(), nullable=True) + generation_task: t.Optional[str] = Column(String(), nullable=True) __mapper_args__ = { "polymorphic_identity": "variantstudy", From 21823bae235a8659aa1d2452df76b751ff4b0b05 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Wed, 8 Nov 2023 15:41:08 +0100 Subject: [PATCH 13/22] test(hydro): improves testing of hydraulic allocation variants --- .../test_hydro_allocation.py | 82 +++++++++---------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/tests/integration/study_data_blueprint/test_hydro_allocation.py b/tests/integration/study_data_blueprint/test_hydro_allocation.py index a2d128141b..d4b04dd332 100644 --- a/tests/integration/study_data_blueprint/test_hydro_allocation.py +++ b/tests/integration/study_data_blueprint/test_hydro_allocation.py @@ -1,11 +1,10 @@ -from http import HTTPStatus -from typing import List +import http +import typing as t import pytest from starlette.testclient import TestClient from antarest.study.business.area_management import AreaInfoDTO, AreaType -from tests.integration.utils import wait_for @pytest.mark.unit_test @@ -22,14 +21,14 @@ def test_get_allocation_form_values( client: TestClient, user_access_token: str, study_id: str, - ): + ) -> None: """Check `get_allocation_form_values` end point""" area_id = "de" res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/hydro/allocation/form", headers={"Authorization": f"Bearer {user_access_token}"}, ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() actual = res.json() expected = {"allocation": [{"areaId": "de", "coefficient": 1.0}]} assert actual == expected @@ -39,45 +38,42 @@ def test_get_allocation_form_values__variant( client: TestClient, user_access_token: str, study_id: str, - ): + ) -> None: """ The purpose of this test is to check that we can get the form parameters from a study variant. To prepare this test, we start from a RAW study, copy it to the managed study workspace and then create a variant from this managed workspace. """ - # Execute the job to copy the study to the workspace + # Create a managed study from the RAW study. res = client.post( - f"/v1/studies/{study_id}/copy?dest=Clone&with_outputs=false", + f"/v1/studies/{study_id}/copy", headers={"Authorization": f"Bearer {user_access_token}"}, + params={"dest": "Clone", "with_outputs": False, "use_task": False}, ) - res.raise_for_status() - task_id = res.json() + assert res.status_code == http.HTTPStatus.CREATED, res.json() + managed_id = res.json() + assert managed_id is not None - # wait for the job to finish - def copy_task_done() -> bool: - r = client.get( - f"/v1/tasks/{task_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) - return r.json()["status"] == 3 - - wait_for(copy_task_done, sleep_time=0.2) - - # Get the job result to retrieve the study ID + # Ensure the managed study has the same allocation form as the RAW study. + area_id = "de" res = client.get( - f"/v1/tasks/{task_id}", + f"/v1/studies/{managed_id}/areas/{area_id}/hydro/allocation/form", headers={"Authorization": f"Bearer {user_access_token}"}, ) - res.raise_for_status() - managed_id = res.json()["result"]["return_value"] + assert res.status_code == http.HTTPStatus.OK, res.json() + actual = res.json() + expected = {"allocation": [{"areaId": "de", "coefficient": 1.0}]} + assert actual == expected # create a variant study from the managed study res = client.post( - f"/v1/studies/{managed_id}/variants?name=foo", + f"/v1/studies/{managed_id}/variants", headers={"Authorization": f"Bearer {user_access_token}"}, + params={"name": "foo"}, ) - res.raise_for_status() + assert res.status_code == http.HTTPStatus.OK, res.json() # should be CREATED variant_id = res.json() + assert variant_id is not None # get allocation form area_id = "de" @@ -85,7 +81,7 @@ def copy_task_done() -> bool: f"/v1/studies/{variant_id}/areas/{area_id}/hydro/allocation/form", headers={"Authorization": f"Bearer {user_access_token}"}, ) - res.raise_for_status() + assert res.status_code == http.HTTPStatus.OK, res.json() actual = res.json() expected = {"allocation": [{"areaId": "de", "coefficient": 1.0}]} assert actual == expected @@ -115,14 +111,14 @@ def test_get_allocation_matrix( user_access_token: str, study_id: str, area_id: str, - expected: List[List[float]], - ): + expected: t.List[t.List[float]], + ) -> None: """Check `get_allocation_matrix` end point""" res = client.get( f"/v1/studies/{study_id}/areas/hydro/allocation/matrix", headers={"Authorization": f"Bearer {user_access_token}"}, ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() actual = res.json() assert actual == expected @@ -131,7 +127,7 @@ def test_set_allocation_form_values( client: TestClient, user_access_token: str, study_id: str, - ): + ) -> None: """Check `set_allocation_form_values` end point""" area_id = "de" expected = { @@ -145,16 +141,17 @@ def test_set_allocation_form_values( headers={"Authorization": f"Bearer {user_access_token}"}, json=expected, ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() actual = res.json() assert actual == expected # check that the values are updated res = client.get( - f"/v1/studies/{study_id}/raw?path=input/hydro/allocation&depth=3", + f"/v1/studies/{study_id}/raw", headers={"Authorization": f"Bearer {user_access_token}"}, + params={"path": "input/hydro/allocation", "depth": 3}, ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() actual = res.json() expected = { "de": {"[allocation]": {"de": 3.0, "es": 1.0}}, @@ -164,7 +161,7 @@ def test_set_allocation_form_values( } assert actual == expected - def test_create_area(self, client: TestClient, user_access_token: str, study_id: str): + def test_create_area(self, client: TestClient, user_access_token: str, study_id: str) -> None: """ Given a study, when an area is created, the hydraulic allocation column for this area must be updated with the following values: @@ -178,13 +175,13 @@ def test_create_area(self, client: TestClient, user_access_token: str, study_id: headers={"Authorization": f"Bearer {user_access_token}"}, data=area_info.json(), ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() res = client.get( f"/v1/studies/{study_id}/areas/hydro/allocation/matrix", headers={"Authorization": f"Bearer {user_access_token}"}, ) - assert res.status_code == HTTPStatus.OK + assert res.status_code == http.HTTPStatus.OK actual = res.json() expected = { "columns": ["de", "es", "fr", "it", "north"], @@ -199,7 +196,7 @@ def test_create_area(self, client: TestClient, user_access_token: str, study_id: } assert actual == expected - def test_delete_area(self, client: TestClient, user_access_token: str, study_id: str): + def test_delete_area(self, client: TestClient, user_access_token: str, study_id: str) -> None: """ Given a study, when an area is deleted, the hydraulic allocation column for this area must be removed. @@ -214,11 +211,12 @@ def test_delete_area(self, client: TestClient, user_access_token: str, study_id: } for prod_area, allocation_cfg in obj.items(): res = client.post( - f"/v1/studies/{study_id}/raw?path=input/hydro/allocation/{prod_area}", + f"/v1/studies/{study_id}/raw", headers={"Authorization": f"Bearer {user_access_token}"}, + params={"path": f"input/hydro/allocation/{prod_area}"}, json=allocation_cfg, ) - assert res.status_code == HTTPStatus.NO_CONTENT, res.json() + assert res.status_code == http.HTTPStatus.NO_CONTENT, res.json() # Then we remove the "fr" zone. # The deletion should update the allocation matrix of all other zones. @@ -226,7 +224,7 @@ def test_delete_area(self, client: TestClient, user_access_token: str, study_id: f"/v1/studies/{study_id}/areas/fr", headers={"Authorization": f"Bearer {user_access_token}"}, ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() # Check that the "fr" column is removed from the hydraulic allocation matrix. # The row corresponding to "fr" must also be deleted. @@ -234,7 +232,7 @@ def test_delete_area(self, client: TestClient, user_access_token: str, study_id: f"/v1/studies/{study_id}/areas/hydro/allocation/matrix", headers={"Authorization": f"Bearer {user_access_token}"}, ) - assert res.status_code == HTTPStatus.OK, res.json() + assert res.status_code == http.HTTPStatus.OK, res.json() actual = res.json() expected = { "columns": ["de", "es", "it"], From 94d3cdb15b1fe8e3b12612ba832eda6c07366525 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 9 Nov 2023 10:10:59 +0100 Subject: [PATCH 14/22] docs(cache): add the documentation of cache constants --- antarest/core/interfaces/cache.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/antarest/core/interfaces/cache.py b/antarest/core/interfaces/cache.py index a85fd5525f..33d5fa541e 100644 --- a/antarest/core/interfaces/cache.py +++ b/antarest/core/interfaces/cache.py @@ -6,6 +6,23 @@ class CacheConstants(Enum): + """ + Constants used to identify cache entries. + + - `RAW_STUDY`: variable used to store JSON (or bytes) objects. + This cache is used by the `RawStudyService` or `VariantStudyService` to store + values that are retrieved from the filesystem. + Note: that it is unlikely that this cache is used, because it is only used + to fetch data inside a study when the URL is "" and the depth is -1. + + - `STUDY_FACTORY`: variable used to store objects of type `FileStudyTreeConfigDTO`. + This cache is used by the `create_from_fs` function when retrieving the configuration + of a study from the data on the disk. + + - `STUDY_LISTING`: variable used to store objects of type `StudyMetadataDTO`. + This cache is used by the `get_studies_information` function to store the list of studies. + """ + RAW_STUDY = "RAW_STUDY" STUDY_FACTORY = "STUDY_FACTORY" STUDY_LISTING = "STUDY_LISTING" From ea69c38574816999046faf0fda8122c559457470 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 9 Nov 2023 14:11:35 +0100 Subject: [PATCH 15/22] feat(permission): add the `assert_permission_on_studies` to check user permission of several studies --- antarest/study/storage/utils.py | 99 ++++++++++++++++++++------------- tests/storage/test_service.py | 92 +++++++++++++++++++++++++++++- 2 files changed, 151 insertions(+), 40 deletions(-) diff --git a/antarest/study/storage/utils.py b/antarest/study/storage/utils.py index 6197180f50..1c058c8c6f 100644 --- a/antarest/study/storage/utils.py +++ b/antarest/study/storage/utils.py @@ -1,13 +1,13 @@ import calendar import logging +import math import os import shutil import tempfile import time +import typing as t from datetime import datetime, timedelta -from math import ceil from pathlib import Path -from typing import Callable, List, Optional, Union, cast from uuid import uuid4 from zipfile import ZipFile @@ -123,7 +123,7 @@ def find_single_output_path(all_output_path: Path) -> Path: return all_output_path -def extract_output_name(path_output: Path, new_suffix_name: Optional[str] = None) -> str: +def extract_output_name(path_output: Path, new_suffix_name: t.Optional[str] = None) -> str: ini_reader = IniReader() is_output_archived = path_output.suffix == ".zip" if is_output_archived: @@ -171,7 +171,7 @@ def remove_from_cache(cache: ICache, root_id: str) -> None: def create_new_empty_study(version: str, path_study: Path, path_resources: Path) -> None: - version_template: Optional[str] = STUDY_REFERENCE_TEMPLATES.get(version, None) + version_template: t.Optional[str] = STUDY_REFERENCE_TEMPLATES.get(version, None) if version_template is None: raise UnsupportedStudyVersion(version) @@ -182,8 +182,8 @@ def create_new_empty_study(version: str, path_study: Path, path_resources: Path) def study_matcher( - name: Optional[str], workspace: Optional[str], folder: Optional[str] -) -> Callable[[StudyMetadataDTO], bool]: + name: t.Optional[str], workspace: t.Optional[str], folder: t.Optional[str] +) -> t.Callable[[StudyMetadataDTO], bool]: def study_match(study: StudyMetadataDTO) -> bool: if name and not study.name.startswith(name): return False @@ -196,14 +196,55 @@ def study_match(study: StudyMetadataDTO) -> bool: return study_match +def assert_permission_on_studies( + user: t.Optional[JWTUser], + studies: t.Sequence[t.Union[Study, StudyMetadataDTO]], + permission_type: StudyPermissionType, + *, + raising: bool = True, +) -> bool: + """ + Asserts whether the provided user has the required permissions on the given studies. + + Args: + user: The user whose permissions need to be verified. + studies: The studies for which permissions need to be verified. + permission_type: The type of permission to be checked for the user. + raising: If set to `True`, raises `UserHasNotPermissionError` when the permission check fails. + + Returns: + `True` if the user has the required permissions, `False` otherwise. + + Raises: + `UserHasNotPermissionError`: If the raising parameter is set to `True` + and the user does not have the required permissions. + """ + if not user: + logger.error("FAIL permission: user is not logged") + raise UserHasNotPermissionError() + msg = { + 0: f"FAIL permissions: user '{user}' has no access to any study", + 1: f"FAIL permissions: user '{user}' does not have {permission_type.value} permission on {studies[0].id}", + 2: f"FAIL permissions: user '{user}' does not have {permission_type.value} permission on all studies", + }[min(len(studies), 2)] + infos = (PermissionInfo.from_study(study) for study in studies) + if any(not check_permission(user, permission_info, permission_type) for permission_info in infos): + logger.error(msg) + if raising: + raise UserHasNotPermissionError(msg) + return False + return True + + def assert_permission( - user: Optional[JWTUser], - study: Optional[Union[Study, StudyMetadataDTO]], + user: t.Optional[JWTUser], + study: t.Optional[t.Union[Study, StudyMetadataDTO]], permission_type: StudyPermissionType, raising: bool = True, ) -> bool: """ Assert user has permission to edit or read study. + Args: user: user logged study: study asked @@ -211,27 +252,9 @@ def assert_permission( raising: raise error if permission not matched Returns: true if permission match, false if not raising. - """ - if not user: - logger.error("FAIL permission: user is not logged") - raise UserHasNotPermissionError() - - if not study: - logger.error("FAIL permission: study not exist") - raise ValueError("Metadata is None") - - permission_info = PermissionInfo.from_study(study) - ok = check_permission(user, permission_info, permission_type) - if raising and not ok: - logger.error( - "FAIL permission: user %d has no permission on study %s", - user.id, - study.id, - ) - raise UserHasNotPermissionError() - - return ok + studies = [study] if study else [] + return assert_permission_on_studies(user, studies, permission_type, raising=raising) MATRIX_INPUT_DAYS_COUNT = 365 @@ -264,7 +287,7 @@ def assert_permission( def get_start_date( file_study: FileStudy, - output_id: Optional[str] = None, + output_id: t.Optional[str] = None, level: StudyDownloadLevelDTO = StudyDownloadLevelDTO.HOURLY, ) -> MatrixIndex: """ @@ -277,12 +300,12 @@ def get_start_date( """ config = FileStudyHelpers.get_config(file_study, output_id)["general"] - starting_month = cast(str, config.get("first-month-in-year")) - starting_day = cast(str, config.get("january.1st")) - leapyear = cast(bool, config.get("leapyear")) - first_week_day = cast(str, config.get("first.weekday")) - start_offset = cast(int, config.get("simulation.start")) - end = cast(int, config.get("simulation.end")) + starting_month = t.cast(str, config.get("first-month-in-year")) + starting_day = t.cast(str, config.get("january.1st")) + leapyear = t.cast(bool, config.get("leapyear")) + first_week_day = t.cast(str, config.get("first.weekday")) + start_offset = t.cast(int, config.get("simulation.start")) + end = t.cast(int, config.get("simulation.end")) starting_month_index = MONTHS.index(starting_month.title()) + 1 starting_day_index = DAY_NAMES.index(starting_day.title()) @@ -303,7 +326,7 @@ def get_start_date( elif level == StudyDownloadLevelDTO.ANNUAL: steps = 1 elif level == StudyDownloadLevelDTO.WEEKLY: - steps = ceil(steps / 7) + steps = math.ceil(steps / 7) elif level == StudyDownloadLevelDTO.MONTHLY: end_date = start_date + timedelta(days=steps) same_year = end_date.year == start_date.year @@ -333,9 +356,9 @@ def export_study_flat( dest: Path, study_factory: StudyFactory, outputs: bool = True, - output_list_filter: Optional[List[str]] = None, + output_list_filter: t.Optional[t.List[str]] = None, denormalize: bool = True, - output_src_path: Optional[Path] = None, + output_src_path: t.Optional[Path] = None, ) -> None: start_time = time.time() diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index fbb36e04ea..bc5dac87d1 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -8,6 +8,7 @@ from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from antarest.core.config import Config, StorageConfig, WorkspaceConfig from antarest.core.exceptions import TaskAlreadyRunning @@ -21,7 +22,7 @@ from antarest.core.roles import RoleType from antarest.core.tasks.model import TaskDTO, TaskStatus, TaskType from antarest.core.utils.fastapi_sqlalchemy import db -from antarest.login.model import Group, GroupDTO, User +from antarest.login.model import Group, GroupDTO, Role, User from antarest.login.service import LoginService from antarest.matrixstore.service import MatrixService from antarest.study.model import ( @@ -60,7 +61,7 @@ from antarest.study.storage.rawstudy.model.filesystem.raw_file_node import RawFileNode from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import FileStudyTree from antarest.study.storage.rawstudy.raw_study_service import RawStudyService -from antarest.study.storage.utils import assert_permission, study_matcher +from antarest.study.storage.utils import assert_permission, assert_permission_on_studies, study_matcher from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants from antarest.study.storage.variantstudy.model.command_context import CommandContext from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy @@ -873,6 +874,93 @@ def test_assert_permission() -> None: assert assert_permission(jwt_2, study, StudyPermissionType.READ) +def test_assert_permission_on_studies(db_session: Session) -> None: + # Given the following user groups : + user_groups = [ + { + "name": "admin", + "role": RoleType.ADMIN, + "users": ["admin"], + }, + { + "name": "Writers", + "role": RoleType.WRITER, + "users": ["John", "Jane", "Jack"], + }, + { + "name": "Readers", + "role": RoleType.READER, + "users": ["Rita", "Ralph"], + }, + ] + + # Create the JWTGroup and JWTUser objects + jwt_groups = {} + jwt_users = {} + users_sequence = 2 # first non-admin user ID + for group in user_groups: + group_name = group["name"] + jwt_groups[group_name] = JWTGroup(id=group_name, name=group_name, role=group["role"]) + for user_name in group["users"]: + if user_name == "admin": + user_id = 1 + else: + user_id = users_sequence + users_sequence += 1 + jwt_users[user_name] = JWTUser( + id=user_id, + impersonator=user_id, + type="users", + groups=[jwt_groups[group_name]], + ) + + # Create the users and groups in the database + with db_session: + for group_name, jwt_group in jwt_groups.items(): + db_session.add(Group(id=jwt_group.id, name=group_name)) + for user_name, jwt_user in jwt_users.items(): + db_session.add(User(id=jwt_user.id, name=user_name)) + db_session.commit() + + for user in db_session.query(User): + user_jwt_groups = jwt_users[user.name].groups + for user_jwt_group in user_jwt_groups: + db_session.add(Role(type=user_jwt_group.role, identity_id=user.id, group_id=user_jwt_group.id)) + db_session.commit() + + # John creates a main study and Jane creates two variant studies. + # They all belong to the same group. + writers = db_session.query(Group).filter(Group.name == "Writers").one() + studies = [ + Study(id=uuid4(), name="Main Study", owner_id=jwt_users["John"].id, groups=[writers]), + Study(id=uuid4(), name="Variant Study 1", owner_id=jwt_users["Jane"].id, groups=[writers]), + Study(id=uuid4(), name="Variant Study 2", owner_id=jwt_users["Jane"].id, groups=[writers]), + ] + + # All admin and writers should have WRITE access to the studies. + # Other members of the group should have no access. + for user_name, jwt_user in jwt_users.items(): + has_access = any(jwt_group.name in {"admin", "Writers"} for jwt_group in jwt_user.groups) + actual = assert_permission_on_studies(jwt_user, studies, StudyPermissionType.WRITE, raising=False) + assert actual == has_access + + # Jack creates a additional variant study and adds it to the readers and writers groups. + readers = db_session.query(Group).filter(Group.name == "Readers").one() + studies.append(Study(id=uuid4(), name="Variant Study 3", owner_id=jwt_users["Jack"].id, groups=[readers, writers])) + + # All admin and writers should have READ access to the studies. + # Other members of the group should have no access, because they don't have access to the writers-only studies. + for user_name, jwt_user in jwt_users.items(): + has_access = any(jwt_group.name in {"admin", "Writers"} for jwt_group in jwt_user.groups) + actual = assert_permission_on_studies(jwt_user, studies, StudyPermissionType.READ, raising=False) + assert actual == has_access + + # Everybody should have access to the last study, because it is in the readers and writers group. + for user_name, jwt_user in jwt_users.items(): + actual = assert_permission_on_studies(jwt_user, studies[-1:], StudyPermissionType.READ, raising=False) + assert actual + + @pytest.mark.unit_test def test_delete_study_calls_callback(tmp_path: Path): study_uuid = "my_study" From e4f7756a871efe763bb559ba7524fb258aa5532f Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 9 Nov 2023 22:17:39 +0100 Subject: [PATCH 16/22] feat(db): add methods to check the variant study snapshot status --- .../storage/variantstudy/model/dbmodel.py | 16 +- .../storage/variantstudy/model/__init__.py | 0 .../variantstudy/model/test_dbmodel.py | 297 ++++++++++++++++++ 3 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 tests/study/storage/variantstudy/model/__init__.py create mode 100644 tests/study/storage/variantstudy/model/test_dbmodel.py diff --git a/antarest/study/storage/variantstudy/model/dbmodel.py b/antarest/study/storage/variantstudy/model/dbmodel.py index 3026115ce4..3e547bce13 100644 --- a/antarest/study/storage/variantstudy/model/dbmodel.py +++ b/antarest/study/storage/variantstudy/model/dbmodel.py @@ -3,6 +3,7 @@ import typing as t import uuid from dataclasses import dataclass +from pathlib import Path from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table # type: ignore from sqlalchemy.orm import relationship # type: ignore @@ -33,7 +34,7 @@ class VariantStudySnapshot(Base): # type: ignore } def __str__(self) -> str: - return f"[Snapshot]: id={self.id}, created_at={self.created_at}" + return f"[Snapshot] id={self.id}, created_at={self.created_at}" @dataclass @@ -92,3 +93,16 @@ class VariantStudy(Study): def __str__(self) -> str: return super().__str__() + f", snapshot={self.snapshot}" + + @property + def snapshot_dir(self) -> Path: + """Get the path of the snapshot directory.""" + return Path(self.path) / "snapshot" + + def is_snapshot_recent(self) -> bool: + """Check if the snapshot exists and is up-to-date.""" + return ( + (self.snapshot is not None) + and (self.snapshot.created_at >= self.updated_at) + and (self.snapshot_dir / "study.antares").is_file() + ) diff --git a/tests/study/storage/variantstudy/model/__init__.py b/tests/study/storage/variantstudy/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/study/storage/variantstudy/model/test_dbmodel.py b/tests/study/storage/variantstudy/model/test_dbmodel.py new file mode 100644 index 0000000000..0bcd107518 --- /dev/null +++ b/tests/study/storage/variantstudy/model/test_dbmodel.py @@ -0,0 +1,297 @@ +import datetime +import json +import typing as t +import uuid +from pathlib import Path + +import pytest +from sqlalchemy.orm import Session # type: ignore + +from antarest.core.model import PublicMode +from antarest.core.roles import RoleType +from antarest.login.model import Group, Role, User +from antarest.study.model import RawStudy, StudyAdditionalData +from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy, VariantStudySnapshot + + +@pytest.fixture(name="user_id") +def fixture_user_id(db_session: Session) -> int: + with db_session: + user_id = 0o007 + james = User(id=user_id, name="James Bond") + role = Role( + type=RoleType.WRITER, + identity=james, + group=Group(id="writers"), + ) + db_session.add(role) + db_session.commit() + return user_id + + +@pytest.fixture(name="raw_study_id") +def fixture_raw_study_id(tmp_path: Path, db_session: Session, user_id: int) -> str: + with db_session: + root_study_id = str(uuid.uuid4()) + root_study = RawStudy( + id=root_study_id, + workspace="default", + path=str(tmp_path.joinpath("root_study")), + version="860", + created_at=datetime.datetime.utcnow(), + updated_at=datetime.datetime.utcnow(), + additional_data=StudyAdditionalData(author="john.doe"), + owner_id=user_id, + ) + db_session.add(root_study) + db_session.commit() + return root_study_id + + +@pytest.fixture(name="variant_study_id") +def fixture_variant_study_id(tmp_path: Path, db_session: Session, raw_study_id: str, user_id: int) -> str: + with db_session: + variant_study_id = str(uuid.uuid4()) + variant = VariantStudy( + id=variant_study_id, + name="Variant Study", + version="860", + author="John DOE", + parent_id=raw_study_id, + created_at=datetime.datetime.utcnow() - datetime.timedelta(days=1), + updated_at=datetime.datetime.utcnow(), + last_access=datetime.datetime.utcnow(), + path=str(tmp_path.joinpath("variant_study")), + owner_id=user_id, + ) + db_session.add(variant) + db_session.commit() + return variant_study_id + + +class TestVariantStudySnapshot: + def test_init__without_command(self, db_session: Session, variant_study_id: str) -> None: + """ + Check the creation of an instance of VariantStudySnapshot + """ + now = datetime.datetime.utcnow() + + with db_session: + snap = VariantStudySnapshot(id=variant_study_id, created_at=now) + db_session.add(snap) + db_session.commit() + + obj: VariantStudySnapshot = ( + db_session.query(VariantStudySnapshot).filter(VariantStudySnapshot.id == variant_study_id).one() + ) + + # check Study representation + assert str(obj).startswith(f"[Snapshot] id={variant_study_id}") + + # check Study fields + assert obj.id == variant_study_id + assert obj.created_at == now + assert obj.last_executed_command is None + + def test_init__with_command(self, db_session: Session, variant_study_id: str) -> None: + """ + Check the creation of an instance of VariantStudySnapshot + """ + now = datetime.datetime.utcnow() + command_id = str(uuid.uuid4()) + + with db_session: + snap = VariantStudySnapshot(id=variant_study_id, created_at=now, last_executed_command=command_id) + db_session.add(snap) + db_session.commit() + + obj: VariantStudySnapshot = ( + db_session.query(VariantStudySnapshot).filter(VariantStudySnapshot.id == variant_study_id).one() + ) + assert obj.id == variant_study_id + assert obj.created_at == now + assert obj.last_executed_command == command_id + + +class TestCommandBlock: + def test_init(self, db_session: Session, variant_study_id: str) -> None: + """ + Check the creation of an instance of CommandBlock + """ + command_id = str(uuid.uuid4()) + index = 7 + command = "dummy command" + version = 42 + args = '{"foo": "bar"}' + + with db_session: + block = CommandBlock( + id=command_id, + study_id=variant_study_id, + index=index, + command=command, + version=version, + args=args, + ) + db_session.add(block) + db_session.commit() + + obj: CommandBlock = db_session.query(CommandBlock).filter(CommandBlock.id == command_id).one() + + # check CommandBlock representation + assert str(obj).startswith(f"CommandBlock(id={command_id!r}") + + # check CommandBlock fields + assert obj.id == command_id + assert obj.study_id == variant_study_id + assert obj.index == index + assert obj.command == command + assert obj.version == version + assert obj.args == args + + # check CommandBlock.to_dto() + dto = obj.to_dto() + # note: it is easier to compare the dict representation of the DTO + assert dto.dict() == { + "id": command_id, + "action": command, + "args": json.loads(args), + "version": 1, + } + + +class TestVariantStudy: + def test_init__without_snapshot(self, db_session: Session, raw_study_id: str, user_id: int) -> None: + """ + Check the creation of an instance of variant study without snapshot + """ + now = datetime.datetime.utcnow() + variant_study_id = str(uuid.uuid4()) + variant_study_path = "path/to/variant" + + with db_session: + variant = VariantStudy( + id=variant_study_id, + name="Variant Study", + version="860", + author="John DOE", + parent_id=raw_study_id, + created_at=now - datetime.timedelta(days=1), + updated_at=now, + last_access=now, + path=variant_study_path, + owner_id=user_id, + ) + db_session.add(variant) + db_session.commit() + + obj: VariantStudy = db_session.query(VariantStudy).filter(VariantStudy.id == variant_study_id).one() + + # check Study representation + assert str(obj).startswith(f"[Study] id={variant_study_id}") + + # check Study fields + assert obj.id == variant_study_id + assert obj.name == "Variant Study" + assert obj.type == "variantstudy" + assert obj.version == "860" + assert obj.author == "John DOE" + assert obj.created_at == now - datetime.timedelta(days=1) + assert obj.updated_at == now + assert obj.last_access == now + assert obj.path == variant_study_path + assert obj.folder is None + assert obj.parent_id == raw_study_id + assert obj.public_mode == PublicMode.NONE + assert obj.owner_id == user_id + assert obj.archived is False + assert obj.groups == [] + assert obj.additional_data is None + + # check Variant-specific fields + assert obj.generation_task is None + assert obj.snapshot is None + assert obj.commands == [] + + # check Variant-specific properties + assert obj.snapshot_dir == Path(variant_study_path).joinpath("snapshot") + assert obj.is_snapshot_recent() is False + + @pytest.mark.parametrize( + "created_at, updated_at, study_antares_file, expected", + [ + pytest.param( + datetime.datetime(2023, 11, 9), + datetime.datetime(2023, 11, 8), + "study.antares", + True, + id="with-recent-snapshot", + ), + pytest.param( + datetime.datetime(2023, 11, 7), + datetime.datetime(2023, 11, 8), + "study.antares", + False, + id="with-old-snapshot", + ), + pytest.param( + datetime.datetime(2023, 11, 9), + datetime.datetime(2023, 11, 8), + "dirty.antares", + False, + id="with-dirty-snapshot", + ), + pytest.param( + None, + datetime.datetime(2023, 11, 8), + "study.antares", + False, + id="without-snapshot", + ), + ], + ) + def test_is_snapshot_recent( + self, + db_session: Session, + tmp_path: Path, + raw_study_id: int, + user_id: int, + created_at: t.Optional[datetime.datetime], + updated_at: datetime.datetime, + study_antares_file: str, + expected: bool, + ) -> None: + """ + Check the snapshot_uptodate() method + """ + with db_session: + # Given a variant study (referencing the raw study) + # with optionally a snapshot and a snapshot directory + variant_id = str(uuid.uuid4()) + variant = VariantStudy( + id=variant_id, + name="Study 3.0", + author="Sandrine", + parent_id=raw_study_id, + updated_at=updated_at, + path=str(tmp_path.joinpath("variant")), + owner_id=user_id, + ) + + # If the snapshot creation date is given, we create a snapshot + # and a snapshot directory. + if created_at: + variant.snapshot = VariantStudySnapshot(created_at=created_at) + variant.snapshot_dir.mkdir(parents=True, exist_ok=True) + + # If the "study.antares" file is given, we create it in the snapshot directory. + if study_antares_file: + variant.snapshot_dir.mkdir(parents=True, exist_ok=True) + (variant.snapshot_dir / study_antares_file).touch() + + db_session.add(variant) + db_session.commit() + + # Check the snapshot_uptodate() method + obj: VariantStudy = db_session.query(VariantStudy).filter(VariantStudy.id == variant_id).one() + assert obj.is_snapshot_recent() == expected From 70595c982f7b647815b3efe5a78938235d59c60d Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 10 Nov 2023 17:24:14 +0100 Subject: [PATCH 17/22] feat(variant): add the `SnapshotGenerator` class helper class used to generate snapshots for variant studies. --- .../variantstudy/snapshot_generator.py | 275 +++++ .../variantstudy/test_snapshot_generator.py | 1046 +++++++++++++++++ 2 files changed, 1321 insertions(+) create mode 100644 antarest/study/storage/variantstudy/snapshot_generator.py create mode 100644 tests/study/storage/variantstudy/test_snapshot_generator.py diff --git a/antarest/study/storage/variantstudy/snapshot_generator.py b/antarest/study/storage/variantstudy/snapshot_generator.py new file mode 100644 index 0000000000..168b9add96 --- /dev/null +++ b/antarest/study/storage/variantstudy/snapshot_generator.py @@ -0,0 +1,275 @@ +""" +This module dedicated to variant snapshot generation. +""" +import datetime +import logging +import shutil +import tempfile +import typing as t +from pathlib import Path + +from antarest.core.exceptions import VariantGenerationError +from antarest.core.interfaces.cache import CacheConstants, ICache +from antarest.core.jwt import JWTUser +from antarest.core.model import StudyPermissionType +from antarest.core.tasks.service import TaskUpdateNotifier, noop_notifier +from antarest.study.model import RawStudy, StudyAdditionalData +from antarest.study.storage.patch_service import PatchService +from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfigDTO +from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy, StudyFactory +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.utils import assert_permission_on_studies, export_study_flat +from antarest.study.storage.variantstudy.command_factory import CommandFactory +from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy, VariantStudySnapshot +from antarest.study.storage.variantstudy.model.model import GenerationResultInfoDTO +from antarest.study.storage.variantstudy.repository import VariantStudyRepository +from antarest.study.storage.variantstudy.variant_command_generator import VariantCommandGenerator + +logger = logging.getLogger(__name__) + + +OUTPUT_RELATIVE_PATH = "output" + + +class SnapshotGenerator: + """ + Helper class used to generate snapshots for variant studies. + """ + + def __init__( + self, + cache: ICache, + raw_study_service: RawStudyService, + command_factory: CommandFactory, + study_factory: StudyFactory, + patch_service: PatchService, + repository: VariantStudyRepository, + ): + self.cache = cache + self.raw_study_service = raw_study_service + self.command_factory = command_factory + self.study_factory = study_factory + self.patch_service = patch_service + self.repository = repository + # Temporary directory used to generate the snapshot + self._tmp_dir: Path = Path() + + def generate_snapshot( + self, + variant_study_id: str, + jwt_user: JWTUser, + *, + denormalize: bool = True, + from_scratch: bool = False, + notifier: TaskUpdateNotifier = noop_notifier, + ) -> GenerationResultInfoDTO: + # ATTENTION: since we are making changes to disk, a file lock is needed. + # The locking is currently done in the `VariantStudyService.generate_task` function + # when starting the task. However, it is not enough, because the snapshot generation + # need to read the root study or a snapshot of a variant study which may be modified + # during the task. Ideally, we should lock the root study and all its descendants, + # but it is not currently possible to lock studies. + # The locking done at the task level nevertheless makes it possible to limit the risks. + + logger.info(f"Generating variant study snapshot for '{variant_study_id}'") + + root_study, descendants = self._retrieve_descendants(variant_study_id) + assert_permission_on_studies(jwt_user, [root_study, *descendants], StudyPermissionType.READ, raising=True) + ref_study, cmd_blocks = search_ref_study(root_study, descendants, from_scratch=from_scratch) + + # We are going to generate the snapshot in a temporary directory which will be renamed + # at the end of the process. This prevents incomplete snapshots in case of error. + + # Get snapshot directory and prepare a temporary directory next to it. + variant_study = descendants[-1] + snapshot_dir = variant_study.snapshot_dir + snapshot_dir.parent.mkdir(parents=True, exist_ok=True) + self._tmp_dir = Path(tempfile.mkdtemp(dir=snapshot_dir.parent, prefix=f"~{snapshot_dir.name}", suffix=".tmp")) + try: + logger.info(f"Exporting the reference study '{ref_study.id}' to '{self._tmp_dir.name}'...") + self._export_ref_study(ref_study) + + logger.info(f"Applying commands to the reference study '{ref_study.id}'...") + results = self._apply_commands(variant_study, ref_study, cmd_blocks) + + if (snapshot_dir / "user").exists(): + logger.info("Keeping previous unmanaged user config...") + shutil.copytree(snapshot_dir / "user", self._tmp_dir / "user", dirs_exist_ok=True) + + # The snapshot is generated, we also need to de-normalize the matrices. + file_study = self.study_factory.create_from_fs( + self._tmp_dir, + study_id=variant_study_id, + output_path=self._tmp_dir / OUTPUT_RELATIVE_PATH, + use_cache=False, # Avoid saving the study config in the cache + ) + if denormalize: + logger.info(f"Denormalizing variant study {variant_study_id}") + file_study.tree.denormalize() + + # Finally, we can update the database. + logger.info(f"Saving new snapshot for study {variant_study_id}") + variant_study.snapshot = VariantStudySnapshot( + id=variant_study_id, + created_at=datetime.datetime.utcnow(), + last_executed_command=cmd_blocks[-1].id if cmd_blocks else None, + ) + + logger.info(f"Reading additional data from files for study {file_study.config.study_id}") + variant_study.additional_data = self._read_additional_data(file_study) + self.repository.save(variant_study) + + # Store the study config in the cache (with adjusted paths). + file_study.config.study_path = file_study.config.path = snapshot_dir + file_study.config.output_path = snapshot_dir / OUTPUT_RELATIVE_PATH + self._update_cache(file_study) + + except Exception: + shutil.rmtree(self._tmp_dir, ignore_errors=True) + raise + + else: + # Rename the temporary directory to the final snapshot directory + shutil.rmtree(snapshot_dir, ignore_errors=True) + self._tmp_dir.rename(snapshot_dir) + try: + notifier(results.json()) + except Exception as exc: + # This exception is ignored, because it is not critical. + logger.warning(f"Error while sending notification: {exc}", exc_info=True) + + return results + + def _retrieve_descendants(self, variant_study_id: str) -> t.Tuple[RawStudy, t.Sequence[VariantStudy]]: + # Get all ancestors of the current study from bottom to top + # The first IDs are variant IDs, the last is the root study ID. + ancestor_ids = self.repository.get_ancestor_or_self_ids(variant_study_id) + descendant_ids = ancestor_ids[::-1] + descendants = self.repository.find_variants(descendant_ids) + root_study = self.repository.one(descendant_ids[0]) + return root_study, descendants + + def _export_ref_study(self, ref_study: t.Union[RawStudy, VariantStudy]) -> None: + self._tmp_dir.rmdir() # remove the temporary directory for shutil.copytree + if isinstance(ref_study, VariantStudy): + export_study_flat( + ref_study.snapshot_dir, + self._tmp_dir, + self.study_factory, + denormalize=False, # de-normalization is done at the end + ) + elif isinstance(ref_study, RawStudy): + self.raw_study_service.export_study_flat( + ref_study, + self._tmp_dir, + denormalize=False, # de-normalization is done at the end + ) + else: # pragma: no cover + raise TypeError(repr(type(ref_study))) + + def _apply_commands( + self, + variant_study: VariantStudy, + ref_study: t.Union[RawStudy, VariantStudy], + cmd_blocks: t.Sequence[CommandBlock], + ) -> GenerationResultInfoDTO: + commands = [self.command_factory.to_command(cb.to_dto()) for cb in cmd_blocks] + generator = VariantCommandGenerator(self.study_factory) + results = generator.generate( + commands, + self._tmp_dir, + variant_study, + delete_on_failure=False, # Not needed, because we are using a temporary directory + notifier=None, + ) + if not results.success: + message = f"Failed to generate variant study {variant_study.id}" + if results.details: + detail: t.Tuple[str, bool, str] = results.details[-1] + message += f": {detail[2]}" + raise VariantGenerationError(message) + return results + + def _read_additional_data(self, file_study: FileStudy) -> StudyAdditionalData: + horizon = file_study.tree.get(url=["settings", "generaldata", "general", "horizon"]) + author = file_study.tree.get(url=["study", "antares", "author"]) + patch = self.patch_service.get_from_filestudy(file_study) + study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.json()) + return study_additional_data + + def _update_cache(self, file_study: FileStudy) -> None: + # The study configuration is changed, so we update the cache. + self.cache.invalidate(f"{CacheConstants.RAW_STUDY}/{file_study.config.study_id}") + self.cache.put( + f"{CacheConstants.STUDY_FACTORY}/{file_study.config.study_id}", + FileStudyTreeConfigDTO.from_build_config(file_study.config).dict(), + ) + + +def search_ref_study( + root_study: t.Union[RawStudy, VariantStudy], + descendants: t.Sequence[VariantStudy], + *, + from_scratch: bool = False, +) -> t.Tuple[t.Union[RawStudy, VariantStudy], t.Sequence[CommandBlock]]: + """ + Search for the reference study and the commands to use for snapshot generation. + + Args: + root_study: The root study from which the descendants of variants are derived. + descendants: The list of descendants of variants from top to bottom. + from_scratch: Whether to generate the snapshot from scratch or not. + + Returns: + The reference study and the commands to use for snapshot generation. + """ + + # The reference study is the root study or a variant study with a valid snapshot + ref_study: t.Union[RawStudy, VariantStudy] + + # The commands to apply on the reference study to generate the current variant + cmd_blocks: t.List[CommandBlock] + + if from_scratch: + # In the case of a from scratch generation, the root study will be used as the reference study. + # We need to retrieve all commands from the descendants of variants in order to apply them + # on the reference study. + ref_study = root_study + cmd_blocks = [c for v in descendants for c in v.commands] + + else: + # To generate the last variant of a descendant of variants, we must search for + # the most recent snapshot in order to use it as a reference study. + # If no snapshot is found, we use the root study as a reference study. + + snapshot_vars = [v for v in descendants if v.is_snapshot_recent()] + + if snapshot_vars: + # We use the most recent snapshot as a reference study + ref_study = max(snapshot_vars, key=lambda v: v.snapshot.created_at) + + # This variant's snapshot corresponds to the commands actually generated + # at the time of the snapshot. However, we need to retrieve the remaining commands, + # because the snapshot generation may be incomplete. + last_exec_cmd = ref_study.snapshot.last_executed_command # ID of the command + if not last_exec_cmd: + # It is unlikely that this case will occur, but it means that + # the snapshot is not correctly generated (corrupted database). + # It better to use all commands to force snapshot re-generation. + cmd_blocks = ref_study.commands[:] + else: + command_ids = [c.id for c in ref_study.commands] + last_exec_index = command_ids.index(last_exec_cmd) + cmd_blocks = ref_study.commands[last_exec_index + 1 :] + + # We need to add all commands from the descendants of variants + # starting at the first descendant of reference study. + index = descendants.index(ref_study) + cmd_blocks.extend([c for v in descendants[index + 1 :] for c in v.commands]) + + else: + # We use the root study as a reference study + ref_study = root_study + cmd_blocks = [c for v in descendants for c in v.commands] + + return ref_study, cmd_blocks diff --git a/tests/study/storage/variantstudy/test_snapshot_generator.py b/tests/study/storage/variantstudy/test_snapshot_generator.py new file mode 100644 index 0000000000..21f32b944f --- /dev/null +++ b/tests/study/storage/variantstudy/test_snapshot_generator.py @@ -0,0 +1,1046 @@ +import configparser +import datetime +import json +import logging +import re +import typing as t +import uuid +from pathlib import Path +from unittest.mock import Mock + +import numpy as np +import pytest +from sqlalchemy import event # type: ignore + +from antarest.core.exceptions import VariantGenerationError +from antarest.core.interfaces.cache import CacheConstants +from antarest.core.jwt import JWTGroup, JWTUser +from antarest.core.requests import RequestParameters +from antarest.core.roles import RoleType +from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.login.model import Group, Role, User +from antarest.study.model import RawStudy, Study, StudyAdditionalData +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy, VariantStudySnapshot +from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO +from antarest.study.storage.variantstudy.snapshot_generator import SnapshotGenerator, search_ref_study +from antarest.study.storage.variantstudy.variant_study_service import VariantStudyService +from tests.helpers import with_db_context + + +def _create_variant( + tmp_path: Path, + variant_name: str, + parent_id: str, + updated_at: datetime.datetime, + snapshot_created_at: t.Optional[datetime.datetime], +) -> VariantStudy: + """ + Create a variant study with a snapshot (if snapshot_created_at is provided). + """ + variant_dir = tmp_path.joinpath(f"some_place/{variant_name}") + variant_dir.mkdir(parents=True, exist_ok=True) + variant = VariantStudy( + id=str(uuid.uuid4()), + name=variant_name, + updated_at=updated_at, + parent_id=parent_id, + path=str(variant_dir), + ) + + if snapshot_created_at: + snapshot_dir = variant_dir.joinpath("snapshot") + snapshot_dir.mkdir(parents=True, exist_ok=True) + (snapshot_dir / "study.antares").touch() + variant.snapshot = VariantStudySnapshot( + id=variant.id, + created_at=snapshot_created_at, + last_executed_command=None, + ) + + return variant + + +class TestSearchRefStudy: + """ + Test the `search_ref_study` method of the `SnapshotGenerator` class. + + We need to test several cases: + + Cases where we expect to have the root study and a list of `CommandBlock` + for all variants in the order of the list. + + - The edge case where the list of studies is empty. + Note: This case is unlikely, but the function should be able to handle it. + + - The case where the list of studies contains variants with or without snapshots, + but a search is requested from scratch. + + - The case where the list of studies contains variants with obsolete snapshots, meaning that: + - either there is no snapshot, + - or the snapshot's creation date is earlier than the variant's last modification date. + Note: The situation where the "snapshot/study.antares" file does not exist is not considered. + + Cases where we expect to have a different reference study than the root study + and corresponding to a variant with an up-to-date snapshot. + + - The case where the list of studies contains two variants with up-to-date snapshots and + where the first is older than the second. + We expect to have a reference study corresponding to the second variant + and a list of commands for the second variant. + + - The case where the list of studies contains two variants with up-to-date snapshots and + where the first is more recent than the second. + We expect to have a reference study corresponding to the first variant + and a list of commands for both variants in order. + + - The case where the list of studies contains a variant with an up-to-date snapshot and + corresponds to the generation of all commands for the variant. + We expect to have an empty list of commands because the snapshot is already completely up-to-date. + + - The case where the list of studies contains a variant with an up-to-date snapshot and + corresponds to a partial generation of commands for the variant. + We expect to have a list of commands corresponding to the remaining commands. + """ + + def test_search_ref_study__empty_descendants(self) -> None: + """ + Edge case where the list of studies is empty. + We expect to have the root study and a list of `CommandBlock` for all variants + in the order of the list. + + Note: This case is unlikely, but the function should be able to handle it. + + Given an empty list of descendants, + When calling search_ref_study, + Then the root study is returned as reference study, + and an empty list of commands is returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + references: t.Sequence[VariantStudy] = [] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == root_study + assert cmd_blocks == [] + + def test_search_ref_study__from_scratch(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains variants with or without snapshots, + but a search is requested from scratch. + We expect to have the root study and a list of `CommandBlock` for all variants + in the order of the list. + + Given a list of descendants with some variants with snapshots, + When calling search_ref_study with the flag from_scratch=True, + Then the root study is returned as reference study, + and all commands of all variants are returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=1), + ) + variant2 = _create_variant( + tmp_path, + "variant2", + variant1.id, + datetime.datetime(year=2023, month=1, day=2), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + variant3 = _create_variant( + tmp_path, + "variant3", + variant2.id, + datetime.datetime(year=2023, month=1, day=1), + None, + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + ] + variant1.snapshot.last_executed_command = variant1.commands[0].id + variant2.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + ] + variant2.snapshot.last_executed_command = variant2.commands[0].id + variant3.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant3.id, + index=0, + command="update_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "capacity": 1500}', + ), + ] + + # Check the variants + references = [variant1, variant2, variant3] + ref_study, cmd_blocks = search_ref_study(root_study, references, from_scratch=True) + assert ref_study == root_study + assert cmd_blocks == [c for v in [variant1, variant2, variant3] for c in v.commands] + + def test_search_ref_study__obsolete_snapshots(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains variants with obsolete snapshots, meaning that: + - either there is no snapshot, + - or the snapshot's creation date is earlier than the variant's last modification date. + Note: The situation where the "snapshot/study.antares" file does not exist is not considered. + We expect to have the root study and a list of `CommandBlock` for all variants. + + Given a list of descendants with some variants with obsolete snapshots, + When calling search_ref_study with the flag from_scratch=False, + Then the root study is returned as reference study, + and all commands of all variants are returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + # Variant 1 has no snapshot. + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=None, + ) + # Variant 2 has an obsolete snapshot. + variant2 = _create_variant( + tmp_path, + "variant2", + variant1.id, + datetime.datetime(year=2023, month=1, day=2), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=1), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + ] + variant2.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + ] + variant2.snapshot.last_executed_command = variant2.commands[0].id + + # Check the variants + references = [variant1, variant2] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == root_study + assert cmd_blocks == [c for v in [variant1, variant2] for c in v.commands] + + def test_search_ref_study__old_recent_snapshot(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with up-to-date snapshots and + where the first is older than the second. + We expect to have a reference study corresponding to the second variant + and an empty list of commands, because the snapshot is already completely up-to-date. + + Given a list of descendants with some variants with up-to-date snapshots, + When calling search_ref_study with the flag from_scratch=False, + Then the second variant is returned as reference study, and no commands are returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + # Variant 1 has an up-to-date snapshot. + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=1), + ) + # Variant 2 has an up-to-date snapshot, but is more recent than variant 1. + variant2 = _create_variant( + tmp_path, + "variant2", + variant1.id, + datetime.datetime(year=2023, month=1, day=2), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=3), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + ] + variant1.snapshot.last_executed_command = variant1.commands[0].id + variant2.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + ] + variant2.snapshot.last_executed_command = variant2.commands[0].id + + # Check the variants + references = [variant1, variant2] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == variant2 + assert cmd_blocks == [] + + def test_search_ref_study__recent_old_snapshot(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with up-to-date snapshots and + where the second is older than the first. + We expect to have a reference study corresponding to the first variant + and the list of commands of the second variant, because the first is completely up-to-date. + + Given a list of descendants with some variants with up-to-date snapshots, + When calling search_ref_study with the flag from_scratch=False, + Then the first variant is returned as reference study, + and the commands of the second variant are returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + # Variant 1 has an up-to-date snapshot, but is more recent than variant 2. + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=3), + ) + # Variant 2 has an up-to-date snapshot, but is older that variant 1. + variant2 = _create_variant( + tmp_path, + "variant2", + variant1.id, + datetime.datetime(year=2023, month=1, day=2), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + ] + variant1.snapshot.last_executed_command = variant1.commands[0].id + variant2.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant2.id, + index=0, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + ] + variant2.snapshot.last_executed_command = variant2.commands[0].id + + # Check the variants + references = [variant1, variant2] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == variant1 + assert cmd_blocks == variant2.commands + + def test_search_ref_study__one_variant_completely_uptodate(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with an up-to-date snapshot and + corresponds to the generation of all commands for the variant (completely up-to-date) + We expect to have an empty list of commands because the snapshot is already completely up-to-date. + + Given a list of descendants with some variants with up-to-date snapshots, + When calling search_ref_study with the flag from_scratch=False, + Then the variant is returned as reference study, and no commands are returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=2, + command="update_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "capacity": 1500}', + ), + ] + + # The last executed command is the last item of the commands list. + variant1.snapshot.last_executed_command = variant1.commands[-1].id + + # Check the variants + references = [variant1] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == variant1 + assert cmd_blocks == [] + + def test_search_ref_study__one_variant_partially_uptodate(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with an up-to-date snapshot and + corresponds to a partial generation of commands for the variant (partially up-to-date) + We expect to have a list of commands corresponding to the remaining commands. + + Given a list of descendants with some variants with up-to-date snapshots, + When calling search_ref_study with the flag from_scratch=False, + Then the variant is returned as reference study, and the remaining commands are returned. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=2, + command="update_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "capacity": 1500}', + ), + ] + + # The last executed command is the NOT last item of the commands list. + variant1.snapshot.last_executed_command = variant1.commands[0].id + + # Check the variants + references = [variant1] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == variant1 + assert cmd_blocks == variant1.commands[1:] + + def test_search_ref_study__missing_last_command(self, tmp_path: Path) -> None: + """ + Case where the list of studies contains a variant with an up-to-date snapshot, + but the last executed command is missing (probably caused by a bug). + We expect to have the list of all variant commands, so that the snapshot can be re-generated. + """ + root_study = Study(id=str(uuid.uuid4()), name="root") + + # Prepare some variants with snapshots: + variant1 = _create_variant( + tmp_path, + "variant1", + root_study.id, + datetime.datetime(year=2023, month=1, day=1), + snapshot_created_at=datetime.datetime(year=2023, month=1, day=2), + ) + + # Add some variant commands + variant1.commands = [ + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=0, + command="create_area", + version=1, + args='{"area_name": "DE"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=1, + command="create_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "cluster_type": "thermal"}', + ), + CommandBlock( + id=str(uuid.uuid4()), + study_id=variant1.id, + index=2, + command="update_thermal_cluster", + version=1, + args='{"area_name": "DE", "cluster_name": "DE", "capacity": 1500}', + ), + ] + + # The last executed command is missing. + variant1.snapshot.last_executed_command = None + + # Check the variants + references = [variant1] + ref_study, cmd_blocks = search_ref_study(root_study, references) + assert ref_study == variant1 + assert cmd_blocks == variant1.commands + + +class RegisterNotification: + """ + Callable used to register notifications. + """ + + def __init__(self) -> None: + self.notifications: t.MutableSequence[str] = [] + + def __call__(self, notification: str) -> None: + self.notifications.append(json.loads(notification)) + + +class TestSnapshotGenerator: + """ + Test the `SnapshotGenerator` class. + """ + + @pytest.fixture(name="jwt_user") + def jwt_user_fixture(self) -> JWTUser: + # Create a user in a "Writers" group: + jwt_user = JWTUser( + id=7, + impersonator=7, + type="users", + groups=[JWTGroup(id="writers", name="Writers", role=RoleType.WRITER)], + ) + # Ensure the user is in database. + with db(): + role = Role( + type=RoleType.WRITER, + identity=User(id=jwt_user.id, name="john.doe"), + group=Group(id="writers"), + ) + db.session.add(role) + db.session.commit() + return jwt_user + + @pytest.fixture(name="root_study_id") + def root_study_id_fixture( + self, + tmp_path: Path, + raw_study_service: RawStudyService, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> str: + # Prepare a RAW study in the temporary folder + study_dir = tmp_path / "my-study" + root_study_id = str(uuid.uuid4()) + root_study = RawStudy( + id=root_study_id, + workspace="default", + path=str(study_dir), + version="860", + created_at=datetime.datetime.utcnow(), + updated_at=datetime.datetime.utcnow(), + additional_data=StudyAdditionalData(author="john.doe"), + owner_id=jwt_user.id, + ) + root_study = raw_study_service.create(root_study) + with db(): + # Save the root study in database + variant_study_service.repository.save(root_study) + return root_study_id + + @pytest.fixture(name="variant_study") + def variant_study_fixture( + self, + root_study_id: str, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> VariantStudy: + with db(): + # Create un new variant + name = "my-variant" + params = RequestParameters(user=jwt_user) + variant_study = variant_study_service.create_variant_study(root_study_id, name, params=params) + + # Append some commands + variant_study_service.append_commands( + variant_study.id, + [ + CommandDTO(action="create_area", args={"area_name": "North"}), + CommandDTO(action="create_area", args={"area_name": "South"}), + CommandDTO(action="create_link", args={"area1": "north", "area2": "south"}), + CommandDTO( + action="create_cluster", + args={ + "area_id": "south", + "cluster_name": "gas_cluster", + "parameters": {"group": "Gas", "unitcount": 1, "nominalcapacity": 500}, + }, + ), + ], + params=params, + ) + return variant_study + + def test_init(self, variant_study_service: VariantStudyService) -> None: + """ + Test the initialization of the `SnapshotGenerator` class. + """ + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + assert generator.cache == variant_study_service.cache + assert generator.raw_study_service == variant_study_service.raw_study_service + assert generator.command_factory == variant_study_service.command_factory + assert generator.study_factory == variant_study_service.study_factory + assert generator.patch_service == variant_study_service.patch_service + assert generator.repository == variant_study_service.repository + + @with_db_context + def test_generate__nominal_case( + self, + variant_study: VariantStudy, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> None: + """ + Test the generation of a variant study based on a raw study. + + Given a raw study and a single variant study, + When calling generate with: + - `denormalize` set to False, + - `from_scratch` set to False, + - `notifier` set to a callback function used to register de notifications, + Then the variant generation must succeed. + We must check that: + - the number of database queries is kept as low as possible, + - the variant is correctly generated in the "snapshot" directory and all commands are applied, + - the matrices are not denormalized (we should have links to matrices), + - the variant is updated in the database (snapshot and additional_data), + - the cache is updated with the new variant configuration, + - the temporary directory is correctly removed. + - the notifications are correctly registered. + """ + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + + sql_statements = [] + notifier = RegisterNotification() + + @event.listens_for(db.session.bind, "before_cursor_execute") # type: ignore + def before_cursor_execute(conn, cursor, statement: str, parameters, context, executemany) -> None: + # note: add a breakpoint here to debug the SQL statements. + sql_statements.append(statement) + + try: + results = generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, + notifier=notifier, + ) + finally: + event.remove(db.session.bind, "before_cursor_execute", before_cursor_execute) + + # Check: the number of database queries is kept as low as possible. + # We expect 5 queries: + # - 1 query to fetch the ancestors of a variant study, + # - 1 query to fetch the root study (with owner and groups for permission check), + # - 1 query to fetch the list of variants with snapshot, commands, etc., + # - 1 query to update the variant study additional_data, + # - 1 query to insert the variant study snapshot. + assert len(sql_statements) == 5, "\n-------\n".join(sql_statements) + + # Check: the variant generation must succeed. + assert results == GenerationResultInfoDTO( + success=True, + details=[ + ("create_area", True, "Area 'North' created"), + ("create_area", True, "Area 'South' created"), + ("create_link", True, "Link between 'north' and 'south' created"), + ("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."), + ], + ) + + # Check: the variant is correctly generated and all commands are applied. + snapshot_dir = variant_study.snapshot_dir + assert snapshot_dir.exists() + assert (snapshot_dir / "study.antares").exists() + assert (snapshot_dir / "input/areas/list.txt").read_text().splitlines(keepends=False) == ["North", "South"] + config = configparser.RawConfigParser() + config.read(snapshot_dir / "input/links/north/properties.ini") + assert config.sections() == ["south"] + assert config["south"], "The 'south' section must exist in the 'properties.ini' file." + config = configparser.RawConfigParser() + config.read(snapshot_dir / "input/thermal/clusters/south/list.ini") + assert config.sections() == ["gas_cluster"] + assert config["gas_cluster"] == { # type: ignore + "group": "Gas", + "unitcount": "1", + "nominalcapacity": "500", + "name": "gas_cluster", + } + + # Check: the matrices are not denormalized (we should have links to matrices). + assert (snapshot_dir / "input/links/north/south_parameters.txt.link").exists() + assert (snapshot_dir / "input/thermal/series/south/gas_cluster/series.txt.link").exists() + + # Check: the variant is updated in the database (snapshot and additional_data). + with db(): + study = variant_study_service.repository.get(variant_study.id) + assert study is not None + assert study.snapshot is not None + assert study.snapshot.last_executed_command == study.commands[-1].id + assert study.additional_data.author == "john.doe" + + # Check: the cache is updated with the new variant configuration. + # The cache is a mock created in the session's scope, so it is shared between all tests. + cache: Mock = generator.cache # type: ignore + # So, the number of calls to the `put` method is at least equal to 2. + assert cache.put.call_count >= 2 + # The last call to the `put` method is for the variant study. + put_variant = cache.put.call_args_list[-1] + assert put_variant[0][0] == f"{CacheConstants.STUDY_FACTORY}/{variant_study.id}" + variant_study_config = put_variant[0][1] + assert variant_study_config["study_id"] == variant_study.id + + # Check: the temporary directory is correctly removed. + assert list(snapshot_dir.parent.iterdir()) == [snapshot_dir] + + # Check: the notifications are correctly registered. + assert notifier.notifications == [ # type: ignore + { + "details": [ + ["create_area", True, "Area 'North' created"], + ["create_area", True, "Area 'South' created"], + ["create_link", True, "Link between 'north' and 'south' created"], + ["create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."], + ], + "success": True, + } + ] + + @with_db_context + def test_generate__with_user_dir( + self, + variant_study: VariantStudy, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> None: + """ + Test the generation of a variant study containing a user directory. + We expect that the user directory is correctly preserved. + """ + # Add a user directory to the variant study. + user_dir = Path(variant_study.snapshot_dir) / "user" + user_dir.mkdir(parents=True, exist_ok=True) + user_dir.joinpath("user_file.txt").touch() + + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + + results = generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, + ) + + # Check the results + assert results == GenerationResultInfoDTO( + success=True, + details=[ + ("create_area", True, "Area 'North' created"), + ("create_area", True, "Area 'South' created"), + ("create_link", True, "Link between 'north' and 'south' created"), + ("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."), + ], + ) + + # Check that the user directory is correctly preserved. + user_dir = Path(variant_study.snapshot_dir) / "user" + assert user_dir.is_dir() + assert user_dir.joinpath("user_file.txt").exists() + + @with_db_context + def test_generate__with_denormalize_true( + self, + variant_study: VariantStudy, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> None: + """ + Test the generation of a variant study with matrices de-normalization. + We expect that all matrices are correctly denormalized (no link). + """ + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + + results = generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=True, + from_scratch=False, + ) + + # Check the results + assert results == GenerationResultInfoDTO( + success=True, + details=[ + ("create_area", True, "Area 'North' created"), + ("create_area", True, "Area 'South' created"), + ("create_link", True, "Link between 'north' and 'south' created"), + ("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."), + ], + ) + + # Check: the matrices are denormalized (we should have TSV files). + snapshot_dir = variant_study.snapshot_dir + assert (snapshot_dir / "input/links/north/south_parameters.txt").exists() + array = np.loadtxt(snapshot_dir / "input/links/north/south_parameters.txt", delimiter="\t") + assert array.shape == (8760, 6) + + assert (snapshot_dir / "input/thermal/series/south/gas_cluster/series.txt").exists() + array = np.loadtxt(snapshot_dir / "input/thermal/series/south/gas_cluster/series.txt", delimiter="\t") + assert array.size == 0 + + @with_db_context + def test_generate__with_invalid_command( + self, + variant_study: VariantStudy, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> None: + """ + Test the generation of a variant study with an invalid command. + We expect to have a clear error message explaining which command fails. + The snapshot directory must be removed (and no temporary directory must be left). + """ + # Append an invalid command to the variant study. + params = RequestParameters(user=jwt_user) + variant_study_service.append_commands( + variant_study.id, + [ + CommandDTO(action="create_area", args={"area_name": "North"}), # duplicate + ], + params=params, + ) + + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + + err_msg = ( + f"Failed to generate variant study {variant_study.id}:" + f" Area 'North' already exists and could not be created" + ) + with pytest.raises(VariantGenerationError, match=re.escape(err_msg)): + generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, + ) + + # Check: the snapshot directory is removed. + snapshot_dir = variant_study.snapshot_dir + assert not snapshot_dir.exists() + + # Check: no temporary directory is left. + assert list(snapshot_dir.parent.iterdir()) == [] + + @with_db_context + def test_generate__notification_failure( + self, + variant_study: VariantStudy, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + caplog: pytest.LogCaptureFixture, + ) -> None: + """ + Test the generation of a variant study with a notification that fails. + Since the notification is not critical, we expect to have no exception. + """ + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + + notifier = Mock(side_effect=Exception("Something went wrong")) + + with caplog.at_level(logging.WARNING): + results = generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, + notifier=notifier, + ) + + # Check the results + assert results == GenerationResultInfoDTO( + success=True, + details=[ + ("create_area", True, "Area 'North' created"), + ("create_area", True, "Area 'South' created"), + ("create_link", True, "Link between 'north' and 'south' created"), + ("create_cluster", True, "Thermal cluster 'gas_cluster' added to area 'south'."), + ], + ) + + # Check th logs + assert "Something went wrong" in caplog.text + + @with_db_context + def test_generate__variant_of_variant( + self, + variant_study: VariantStudy, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> None: + """ + Test the generation of a variant study of a variant study. + """ + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, + ) + + # Generate the variant once. + generator.generate_snapshot( + variant_study.id, + jwt_user, + denormalize=False, + from_scratch=False, + ) + + # Create a new variant of the variant study. + params = RequestParameters(user=jwt_user) + new_variant = variant_study_service.create_variant_study(variant_study.id, "my-variant", params=params) + + # Append some commands to the new variant. + variant_study_service.append_commands( + new_variant.id, + [ + CommandDTO(action="create_area", args={"area_name": "East"}), + ], + params=params, + ) + + # Generate the variant again. + results = generator.generate_snapshot( + new_variant.id, + jwt_user, + denormalize=False, + from_scratch=False, + ) + + # Check the results + assert results == GenerationResultInfoDTO(success=True, details=[("create_area", True, "Area 'East' created")]) From be805dc62da46cc49dd2cc52bad2ac75d59d9f89 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 10 Nov 2023 17:32:31 +0100 Subject: [PATCH 18/22] feat(variant): use the `SnapshotGenerator` class in the `VariantStudyService` --- .../study/storage/variantstudy/repository.py | 27 ++- .../variantstudy/variant_study_service.py | 189 +++++------------- 2 files changed, 74 insertions(+), 142 deletions(-) diff --git a/antarest/study/storage/variantstudy/repository.py b/antarest/study/storage/variantstudy/repository.py index 8a20108021..bf2c979de1 100644 --- a/antarest/study/storage/variantstudy/repository.py +++ b/antarest/study/storage/variantstudy/repository.py @@ -1,6 +1,6 @@ import typing as t -from sqlalchemy.orm import Session # type: ignore +from sqlalchemy.orm import Session, joinedload, subqueryload # type: ignore from antarest.core.interfaces.cache import ICache from antarest.core.utils.fastapi_sqlalchemy import db @@ -87,3 +87,28 @@ def get_all_command_blocks(self) -> t.List[CommandBlock]: """ cmd_blocks: t.List[CommandBlock] = self.session.query(CommandBlock).all() return cmd_blocks + + def find_variants(self, variant_ids: t.Sequence[str]) -> t.Sequence[VariantStudy]: + """ + Find a list of variants by IDs + + Args: + variant_ids: list of variant IDs. + + Returns: + List of variants (and attached snapshot) ordered by IDs + """ + # When we fetch the list of variants, we also need to fetch the associated snapshots, + # the list of commands, the additional data, etc. + # We use a SQL query with joins to fetch all these data efficiently. + q = ( + self.session.query(VariantStudy) + .options(joinedload(VariantStudy.snapshot)) + .options(joinedload(VariantStudy.commands)) + .options(joinedload(VariantStudy.additional_data)) + .options(joinedload(VariantStudy.owner)) + .options(joinedload(VariantStudy.groups)) + .filter(VariantStudy.id.in_(variant_ids)) # type: ignore + ) + index = {id_: i for i, id_ in enumerate(variant_ids)} + return sorted(q, key=lambda v: index[v.id]) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index 4adeb0f1aa..de6fa1651c 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -3,7 +3,6 @@ import logging import re import shutil -import tempfile from datetime import datetime from functools import reduce from pathlib import Path @@ -21,6 +20,7 @@ NoParentStudyError, StudyNotFoundError, StudyTypeUnsupported, + StudyValidationError, VariantGenerationError, VariantGenerationTimeoutError, VariantStudyParentNotValid, @@ -51,7 +51,7 @@ from antarest.study.storage.variantstudy.business.utils import transform_command_to_dto from antarest.study.storage.variantstudy.command_factory import CommandFactory from antarest.study.storage.variantstudy.model.command.icommand import ICommand -from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy, VariantStudySnapshot +from antarest.study.storage.variantstudy.model.dbmodel import CommandBlock, VariantStudy from antarest.study.storage.variantstudy.model.model import ( CommandDTO, CommandResultDTO, @@ -59,6 +59,7 @@ VariantTreeDTO, ) from antarest.study.storage.variantstudy.repository import VariantStudyRepository +from antarest.study.storage.variantstudy.snapshot_generator import SnapshotGenerator from antarest.study.storage.variantstudy.variant_command_generator import VariantCommandGenerator logger = logging.getLogger(__name__) @@ -342,11 +343,19 @@ def _get_variant_study( raw_study_accepted: bool = False, ) -> VariantStudy: """ - Get variant study and check permissions + Get variant study (or RAW study if `raw_study_accepted` is `True`), and check READ permissions. + Args: - study_id: study id - params: request parameters - Returns: None + study_id: The study identifier. + params: request parameters used for permission check. + + Returns: + The variant study. + + Raises: + StudyNotFoundError: If the study does not exist (HTTP status 404). + MustBeAuthenticatedError: If the user is not authenticated (HTTP status 403). + StudyTypeUnsupported: If the study is not a variant study (HTTP status 422). """ study = self.repository.get(study_id) @@ -597,11 +606,19 @@ def generate_task( study_id = metadata.id def callback(notifier: TaskUpdateNotifier) -> TaskResult: - generate_result = self._generate( - variant_study_id=study_id, + generator = SnapshotGenerator( + cache=self.cache, + raw_study_service=self.raw_study_service, + command_factory=self.command_factory, + study_factory=self.study_factory, + patch_service=self.patch_service, + repository=self.repository, + ) + generate_result = generator.generate_snapshot( + study_id, + DEFAULT_ADMIN_USER, denormalize=denormalize, from_scratch=from_scratch, - params=RequestParameters(DEFAULT_ADMIN_USER), notifier=notifier, ) return TaskResult( @@ -653,136 +670,6 @@ def generate_study_config( return self._generate_study_config(variant_study, variant_study, None) - def _generate( - self, - variant_study_id: str, - params: RequestParameters, - denormalize: bool = True, - from_scratch: bool = False, - notifier: TaskUpdateNotifier = noop_notifier, - ) -> GenerationResultInfoDTO: - logger.info(f"Generating variant study {variant_study_id}") - - # Get variant study - variant_study = self._get_variant_study(variant_study_id, params) - - # Get parent study - if variant_study.parent_id is None: - raise NoParentStudyError(variant_study_id) - - parent_study = self.repository.get(variant_study.parent_id) - if parent_study is None: - raise StudyNotFoundError(variant_study.parent_id) - - # Check parent study permission - assert_permission(params.user, parent_study, StudyPermissionType.READ) - - # Remove from cache - remove_from_cache(self.cache, variant_study.id) - - # Get snapshot directory - dst_path = self.get_study_path(variant_study) - - # this indicates that the current snapshot is up-to-date, - # and we can only generate from the next command - last_executed_command_index = VariantStudyService._get_snapshot_last_executed_command_index(variant_study) - - is_parent_newer = ( - parent_study.updated_at > variant_study.snapshot.created_at if variant_study.snapshot else True - ) - last_executed_command_index = ( - None - if ( - is_parent_newer - or from_scratch - or (isinstance(parent_study, VariantStudy) and not self.exists(parent_study)) - ) - else last_executed_command_index - ) - - variant_study.snapshot = None - self.repository.save(variant_study, update_modification_date=False) - - unmanaged_user_config: Optional[Path] = None - if dst_path.is_dir(): - # Remove snapshot directory if it exists and last snapshot is out of sync - if last_executed_command_index is None: - logger.info("Removing previous snapshot data") - if (dst_path / "user").exists(): - logger.info("Keeping previous unmanaged user config") - tmp_dir = tempfile.TemporaryDirectory(dir=self.config.storage.tmp_dir) - shutil.copytree(dst_path / "user", tmp_dir.name, dirs_exist_ok=True) - unmanaged_user_config = Path(tmp_dir.name) - shutil.rmtree(dst_path) - else: - logger.info("Using previous snapshot data") - elif last_executed_command_index is not None: - # there is no snapshot so last_command_index should be None - logger.warning("Previous snapshot with last_executed_command found, but no data found") - last_executed_command_index = None - - if last_executed_command_index is None: - # Copy parent study to destination - if isinstance(parent_study, VariantStudy): - self._safe_generation(parent_study) - self.export_study_flat( - metadata=parent_study, - dst_path=dst_path, - outputs=False, - denormalize=False, - ) - else: - self.raw_study_service.export_study_flat( - metadata=parent_study, - dst_path=dst_path, - outputs=False, - denormalize=False, - ) - - command_start_index = last_executed_command_index + 1 if last_executed_command_index is not None else 0 - logger.info(f"Generating study snapshot from command index {command_start_index}") - results = self._generate_snapshot( - variant_study=variant_study, - dst_path=dst_path, - notifier=notifier, - from_command_index=command_start_index, - ) - - if unmanaged_user_config: - logger.info("Restoring previous unmanaged user config") - if dst_path.exists(): - if (dst_path / "user").exists(): - logger.warning("Existing unmanaged user config. It will be overwritten.") - shutil.rmtree((dst_path / "user")) - shutil.copytree(unmanaged_user_config, dst_path / "user") - else: - logger.warning("Destination snapshot doesn't exist !") - shutil.rmtree(unmanaged_user_config, ignore_errors=True) - - if results.success: - # sourcery skip: extract-method - last_command_index = len(variant_study.commands) - 1 - # noinspection PyArgumentList - variant_study.snapshot = VariantStudySnapshot( - id=variant_study.id, - created_at=datetime.utcnow(), - last_executed_command=( - variant_study.commands[last_command_index].id if last_command_index >= 0 else None - ), - ) - study = self.study_factory.create_from_fs( - self.get_study_path(variant_study), - study_id=variant_study.id, - output_path=Path(variant_study.path) / OUTPUT_RELATIVE_PATH, - ) - variant_study.additional_data = self._read_additional_data_from_files(study) - self.repository.save(variant_study) - logger.info(f"Saving new snapshot for study {variant_study.id}") - if denormalize: - logger.info(f"Denormalizing variant study {variant_study.id}") - study.tree.denormalize() - return results - def _generate_study_config( self, original_study: VariantStudy, @@ -879,9 +766,27 @@ def _generate_snapshot( return self.generator.generate(commands, dst_path, variant_study, notifier=notify) def get_study_task(self, study_id: str, params: RequestParameters) -> TaskDTO: + """ + Get the generation task ID of a variant study. + + Args: + study_id: The ID of the variant study. + params: The request parameters used to check permissions. + + Returns: + The generation task ID. + + Raises: + StudyNotFoundError: If the study does not exist (HTTP status 404). + MustBeAuthenticatedError: If the user is not authenticated (HTTP status 403). + StudyTypeUnsupported: If the study is not a variant study (HTTP status 422). + StudyValidationError: If the study has no generation task (HTTP status 422). + """ variant_study = self._get_variant_study(study_id, params) task_id = variant_study.generation_task - return self.task_service.status_task(task_id=task_id, request_params=params, with_logs=True) + if task_id: + return self.task_service.status_task(task_id=task_id, request_params=params, with_logs=True) + raise StudyValidationError(f"Variant study '{study_id}' has no generation task") def create(self, study: VariantStudy) -> VariantStudy: """ @@ -894,9 +799,11 @@ def create(self, study: VariantStudy) -> VariantStudy: def exists(self, metadata: VariantStudy) -> bool: """ - Check if study exists. + Check if the study snapshot exists and is up-to-date. + Args: metadata: Study metadata. + Returns: `True` if the study is present on disk, `False` otherwise. """ return ( From 340881e629bfbdfa05ceb1761f0ebe98279c9044 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 10 Nov 2023 17:33:57 +0100 Subject: [PATCH 19/22] test(variant): correct the `test_variant_model` unit test to use fixtures --- .../variantstudy/model/test_variant_model.py | 334 +++++++----------- 1 file changed, 123 insertions(+), 211 deletions(-) diff --git a/tests/variantstudy/model/test_variant_model.py b/tests/variantstudy/model/test_variant_model.py index 89e4b96832..7acbce4530 100644 --- a/tests/variantstudy/model/test_variant_model.py +++ b/tests/variantstudy/model/test_variant_model.py @@ -1,242 +1,154 @@ import datetime +import uuid from pathlib import Path -from unittest.mock import ANY, Mock -import numpy as np -from sqlalchemy import create_engine -from sqlalchemy.engine.base import Engine # type: ignore +import pytest -from antarest.core.cache.business.local_chache import LocalCache -from antarest.core.config import Config, StorageConfig, WorkspaceConfig from antarest.core.jwt import JWTGroup, JWTUser -from antarest.core.persistence import Base from antarest.core.requests import RequestParameters from antarest.core.roles import RoleType -from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware, db -from antarest.login.model import User -from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, StudyAdditionalData -from antarest.study.storage.variantstudy.command_factory import CommandFactory -from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy +from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.login.model import Group, Role, User +from antarest.study.model import RawStudy, StudyAdditionalData +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants from antarest.study.storage.variantstudy.model.model import CommandDTO, GenerationResultInfoDTO -from antarest.study.storage.variantstudy.repository import VariantStudyRepository -from antarest.study.storage.variantstudy.variant_study_service import SNAPSHOT_RELATIVE_PATH, VariantStudyService - -# noinspection SpellCheckingInspection -SADMIN = RequestParameters( - user=JWTUser( - id=0, - impersonator=0, - type="users", - groups=[JWTGroup(id="admin", name="admin", role=RoleType.ADMIN)], - ) -) - - -def test_commands_service(tmp_path: Path, db_engine: Engine, command_factory: CommandFactory): - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=db_engine, - session_args={"autocommit": False, "autoflush": False}, - ) - repository = VariantStudyRepository(LocalCache()) - service = VariantStudyService( - raw_study_service=Mock(), - cache=Mock(), - task_service=Mock(), - command_factory=command_factory, - study_factory=Mock(), - config=Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig(path=tmp_path)})), - repository=repository, - event_bus=Mock(), - patch_service=Mock(), - ) - - with db(): - # Add the admin user in the database - db.session.add(User(id=SADMIN.user.id)) - - # sourcery skip: extract-method, inline-variable - # Save a study - origin_id = "origin-id" - # noinspection PyArgumentList - origin_study = RawStudy( - id=origin_id, - name="my-study", - additional_data=StudyAdditionalData(), +from antarest.study.storage.variantstudy.snapshot_generator import SnapshotGenerator +from antarest.study.storage.variantstudy.variant_study_service import VariantStudyService +from tests.helpers import with_db_context + + +class TestVariantStudyService: + @pytest.fixture(name="jwt_user") + def jwt_user_fixture(self) -> JWTUser: + # Create a user in a "Writers" group: + jwt_user = JWTUser( + id=7, + impersonator=7, + type="users", + groups=[JWTGroup(id="writers", name="Writers", role=RoleType.WRITER)], ) - repository.save(origin_study) + # Ensure the user is in database. + with db(): + role = Role( + type=RoleType.WRITER, + identity=User(id=jwt_user.id, name="john.doe"), + group=Group(id="writers"), + ) + db.session.add(role) + db.session.commit() + return jwt_user + + @pytest.fixture(name="root_study_id") + def root_study_id_fixture( + self, + tmp_path: Path, + raw_study_service: RawStudyService, + variant_study_service: VariantStudyService, + jwt_user: JWTUser, + ) -> str: + # Prepare a RAW study in the temporary folder + study_dir = tmp_path / "my-study" + root_study_id = str(uuid.uuid4()) + root_study = RawStudy( + id=root_study_id, + workspace="default", + path=str(study_dir), + version="860", + created_at=datetime.datetime.utcnow(), + updated_at=datetime.datetime.utcnow(), + additional_data=StudyAdditionalData(author="john.doe"), + owner_id=jwt_user.id, + ) + root_study = raw_study_service.create(root_study) + with db(): + # Save the root study in database + variant_study_service.repository.save(root_study) + return root_study_id + + @with_db_context + def test_commands_service( + self, + root_study_id: str, + generator_matrix_constants: GeneratorMatrixConstants, + jwt_user: JWTUser, + variant_study_service: VariantStudyService, + ) -> None: + # Initialize the default matrix constants + # noinspection PyProtectedMember + generator_matrix_constants._init() + + params = RequestParameters(user=jwt_user) # Create un new variant - name = "my-variant" - variant_study = service.create_variant_study(origin_id, name, SADMIN) + variant_study = variant_study_service.create_variant_study(root_study_id, "my-variant", params=params) saved_id = variant_study.id - study = repository.get(saved_id) + study = variant_study_service.repository.get(saved_id) + assert study is not None assert study.id == saved_id - assert study.parent_id == origin_id + assert study.parent_id == root_study_id # Append command + command_count = 0 command_1 = CommandDTO(action="create_area", args={"area_name": "Yes"}) - service.append_command(saved_id, command_1, SADMIN) + variant_study_service.append_command(saved_id, command_1, params=params) + command_count += 1 + command_2 = CommandDTO(action="create_area", args={"area_name": "No"}) - service.append_command(saved_id, command_2, SADMIN) - commands = service.get_commands(saved_id, SADMIN) - assert len(commands) == 2 + variant_study_service.append_command(saved_id, command_2, params=params) + command_count += 1 + + commands = variant_study_service.get_commands(saved_id, params=params) + assert len(commands) == command_count # Append multiple commands command_3 = CommandDTO(action="create_area", args={"area_name": "Maybe"}) - command_4 = CommandDTO(action="create_link", args={"area1": "No", "area2": "Yes"}) - service.append_commands(saved_id, [command_3, command_4], SADMIN) - commands = service.get_commands(saved_id, SADMIN) - assert len(commands) == 4 + command_4 = CommandDTO(action="create_link", args={"area1": "no", "area2": "yes"}) + variant_study_service.append_commands(saved_id, [command_3, command_4], params=params) + command_count += 2 + + commands = variant_study_service.get_commands(saved_id, params=params) + assert len(commands) == command_count # Get command - assert commands[0] == service.get_command(saved_id, commands[0].id, SADMIN) + assert commands[0] == variant_study_service.get_command(saved_id, commands[0].id, params=params) - # Remove command - service.remove_command(saved_id, commands[2].id, SADMIN) - commands = service.get_commands(saved_id, SADMIN) - assert len(commands) == 3 + # Remove command (area "Maybe") + variant_study_service.remove_command(saved_id, commands[2].id, params=params) + command_count -= 1 - # Update command - prepro = np.random.rand(365, 6).tolist() - prepro_id = command_factory.command_context.matrix_service.create(prepro) + # Create a thermal cluster in the area "Yes" command_5 = CommandDTO( - action="replace_matrix", + action="create_cluster", args={ - "target": "some/matrix/path", - "matrix": prepro_id, + "area_id": "yes", + "cluster_name": "cl1", + "parameters": {"group": "Gas", "unitcount": 1, "nominalcapacity": 500}, }, ) - service.update_command( - study_id=saved_id, - command_id=commands[2].id, - command=command_5, - params=SADMIN, - ) - commands = service.get_commands(saved_id, SADMIN) - assert commands[2].action == "replace_matrix" - assert commands[2].args["matrix"] == prepro_id - - # Move command - service.move_command( - study_id=saved_id, - command_id=commands[2].id, - new_index=0, - params=SADMIN, - ) - commands = service.get_commands(saved_id, SADMIN) - assert commands[0].action == "replace_matrix" - - # Generate - service._generate_snapshot = Mock() - service._read_additional_data_from_files = Mock() - service._read_additional_data_from_files.return_value = StudyAdditionalData() - expected_result = GenerationResultInfoDTO(success=True, details=[]) - service._generate_snapshot.return_value = expected_result - results = service._generate(saved_id, SADMIN, False) - assert results == expected_result - assert study.snapshot.id == study.id - - -def test_smart_generation(tmp_path: Path, command_factory: CommandFactory) -> None: - engine = create_engine( - "sqlite:///:memory:", - echo=False, - connect_args={"check_same_thread": False}, - ) - Base.metadata.create_all(engine) - # noinspection SpellCheckingInspection - DBSessionMiddleware( - None, - custom_engine=engine, - session_args={"autocommit": False, "autoflush": False}, - ) - repository = VariantStudyRepository(LocalCache()) - service = VariantStudyService( - raw_study_service=Mock(), - cache=Mock(), - task_service=Mock(), - command_factory=command_factory, - study_factory=Mock(), - config=Config(storage=StorageConfig(workspaces={DEFAULT_WORKSPACE_NAME: WorkspaceConfig(path=tmp_path)})), - repository=repository, - event_bus=Mock(), - patch_service=Mock(), - ) - service.generator = Mock() - service.generator.generate.side_effect = [ - GenerationResultInfoDTO(success=True, details=[]), - GenerationResultInfoDTO(success=True, details=[]), - GenerationResultInfoDTO(success=True, details=[]), - GenerationResultInfoDTO(success=True, details=[]), - ] - - # noinspection PyUnusedLocal - def export_flat( - metadata: VariantStudy, - dst_path: Path, - outputs: bool = True, - denormalize: bool = True, - ) -> None: - dst_path.mkdir(parents=True) - (dst_path / "user").mkdir() - (dst_path / "user" / "some_unmanaged_config").touch() - - service.raw_study_service.export_study_flat.side_effect = export_flat - - with db(): - origin_id = "base-study" - # noinspection PyArgumentList - origin_study = RawStudy( - id=origin_id, - name="my-study", - folder=f"some_place/{origin_id}", - workspace=DEFAULT_WORKSPACE_NAME, - additional_data=StudyAdditionalData(), - updated_at=datetime.datetime(year=2000, month=1, day=1), + variant_study_service.append_command(saved_id, command_5, params=params) + command_count += 1 + + commands = variant_study_service.get_commands(saved_id, params=params) + assert len(commands) == command_count + + # Generate using the SnapshotGenerator + generator = SnapshotGenerator( + cache=variant_study_service.cache, + raw_study_service=variant_study_service.raw_study_service, + command_factory=variant_study_service.command_factory, + study_factory=variant_study_service.study_factory, + patch_service=variant_study_service.patch_service, + repository=variant_study_service.repository, ) - repository.save(origin_study) - - variant_study = service.create_variant_study(origin_id, "my variant", SADMIN) - variant_id = variant_study.id - assert service._get_variant_study(variant_id, SADMIN).folder == "some_place" - unmanaged_user_config_path = tmp_path / variant_id / SNAPSHOT_RELATIVE_PATH / "user" / "some_unmanaged_config" - assert not unmanaged_user_config_path.exists() - - service.append_command( - variant_id, - CommandDTO(action="create_area", args={"area_name": "a"}), - SADMIN, - ) - service._read_additional_data_from_files = Mock() - service._read_additional_data_from_files.return_value = StudyAdditionalData() - service._generate(variant_id, SADMIN, False) - service.generator.generate.assert_called_with([ANY], ANY, ANY, notifier=ANY) - - service._generate(variant_id, SADMIN, False) - service.generator.generate.assert_called_with([], ANY, ANY, notifier=ANY) - - service.append_command( - variant_id, - CommandDTO(action="create_area", args={"area_name": "b"}), - SADMIN, - ) - assert service._get_variant_study(variant_id, SADMIN).snapshot.last_executed_command is not None - service._generate(variant_id, SADMIN, False) - service.generator.generate.assert_called_with([ANY], ANY, ANY, notifier=ANY) - - service.replace_commands( - variant_id, - [ - CommandDTO(action="create_area", args={"area_name": "c"}), - CommandDTO(action="create_area", args={"area_name": "d"}), + results = generator.generate_snapshot(saved_id, jwt_user, denormalize=False) + assert results == GenerationResultInfoDTO( + success=True, + details=[ + ("create_area", True, "Area 'Yes' created"), + ("create_area", True, "Area 'No' created"), + ("create_link", True, "Link between 'no' and 'yes' created"), + ("create_cluster", True, "Thermal cluster 'cl1' added to area 'yes'."), ], - SADMIN, ) - - assert unmanaged_user_config_path.exists() - unmanaged_user_config_path.write_text("hello") - service._generate(variant_id, SADMIN, False) - service.generator.generate.assert_called_with([ANY, ANY], ANY, ANY, notifier=ANY) - assert unmanaged_user_config_path.read_text() == "hello" + assert study.snapshot.id == study.id From 16cf54ff8ef54ce1c7c735a45566ff9e35f96e53 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Sun, 12 Nov 2023 19:09:13 +0100 Subject: [PATCH 20/22] perf(db): improved study query performance using owner and groups preloading --- antarest/study/repository.py | 77 +++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 15 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 3a6447ec9b..598433acb4 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -1,8 +1,8 @@ +import datetime import logging -from datetime import datetime -from typing import List, Optional +import typing as t -from sqlalchemy.orm import with_polymorphic # type: ignore +from sqlalchemy.orm import Session, joinedload, with_polymorphic # type: ignore from antarest.core.interfaces.cache import CacheConstants, ICache from antarest.core.utils.fastapi_sqlalchemy import db @@ -17,8 +17,30 @@ class StudyMetadataRepository: Database connector to manage Study entity """ - def __init__(self, cache_service: ICache): + def __init__(self, cache_service: ICache, session: t.Optional[Session] = None): + """ + Initialize the repository. + + Args: + cache_service: Cache service for the repository. + session: Optional SQLAlchemy session to be used. + """ self.cache_service = cache_service + self._session = session + + @property + def session(self) -> Session: + """ + Get the SQLAlchemy session for the repository. + + Returns: + SQLAlchemy session. + """ + if self._session is None: + # Get or create the session from a context variable (thread local variable) + return db.session + # Get the user-defined session + return self._session def save( self, @@ -29,7 +51,7 @@ def save( metadata_id = metadata.id or metadata.name logger.debug(f"Saving study {metadata_id}") if update_modification_date: - metadata.updated_at = datetime.utcnow() + metadata.updated_at = datetime.datetime.utcnow() metadata.groups = [db.session.merge(g) for g in metadata.groups] if metadata.owner: @@ -44,34 +66,59 @@ def save( def refresh(self, metadata: Study) -> None: db.session.refresh(metadata) - def get(self, id: str) -> Optional[Study]: + def get(self, id: str) -> t.Optional[Study]: """Get the study by ID or return `None` if not found in database.""" - metadata: Study = db.session.query(Study).get(id) + # When we fetch a study, we also need to fetch the associated owner and groups + # to check the permissions of the current user efficiently. + metadata: Study = ( + # fmt: off + db.session.query(Study) + .options(joinedload(Study.owner)) + .options(joinedload(Study.groups)) + .get(id) + # fmt: on + ) return metadata def one(self, id: str) -> Study: """Get the study by ID or raise `sqlalchemy.exc.NoResultFound` if not found in database.""" - study: Study = db.session.query(Study).filter_by(id=id).one() + # When we fetch a study, we also need to fetch the associated owner and groups + # to check the permissions of the current user efficiently. + study: Study = ( + db.session.query(Study) + .options(joinedload(Study.owner)) + .options(joinedload(Study.groups)) + .filter_by(id=id) + .one() + ) return study - def get_list(self, study_id: List[str]) -> List[Study]: - studies: List[Study] = db.session.query(Study).where(Study.id.in_(study_id)).all() + def get_list(self, study_id: t.List[str]) -> t.List[Study]: + # When we fetch a study, we also need to fetch the associated owner and groups + # to check the permissions of the current user efficiently. + studies: t.List[Study] = ( + db.session.query(Study) + .options(joinedload(Study.owner)) + .options(joinedload(Study.groups)) + .where(Study.id.in_(study_id)) + .all() + ) return studies - def get_additional_data(self, study_id: str) -> Optional[StudyAdditionalData]: + def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]: metadata: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id) return metadata - def get_all(self) -> List[Study]: + def get_all(self) -> t.List[Study]: entity = with_polymorphic(Study, "*") - metadatas: List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all() + metadatas: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all() return metadatas - def get_all_raw(self, show_missing: bool = True) -> List[RawStudy]: + def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: query = db.session.query(RawStudy) if not show_missing: query = query.filter(RawStudy.missing.is_(None)) - metadatas: List[RawStudy] = query.all() + metadatas: t.List[RawStudy] = query.all() return metadatas def delete(self, id: str) -> None: From 7dff31affa0bb1d2c07ae4336d4e16d55cfcba94 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 14 Nov 2023 17:12:10 +0100 Subject: [PATCH 21/22] style: correct return type of the get task endpoint --- antarest/core/tasks/web.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antarest/core/tasks/web.py b/antarest/core/tasks/web.py index 7ce33e6740..e2dfb39638 100644 --- a/antarest/core/tasks/web.py +++ b/antarest/core/tasks/web.py @@ -45,7 +45,7 @@ def get_task( with_logs: bool = False, timeout: int = DEFAULT_AWAIT_MAX_TIMEOUT, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> TaskDTO: """ Retrieve information about a specific task. From 8d42c1bd05b6cc1496f60f47eb11e5c99111c1f1 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 14 Nov 2023 17:18:39 +0100 Subject: [PATCH 22/22] style: correct variable naming --- antarest/study/repository.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 598433acb4..94a0220e37 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -70,7 +70,7 @@ def get(self, id: str) -> t.Optional[Study]: """Get the study by ID or return `None` if not found in database.""" # When we fetch a study, we also need to fetch the associated owner and groups # to check the permissions of the current user efficiently. - metadata: Study = ( + study: Study = ( # fmt: off db.session.query(Study) .options(joinedload(Study.owner)) @@ -78,7 +78,7 @@ def get(self, id: str) -> t.Optional[Study]: .get(id) # fmt: on ) - return metadata + return study def one(self, id: str) -> Study: """Get the study by ID or raise `sqlalchemy.exc.NoResultFound` if not found in database.""" @@ -106,20 +106,20 @@ def get_list(self, study_id: t.List[str]) -> t.List[Study]: return studies def get_additional_data(self, study_id: str) -> t.Optional[StudyAdditionalData]: - metadata: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id) - return metadata + study: StudyAdditionalData = db.session.query(StudyAdditionalData).get(study_id) + return study def get_all(self) -> t.List[Study]: entity = with_polymorphic(Study, "*") - metadatas: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all() - return metadatas + studies: t.List[Study] = db.session.query(entity).filter(RawStudy.missing.is_(None)).all() + return studies def get_all_raw(self, show_missing: bool = True) -> t.List[RawStudy]: query = db.session.query(RawStudy) if not show_missing: query = query.filter(RawStudy.missing.is_(None)) - metadatas: t.List[RawStudy] = query.all() - return metadatas + studies: t.List[RawStudy] = query.all() + return studies def delete(self, id: str) -> None: logger.debug(f"Deleting study {id}")