Skip to content

Commit

Permalink
fix(comments): use command to update comments on a variant
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Mar 5, 2024
1 parent c482184 commit 9d94c2a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 37 deletions.
34 changes: 15 additions & 19 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
MatrixIndex,
PatchArea,
PatchCluster,
PatchStudy,
RawStudy,
Study,
StudyAdditionalData,
Expand Down Expand Up @@ -121,6 +120,7 @@
upgrade_study,
)
from antarest.study.storage.utils import assert_permission, get_start_date, is_managed, remove_from_cache
from antarest.study.storage.variantstudy.business.utils import transform_command_to_dto
from antarest.study.storage.variantstudy.model.command.icommand import ICommand
from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix
from antarest.study.storage.variantstudy.model.command.update_comments import UpdateComments
Expand Down Expand Up @@ -395,17 +395,7 @@ def get_comments(self, study_id: str, params: RequestParameters) -> t.Union[str,
study = self.get_study(study_id)
assert_permission(params.user, study, StudyPermissionType.READ)

output: t.Union[str, JSON]
raw_study_service = self.storage_service.raw_study_service
variant_study_service = self.storage_service.variant_study_service
if isinstance(study, RawStudy):
output = raw_study_service.get(metadata=study, url="/settings/comments")
elif isinstance(study, VariantStudy):
patch = raw_study_service.patch_service.get(study)
patch_study = PatchStudy() if patch.study is None else patch.study
output = patch_study.comments or variant_study_service.get(metadata=study, url="/settings/comments")
else:
raise StudyTypeUnsupported(study.id, study.type)
output = self.storage_service.get_storage(study).get(metadata=study, url="/settings/comments")

with contextlib.suppress(AttributeError, UnicodeDecodeError):
output = output.decode("utf-8") # type: ignore
Expand Down Expand Up @@ -440,14 +430,20 @@ def edit_comments(
new=bytes(data.comments, "utf-8"),
params=params,
)
elif isinstance(study, VariantStudy):
patch = self.storage_service.raw_study_service.patch_service.get(study)
patch_study = patch.study or PatchStudy()
patch_study.comments = data.comments
patch.study = patch_study
self.storage_service.raw_study_service.patch_service.save(study, patch)
else:
raise StudyTypeUnsupported(study.id, study.type)
variant_study_service = self.storage_service.variant_study_service
command = [
UpdateRawFile(
target="settings/comments",
b64Data=base64.b64encode(bytes(data.comments, "utf-8")),
command_context=variant_study_service.command_factory.command_context,
)
]
variant_study_service.append_commands(
study.id,
transform_command_to_dto(command, force_aggregate=True),
RequestParameters(user=params.user),
)

def get_studies_information(
self,
Expand Down
11 changes: 1 addition & 10 deletions antarest/study/storage/storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Union

from antarest.core.exceptions import StudyTypeUnsupported
from antarest.study.common.studystorage import IStudyStorageService
from antarest.study.model import RawStudy, Study
from antarest.study.storage.rawstudy.raw_study_service import RawStudyService
Expand Down Expand Up @@ -49,13 +48,5 @@ def get_storage(self, study: Study) -> IStudyStorageService[Union[RawStudy, Vari
Returns:
The study storage service associated with the study type.
Raises:
StudyTypeUnsupported: If the study type is not supported by the available storage services.
"""
if isinstance(study, RawStudy):
return self.raw_study_service
elif isinstance(study, VariantStudy):
return self.variant_study_service
else:
raise StudyTypeUnsupported(study.id, study.type)
return self.raw_study_service if isinstance(study, RawStudy) else self.variant_study_service
46 changes: 38 additions & 8 deletions tests/integration/variant_blueprint/test_variant_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
from antarest.core.tasks.model import TaskDTO, TaskStatus


def test_variant_manager(client: TestClient, admin_access_token: str, study_id: str, caplog) -> None:
with caplog.at_level(level=logging.WARNING):
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
def _set_up_variant_manager(client: TestClient, admin_access_token: str) -> str:
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
base_study_res = client.post("/v1/studies?name=foo", headers=admin_headers)
base_study_id = base_study_res.json()
res = client.post(f"/v1/studies/{base_study_id}/variants?name=foo", headers=admin_headers)
variant_id = res.json()
return admin_headers, base_study_id, variant_id

base_study_res = client.post("/v1/studies?name=foo", headers=admin_headers)

base_study_id = base_study_res.json()

res = client.post(f"/v1/studies/{base_study_id}/variants?name=foo", headers=admin_headers)
variant_id = res.json()
def test_variant_manager(client: TestClient, admin_access_token: str, study_id: str, caplog) -> None:
with caplog.at_level(level=logging.WARNING):
admin_headers, base_study_id, variant_id = _set_up_variant_manager(client, admin_access_token)

client.post(f"/v1/launcher/run/{variant_id}", headers=admin_headers)

Expand Down Expand Up @@ -186,3 +188,31 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:

res = client.get(f"/v1/studies/{variant_id}", headers=admin_headers)
assert res.status_code == 404


def test_comments(client: TestClient, admin_access_token: str, tmp_path: str) -> None:
admin_headers, _, variant_id = _set_up_variant_manager(client, admin_access_token)

# Put comments
comment = "updated comment"
res = client.put(f"/v1/studies/{variant_id}/comments", json={"comments": comment}, headers=admin_headers)
assert res.status_code == 204

# Asserts comments are updated
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment

# Generates the study
res = client.put(f"/v1/studies/{variant_id}/generate?denormalize=false&from_scratch=true", headers=admin_headers)
task_id = res.json()
# Wait for task completion
res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True})
assert res.status_code == 200
task_result = TaskDTO.parse_obj(res.json())
assert task_result.status == TaskStatus.COMPLETED
assert task_result.result is not None
assert task_result.result.success

# Asserts comments did not disappear
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment

0 comments on commit 9d94c2a

Please sign in to comment.