Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(comments): use a command to update comments on variants #1959

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
MatrixIndex,
PatchArea,
PatchCluster,
PatchStudy,
RawStudy,
Study,
StudyAdditionalData,
Expand Down Expand Up @@ -121,6 +120,7 @@
upgrade_study,
)
from antarest.study.storage.utils import assert_permission, get_start_date, is_managed, remove_from_cache
from antarest.study.storage.variantstudy.business.utils import transform_command_to_dto
from antarest.study.storage.variantstudy.model.command.icommand import ICommand
from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix
from antarest.study.storage.variantstudy.model.command.update_comments import UpdateComments
Expand Down Expand Up @@ -395,17 +395,7 @@ def get_comments(self, study_id: str, params: RequestParameters) -> t.Union[str,
study = self.get_study(study_id)
assert_permission(params.user, study, StudyPermissionType.READ)

output: t.Union[str, JSON]
raw_study_service = self.storage_service.raw_study_service
variant_study_service = self.storage_service.variant_study_service
if isinstance(study, RawStudy):
output = raw_study_service.get(metadata=study, url="/settings/comments")
elif isinstance(study, VariantStudy):
patch = raw_study_service.patch_service.get(study)
patch_study = PatchStudy() if patch.study is None else patch.study
output = patch_study.comments or variant_study_service.get(metadata=study, url="/settings/comments")
else:
raise StudyTypeUnsupported(study.id, study.type)
output = self.storage_service.get_storage(study).get(metadata=study, url="/settings/comments")

with contextlib.suppress(AttributeError, UnicodeDecodeError):
output = output.decode("utf-8") # type: ignore
Expand Down Expand Up @@ -440,14 +430,20 @@ def edit_comments(
new=bytes(data.comments, "utf-8"),
params=params,
)
elif isinstance(study, VariantStudy):
patch = self.storage_service.raw_study_service.patch_service.get(study)
patch_study = patch.study or PatchStudy()
patch_study.comments = data.comments
patch.study = patch_study
self.storage_service.raw_study_service.patch_service.save(study, patch)
else:
raise StudyTypeUnsupported(study.id, study.type)
variant_study_service = self.storage_service.variant_study_service
command = [
UpdateRawFile(
target="settings/comments",
b64Data=base64.b64encode(data.comments.encode("utf-8")).decode("utf-8"),
command_context=variant_study_service.command_factory.command_context,
)
]
variant_study_service.append_commands(
study.id,
transform_command_to_dto(command, force_aggregate=True),
RequestParameters(user=params.user),
)

def get_studies_information(
self,
Expand Down
11 changes: 1 addition & 10 deletions antarest/study/storage/storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Union

from antarest.core.exceptions import StudyTypeUnsupported
from antarest.study.common.studystorage import IStudyStorageService
from antarest.study.model import RawStudy, Study
from antarest.study.storage.rawstudy.raw_study_service import RawStudyService
Expand Down Expand Up @@ -49,13 +48,5 @@ def get_storage(self, study: Study) -> IStudyStorageService[Union[RawStudy, Vari

Returns:
The study storage service associated with the study type.

Raises:
StudyTypeUnsupported: If the study type is not supported by the available storage services.
"""
if isinstance(study, RawStudy):
return self.raw_study_service
elif isinstance(study, VariantStudy):
return self.variant_study_service
else:
raise StudyTypeUnsupported(study.id, study.type)
return self.raw_study_service if isinstance(study, RawStudy) else self.variant_study_service
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class UpdateRawFile(ICommand):
target: str
b64Data: str

def __repr__(self) -> str:
cls = self.__class__.__name__
target = self.target
try:
data = base64.decodebytes(self.b64Data.encode("utf-8")).decode("utf-8")
return f"{cls}(target={target!r}, data={data!r})"
except (ValueError, TypeError):
return f"{cls}(target={target!r}, b64Data={self.b64Data!r})"

def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]:
return CommandOutput(status=True, message="ok"), {}

Expand Down
72 changes: 62 additions & 10 deletions tests/integration/variant_blueprint/test_variant_manager.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,45 @@
import logging
import typing as t

import pytest
from starlette.testclient import TestClient

from antarest.core.tasks.model import TaskDTO, TaskStatus


def test_variant_manager(client: TestClient, admin_access_token: str, study_id: str, caplog) -> None:
@pytest.fixture(name="base_study_id")
def base_study_id_fixture(client: TestClient, admin_access_token: str, caplog: t.Any) -> str:
"""Create a base study and return its ID."""
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
with caplog.at_level(level=logging.WARNING):
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

base_study_res = client.post("/v1/studies?name=foo", headers=admin_headers)
res = client.post("/v1/studies?name=Base1", headers=admin_headers)
return t.cast(str, res.json())


@pytest.fixture(name="variant_id")
def variant_id_fixture(
client: TestClient,
admin_access_token: str,
base_study_id: str,
caplog: t.Any,
) -> str:
"""Create a variant and return its ID."""
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}
with caplog.at_level(level=logging.WARNING):
res = client.post(f"/v1/studies/{base_study_id}/variants?name=Variant1", headers=admin_headers)
return t.cast(str, res.json())

base_study_id = base_study_res.json()

res = client.post(f"/v1/studies/{base_study_id}/variants?name=foo", headers=admin_headers)
variant_id = res.json()
def test_variant_manager(
client: TestClient,
admin_access_token: str,
base_study_id: str,
variant_id: str,
caplog: t.Any,
) -> None:
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

with caplog.at_level(level=logging.WARNING):
client.post(f"/v1/launcher/run/{variant_id}", headers=admin_headers)

res = client.get(f"v1/studies/{variant_id}/synthesis", headers=admin_headers)
Expand All @@ -26,9 +50,9 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:
client.post(f"/v1/studies/{variant_id}/variants?name=baz", headers=admin_headers)
res = client.get(f"/v1/studies/{base_study_id}/variants", headers=admin_headers)
children = res.json()
assert children["node"]["name"] == "foo"
assert children["node"]["name"] == "Base1"
assert len(children["children"]) == 1
assert children["children"][0]["node"]["name"] == "foo"
assert children["children"][0]["node"]["name"] == "Variant1"
assert len(children["children"][0]["children"]) == 2
assert children["children"][0]["children"][0]["node"]["name"] == "bar"
assert children["children"][0]["children"][1]["node"]["name"] == "baz"
Expand Down Expand Up @@ -169,7 +193,7 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:
res = client.post(f"/v1/studies/{variant_id}/freeze?name=bar", headers=admin_headers)
assert res.status_code == 500

new_study_id = "newid"
new_study_id = "new_id"

res = client.get(f"/v1/studies/{new_study_id}", headers=admin_headers)
assert res.status_code == 404
Expand All @@ -186,3 +210,31 @@ def test_variant_manager(client: TestClient, admin_access_token: str, study_id:

res = client.get(f"/v1/studies/{variant_id}", headers=admin_headers)
assert res.status_code == 404


def test_comments(client: TestClient, admin_access_token: str, variant_id: str) -> None:
admin_headers = {"Authorization": f"Bearer {admin_access_token}"}

# Put comments
comment = "updated comment"
res = client.put(f"/v1/studies/{variant_id}/comments", json={"comments": comment}, headers=admin_headers)
assert res.status_code == 204

# Asserts comments are updated
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment

# Generates the study
res = client.put(f"/v1/studies/{variant_id}/generate?denormalize=false&from_scratch=true", headers=admin_headers)
task_id = res.json()
# Wait for task completion
res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True})
assert res.status_code == 200
task_result = TaskDTO.parse_obj(res.json())
assert task_result.status == TaskStatus.COMPLETED
assert task_result.result is not None
assert task_result.result.success

# Asserts comments did not disappear
res = client.get(f"/v1/studies/{variant_id}/comments", headers=admin_headers)
assert res.json() == comment
Loading