Skip to content

Commit

Permalink
fix(comments): use a command to update comments on a variant (#1959)
Browse files Browse the repository at this point in the history
Co-authored-by: Laurent LAPORTE <[email protected]>
  • Loading branch information
MartinBelthle and laurent-laporte-pro authored Mar 5, 2024
1 parent c482184 commit f7f082a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 39 deletions.
34 changes: 15 additions & 19 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
MatrixIndex,
PatchArea,
PatchCluster,
PatchStudy,
RawStudy,
Study,
StudyAdditionalData,
Expand Down Expand Up @@ -121,6 +120,7 @@
upgrade_study,
)
from antarest.study.storage.utils import assert_permission, get_start_date, is_managed, remove_from_cache
from antarest.study.storage.variantstudy.business.utils import transform_command_to_dto
from antarest.study.storage.variantstudy.model.command.icommand import ICommand
from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix
from antarest.study.storage.variantstudy.model.command.update_comments import UpdateComments
Expand Down Expand Up @@ -395,17 +395,7 @@ def get_comments(self, study_id: str, params: RequestParameters) -> t.Union[str,
study = self.get_study(study_id)
assert_permission(params.user, study, StudyPermissionType.READ)

output: t.Union[str, JSON]
raw_study_service = self.storage_service.raw_study_service
variant_study_service = self.storage_service.variant_study_service
if isinstance(study, RawStudy):
output = raw_study_service.get(metadata=study, url="/settings/comments")
elif isinstance(study, VariantStudy):
patch = raw_study_service.patch_service.get(study)
patch_study = PatchStudy() if patch.study is None else patch.study
output = patch_study.comments or variant_study_service.get(metadata=study, url="/settings/comments")
else:
raise StudyTypeUnsupported(study.id, study.type)
output = self.storage_service.get_storage(study).get(metadata=study, url="/settings/comments")

with contextlib.suppress(AttributeError, UnicodeDecodeError):
output = output.decode("utf-8") # type: ignore
Expand Down Expand Up @@ -440,14 +430,20 @@ def edit_comments(
new=bytes(data.comments, "utf-8"),
params=params,
)
elif isinstance(study, VariantStudy):
patch = self.storage_service.raw_study_service.patch_service.get(study)
patch_study = patch.study or PatchStudy()
patch_study.comments = data.comments
patch.study = patch_study
self.storage_service.raw_study_service.patch_service.save(study, patch)
else:
raise StudyTypeUnsupported(study.id, study.type)
variant_study_service = self.storage_service.variant_study_service
command = [
UpdateRawFile(
target="settings/comments",
b64Data=base64.b64encode(data.comments.encode("utf-8")).decode("utf-8"),
command_context=variant_study_service.command_factory.command_context,
)
]
variant_study_service.append_commands(
study.id,
transform_command_to_dto(command, force_aggregate=True),
RequestParameters(user=params.user),
)

def get_studies_information(
self,
Expand Down
11 changes: 1 addition & 10 deletions antarest/study/storage/storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Union

from antarest.core.exceptions import StudyTypeUnsupported
from antarest.study.common.studystorage import IStudyStorageService
from antarest.study.model import RawStudy, Study
from antarest.study.storage.rawstudy.raw_study_service import RawStudyService
Expand Down Expand Up @@ -49,13 +48,5 @@ def get_storage(self, study: Study) -> IStudyStorageService[Union[RawStudy, Vari
Returns:
The study storage service associated with the study type.
Raises:
StudyTypeUnsupported: If the study type is not supported by the available storage services.
"""
if isinstance(study, RawStudy):
return self.raw_study_service
elif isinstance(study, VariantStudy):
return self.variant_study_service
else:
raise StudyTypeUnsupported(study.id, study.type)
return self.raw_study_service if isinstance(study, RawStudy) else self.variant_study_service
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class UpdateRawFile(ICommand):
target: str
b64Data: str

def __repr__(self) -> str:
cls = self.__class__.__name__
target = self.target
try:
data = base64.decodebytes(self.b64Data.encode("utf-8")).decode("utf-8")
return f"{cls}(target={target!r}, data={data!r})"
except (ValueError, TypeError):
return f"{cls}(target={target!r}, b64Data={self.b64Data!r})"

def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]:
return CommandOutput(status=True, message="ok"), {}

Expand Down
72 changes: 62 additions & 10 deletions tests/integration/variant_blueprint/test_variant_manager.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,45 @@
import logging
import typing as t

import pytest
from starlette.testclient import TestClient

from antarest.core.tasks.model import TaskDTO, TaskStatus


def test_variant_manager(client: TestClient, admin_access_token: str, study_id: str, caplog) -> None:
@pytest.fixture(name="base_study_id")
def base_study_id_fixture(client: TestClient, admin_access_token: str, caplog: t.Any) -> str:
"""Create a base study and return its ID."""
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
with caplog.at_level(level=logging.WARNING):
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

base_study_res = client.post("/v1/studies?name=foo", headers=admin_headers)
res = client.post("/v1/studies?name=Base1", headers=admin_headers)
return t.cast(str, res.json())


@pytest.fixture(name="variant_id")
def variant_id_fixture(
client: TestClient,
admin_access_token: str,
base_study_id: str,
caplog: t.Any,
) -> str:
"""Create a variant and return its ID."""
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
with caplog.at_level(level=logging.WARNING):
res = client.post(f"/v1/studies/{base_study_id}/variants?name=Variant1", headers=admin_headers)
return t.cast(str, res.json())

base_study_id = base_study_res.json()

res = client.post(f"/v1/studies/{base_study_id}/variants?name=foo", headers=admin_headers)
variant_id = res.json()
def test_variant_manager(
client: TestClient,
admin_access_token: str,
base_study_id: str,
variant_id: str,
caplog: t.Any,
) -> None:
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

with caplog.at_level(level=logging.WARNING):
client.post(f"/v1/launcher/run/{variant_id}", headers=admin_headers)

res = client.get(f"v1/studies/{variant_id}/synthesis", headers=admin_headers)
Expand All @@ -26,9 +50,9 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:
client.post(f"/v1/studies/{variant_id}/variants?name=baz", headers=admin_headers)
res = client.get(f"/v1/studies/{base_study_id}/variants", headers=admin_headers)
children = res.json()
assert children["node"]["name"] == "foo"
assert children["node"]["name"] == "Base1"
assert len(children["children"]) == 1
assert children["children"][0]["node"]["name"] == "foo"
assert children["children"][0]["node"]["name"] == "Variant1"
assert len(children["children"][0]["children"]) == 2
assert children["children"][0]["children"][0]["node"]["name"] == "bar"
assert children["children"][0]["children"][1]["node"]["name"] == "baz"
Expand Down Expand Up @@ -169,7 +193,7 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:
res = client.post(f"/v1/studies/{variant_id}/freeze?name=bar", headers=admin_headers)
assert res.status_code == 500

new_study_id = "newid"
new_study_id = "new_id"

res = client.get(f"/v1/studies/{new_study_id}", headers=admin_headers)
assert res.status_code == 404
Expand All @@ -186,3 +210,31 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:

res = client.get(f"/v1/studies/{variant_id}", headers=admin_headers)
assert res.status_code == 404


def test_comments(client: TestClient, admin_access_token: str, variant_id: str) -> None:
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

# Put comments
comment = "updated comment"
res = client.put(f"/v1/studies/{variant_id}/comments", json={"comments": comment}, headers=admin_headers)
assert res.status_code == 204

# Asserts comments are updated
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment

# Generates the study
res = client.put(f"/v1/studies/{variant_id}/generate?denormalize=false&from_scratch=true", headers=admin_headers)
task_id = res.json()
# Wait for task completion
res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True})
assert res.status_code == 200
task_result = TaskDTO.parse_obj(res.json())
assert task_result.status == TaskStatus.COMPLETED
assert task_result.result is not None
assert task_result.result.success

# Asserts comments did not disappear
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment

0 comments on commit f7f082a

Please sign in to comment.