diff --git a/antarest/study/service.py b/antarest/study/service.py index dc3288b4e2..c13f755aad 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -84,7 +84,6 @@ MatrixIndex, PatchArea, PatchCluster, - PatchStudy, RawStudy, Study, StudyAdditionalData, @@ -121,6 +120,7 @@ upgrade_study, ) from antarest.study.storage.utils import assert_permission, get_start_date, is_managed, remove_from_cache +from antarest.study.storage.variantstudy.business.utils import transform_command_to_dto from antarest.study.storage.variantstudy.model.command.icommand import ICommand from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix from antarest.study.storage.variantstudy.model.command.update_comments import UpdateComments @@ -395,17 +395,7 @@ def get_comments(self, study_id: str, params: RequestParameters) -> t.Union[str, study = self.get_study(study_id) assert_permission(params.user, study, StudyPermissionType.READ) - output: t.Union[str, JSON] - raw_study_service = self.storage_service.raw_study_service - variant_study_service = self.storage_service.variant_study_service - if isinstance(study, RawStudy): - output = raw_study_service.get(metadata=study, url="/settings/comments") - elif isinstance(study, VariantStudy): - patch = raw_study_service.patch_service.get(study) - patch_study = PatchStudy() if patch.study is None else patch.study - output = patch_study.comments or variant_study_service.get(metadata=study, url="/settings/comments") - else: - raise StudyTypeUnsupported(study.id, study.type) + output = self.storage_service.get_storage(study).get(metadata=study, url="/settings/comments") with contextlib.suppress(AttributeError, UnicodeDecodeError): output = output.decode("utf-8") # type: ignore @@ -440,14 +430,20 @@ def edit_comments( new=bytes(data.comments, "utf-8"), params=params, ) - elif isinstance(study, VariantStudy): - patch = self.storage_service.raw_study_service.patch_service.get(study) - patch_study = patch.study or PatchStudy() - patch_study.comments = data.comments - patch.study = patch_study - self.storage_service.raw_study_service.patch_service.save(study, patch) else: - raise StudyTypeUnsupported(study.id, study.type) + variant_study_service = self.storage_service.variant_study_service + command = [ + UpdateRawFile( + target="settings/comments", + b64Data=base64.b64encode(data.comments.encode("utf-8")).decode("utf-8"), + command_context=variant_study_service.command_factory.command_context, + ) + ] + variant_study_service.append_commands( + study.id, + transform_command_to_dto(command, force_aggregate=True), + RequestParameters(user=params.user), + ) def get_studies_information( self, diff --git a/antarest/study/storage/storage_service.py b/antarest/study/storage/storage_service.py index affe97eae1..599e948948 100644 --- a/antarest/study/storage/storage_service.py +++ b/antarest/study/storage/storage_service.py @@ -5,7 +5,6 @@ from typing import Union -from antarest.core.exceptions import StudyTypeUnsupported from antarest.study.common.studystorage import IStudyStorageService from antarest.study.model import RawStudy, Study from antarest.study.storage.rawstudy.raw_study_service import RawStudyService @@ -49,13 +48,5 @@ def get_storage(self, study: Study) -> IStudyStorageService[Union[RawStudy, Vari Returns: The study storage service associated with the study type. - - Raises: - StudyTypeUnsupported: If the study type is not supported by the available storage services. """ - if isinstance(study, RawStudy): - return self.raw_study_service - elif isinstance(study, VariantStudy): - return self.variant_study_service - else: - raise StudyTypeUnsupported(study.id, study.type) + return self.raw_study_service if isinstance(study, RawStudy) else self.variant_study_service diff --git a/antarest/study/storage/variantstudy/model/command/update_raw_file.py b/antarest/study/storage/variantstudy/model/command/update_raw_file.py index c4b6cfb46b..3e7b3b8759 100644 --- a/antarest/study/storage/variantstudy/model/command/update_raw_file.py +++ b/antarest/study/storage/variantstudy/model/command/update_raw_file.py @@ -26,6 +26,15 @@ class UpdateRawFile(ICommand): target: str b64Data: str + def __repr__(self) -> str: + cls = self.__class__.__name__ + target = self.target + try: + data = base64.decodebytes(self.b64Data.encode("utf-8")).decode("utf-8") + return f"{cls}(target={target!r}, data={data!r})" + except (ValueError, TypeError): + return f"{cls}(target={target!r}, b64Data={self.b64Data!r})" + def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: return CommandOutput(status=True, message="ok"), {} diff --git a/tests/integration/variant_blueprint/test_variant_manager.py b/tests/integration/variant_blueprint/test_variant_manager.py index 5af256dbbe..df3cf590e4 100644 --- a/tests/integration/variant_blueprint/test_variant_manager.py +++ b/tests/integration/variant_blueprint/test_variant_manager.py @@ -1,21 +1,45 @@ import logging +import typing as t +import pytest from starlette.testclient import TestClient from antarest.core.tasks.model import TaskDTO, TaskStatus -def test_variant_manager(client: TestClient, admin_access_token: str, study_id: str, caplog) -> None: +@pytest.fixture(name="base_study_id") +def base_study_id_fixture(client: TestClient, admin_access_token: str, caplog: t.Any) -> str: + """Create a base study and return its ID.""" + admin_headers = {"Authorization": f"Bearer {admin_access_token}"} with caplog.at_level(level=logging.WARNING): - admin_headers = {"Authorization": f"Bearer {admin_access_token}"} - - base_study_res = client.post("/v1/studies?name=foo", headers=admin_headers) + res = client.post("/v1/studies?name=Base1", headers=admin_headers) + return t.cast(str, res.json()) + + +@pytest.fixture(name="variant_id") +def variant_id_fixture( + client: TestClient, + admin_access_token: str, + base_study_id: str, + caplog: t.Any, +) -> str: + """Create a variant and return its ID.""" + admin_headers = {"Authorization": f"Bearer {admin_access_token}"} + with caplog.at_level(level=logging.WARNING): + res = client.post(f"/v1/studies/{base_study_id}/variants?name=Variant1", headers=admin_headers) + return t.cast(str, res.json()) - base_study_id = base_study_res.json() - res = client.post(f"/v1/studies/{base_study_id}/variants?name=foo", headers=admin_headers) - variant_id = res.json() +def test_variant_manager( + client: TestClient, + admin_access_token: str, + base_study_id: str, + variant_id: str, + caplog: t.Any, +) -> None: + admin_headers = {"Authorization": f"Bearer {admin_access_token}"} + with caplog.at_level(level=logging.WARNING): client.post(f"/v1/launcher/run/{variant_id}", headers=admin_headers) res = client.get(f"v1/studies/{variant_id}/synthesis", headers=admin_headers) @@ -26,9 +50,9 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id: client.post(f"/v1/studies/{variant_id}/variants?name=baz", headers=admin_headers) res = client.get(f"/v1/studies/{base_study_id}/variants", headers=admin_headers) children = res.json() - assert children["node"]["name"] == "foo" + assert children["node"]["name"] == "Base1" assert len(children["children"]) == 1 - assert children["children"][0]["node"]["name"] == "foo" + assert children["children"][0]["node"]["name"] == "Variant1" assert len(children["children"][0]["children"]) == 2 assert children["children"][0]["children"][0]["node"]["name"] == "bar" assert children["children"][0]["children"][1]["node"]["name"] == "baz" @@ -169,7 +193,7 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id: res = client.post(f"/v1/studies/{variant_id}/freeze?name=bar", headers=admin_headers) assert res.status_code == 500 - new_study_id = "newid" + new_study_id = "new_id" res = client.get(f"/v1/studies/{new_study_id}", headers=admin_headers) assert res.status_code == 404 @@ -186,3 +210,31 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id: res = client.get(f"/v1/studies/{variant_id}", headers=admin_headers) assert res.status_code == 404 + + +def test_comments(client: TestClient, admin_access_token: str, variant_id: str) -> None: + admin_headers = {"Authorization": f"Bearer {admin_access_token}"} + + # Put comments + comment = "updated comment" + res = client.put(f"/v1/studies/{variant_id}/comments", json={"comments": comment}, headers=admin_headers) + assert res.status_code == 204 + + # Asserts comments are updated + res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers) + assert res.json() == comment + + # Generates the study + res = client.put(f"/v1/studies/{variant_id}/generate?denormalize=false&from_scratch=true", headers=admin_headers) + task_id = res.json() + # Wait for task completion + res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True}) + assert res.status_code == 200 + task_result = TaskDTO.parse_obj(res.json()) + assert task_result.status == TaskStatus.COMPLETED + assert task_result.result is not None + assert task_result.result.success + + # Asserts comments did not disappear + res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers) + assert res.json() == comment