Skip to content

Commit

Permalink
fix(upgrade): raise an HTTP 417 exception when an upgrade has unmet r…
Browse files Browse the repository at this point in the history
…equirements (#2047)

Context: 
This PR addresses two potential issues that can occur when upgrading
studies:
- Upgrading a variant study could potentially corrupt it.
- Upgrading a parent study that has variant children could potentially
corrupt those children.

Issue: 
At present, our system does not support the upgrade of commands, which
complicates the resolution of these issues. Until we develop a solution
for upgrading commands, this PR implements measures to prevent the two
situations mentioned above.

Solution: 
The modifications in this PR will trigger a 417 HTTP ERROR if an attempt
is made to upgrade a study that either has children or is a variant.
This is a temporary measure until we implement a solution for upgrading
commands.
  • Loading branch information
mabw-rte authored Jun 5, 2024
1 parent 1283d80 commit 5545e6f
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 28 deletions.
11 changes: 11 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,17 @@ def __init__(self, uuid: str, message: t.Optional[str] = None) -> None:
)


class StudyVariantUpgradeError(HTTPException):
def __init__(self, is_variant: bool) -> None:
if is_variant:
super().__init__(
HTTPStatus.EXPECTATION_FAILED,
"Upgrade not supported for variant study",
)
else:
super().__init__(HTTPStatus.EXPECTATION_FAILED, "Upgrade not supported for parent of variants")


class UnsupportedStudyVersion(HTTPException):
def __init__(self, version: str) -> None:
super().__init__(
Expand Down
13 changes: 13 additions & 0 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,16 @@ def list_duplicates(self) -> t.List[t.Tuple[str, str]]:
subquery = session.query(Study.path).group_by(Study.path).having(func.count() > 1).subquery()
query = session.query(Study.id, Study.path).filter(Study.path.in_(subquery))
return t.cast(t.List[t.Tuple[str, str]], query.all())

def has_children(self, uuid: str) -> bool:
"""
Check if a study has children.
Args:
uuid: The `uuid` of the study to check.
Returns:
True if the study has children, False otherwise.
"""

return self.session.query(Study).filter(Study.parent_id == uuid).first() is not None
14 changes: 14 additions & 0 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
StudyDeletionNotAllowed,
StudyNotFoundError,
StudyTypeUnsupported,
StudyVariantUpgradeError,
TaskAlreadyRunning,
UnsupportedOperationOnArchivedStudy,
UnsupportedStudyVersion,
Expand Down Expand Up @@ -2384,6 +2385,19 @@ def upgrade_study(
assert_permission(params.user, study, StudyPermissionType.WRITE)
self._assert_study_unarchived(study)

# The upgrade of a study variant requires the use of a command specifically dedicated to the upgrade.
# However, such a command does not currently exist. Moreover, upgrading a study (whether raw or variant)
# directly impacts its descendants, as it would necessitate upgrading all of them.
# It’s uncertain whether this would be an acceptable behavior.
# For this reason, upgrading a study is not possible if the study is a variant or if it has descendants.

# First check if the study is a variant study, if so throw an error
if isinstance(study, VariantStudy):
raise StudyVariantUpgradeError(True)
# If the study is a parent raw study, throw an error
elif self.repository.has_children(study_id):
raise StudyVariantUpgradeError(False)

target_version = target_version or find_next_version(study.version)
if not target_version:
raise UnsupportedStudyVersion(study.version)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def clear_snapshot(self, variant_study: Study) -> None:
shutil.rmtree(self.get_study_path(variant_study), ignore_errors=True)

def has_children(self, study: VariantStudy) -> bool:
return len(self.repository.get_children(parent_id=study.id)) > 0
return self.repository.has_children(study.id)

def get_all_variants_children(
self,
Expand Down
50 changes: 50 additions & 0 deletions tests/integration/test_studies_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,53 @@ def test_upgrade_study__bad_target_version(self, client: TestClient, user_access
task = wait_task_completion(client, user_access_token, task_id)
assert task.status == TaskStatus.FAILED
assert target_version in task.result.message, f"Version not in {task.result.message=}"

def test_upgrade_study__unmet_requirements(self, client: TestClient, admin_access_token: str):
"""
Test that an upgrade with unmet requirements fails, corresponding to the two following cases:
- the study is a raw study with at least one child study
- the study is a variant study
"""

# set the admin access token in the client headers
client.headers = {"Authorization": f"Bearer {admin_access_token}"}

# create a raw study
res = client.post(
"/v1/studies",
params={"name": "My Study"},
)
assert res.status_code == 201, res.json()
uuid = res.json()

# create a child variant study
res = client.post(
f"/v1/studies/{uuid}/variants",
params={"name": "foo"},
)
assert res.status_code == 200, res.json()
child_uuid = res.json()

# upgrade the raw study
res = client.put(
f"/v1/studies/{uuid}/upgrade",
)

# check that the upgrade failed (HttpException:417, with the expected message)
assert res.status_code == 417, res.json()
assert res.json() == {
"description": "Upgrade not supported for parent of variants",
"exception": "StudyVariantUpgradeError",
}

# upgrade the variant study
res = client.put(
f"/v1/studies/{child_uuid}/upgrade",
)

# check that the upgrade failed (HttpException:417, with the expected message)
assert res.status_code == 417, res.json()
assert res.json() == {
"description": "Upgrade not supported for variant study",
"exception": "StudyVariantUpgradeError",
}
146 changes: 119 additions & 27 deletions tests/storage/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from starlette.responses import Response

from antarest.core.config import Config, StorageConfig, WorkspaceConfig
from antarest.core.exceptions import TaskAlreadyRunning
from antarest.core.exceptions import StudyVariantUpgradeError, TaskAlreadyRunning
from antarest.core.filetransfer.model import FileDownload, FileDownloadTaskDTO
from antarest.core.interfaces.cache import ICache
from antarest.core.interfaces.eventbus import Event, EventType
Expand Down Expand Up @@ -1104,9 +1104,9 @@ def test_delete_with_prefetch(tmp_path: Path) -> None:
seal(study_mock)

study_metadata_repository.get.return_value = study_mock
variant_study_repository.get_children.return_value = []
variant_study_repository.has_children.return_value = False

# if this fails, it may means the study metadata mock is missing some attribute definition
# if this fails, it may mean the study metadata mock is missing some attribute definition
# this test is here to prevent errors if we add attribute fetching from child classes
# (attributes in polymorphism are lazy)
# see the comment in the delete method for more information
Expand All @@ -1133,7 +1133,7 @@ def test_delete_with_prefetch(tmp_path: Path) -> None:
seal(study_mock)

study_metadata_repository.get.return_value = study_mock
variant_study_repository.get_children.return_value = []
variant_study_repository.has_children.return_value = False

# if this fails, it may means the study metadata mock is missing some definition
# this test is here to prevent errors if we add attribute fetching from child classes (attributes in polymorphism are lazy)
Expand Down Expand Up @@ -1180,8 +1180,7 @@ def create_study_fs_mock(variant: bool = False) -> str:
return str(_study_dir)

study_path = create_study_fs_mock()
study_mock = Mock(
spec=RawStudy,
study_mock = RawStudy(
archived=False,
id="my_study",
path=study_path,
Expand All @@ -1191,30 +1190,72 @@ def create_study_fs_mock(variant: bool = False) -> str:
workspace=DEFAULT_WORKSPACE_NAME,
last_access=datetime.utcnow(),
)
study_mock.to_json_summary.return_value = {"id": "my_study", "name": "foo"}

# it freezes the mock and raise Attribute error if anything else than defined is used
seal(study_mock)

v1 = VariantStudy(id="variant_1", path=create_study_fs_mock(variant=True))
v2 = VariantStudy(id="variant_2", path=create_study_fs_mock(variant=True))
v3 = VariantStudy(id="sub_variant_1", path=create_study_fs_mock(variant=True))

study_metadata_repository.get.side_effect = [study_mock, v3, v1, v2]
variant_study_repository.get_children.side_effect = [
[v1, v2],
[v3],
[],
[],
[],
[],
[],
]
variant_study_repository.get.side_effect = [
VariantStudy(id="variant_1"),
VariantStudy(id="sub_variant_1"),
VariantStudy(id="variant_2"),
]
def get_study(study_id: str) -> Study:
if study_id == "my_study":
return study_mock
elif study_id == "variant_1":
return v1
elif study_id == "variant_2":
return v2
elif study_id == "sub_variant_1":
return v3
raise ValueError(f"Unexpected study id: {study_id}")

class ChildrenProvider:
def __init__(self):
self.c0 = 0
self.c1 = 0

def get_children(self, parent_id: str) -> t.List[Study]:
if parent_id == "my_study":
if self.c0 > 0:
return []
self.c0 = 1
return [v1, v2]
elif parent_id == "variant_1":
if self.c1 > 0:
return []
self.c1 = 1
return [v3]
elif parent_id == "variant_2":
return []
elif parent_id == "sub_variant_1":
return []
raise ValueError(f"Unexpected study id: {parent_id}")

class HasChildrenProvider:
def __init__(self):
self.c1 = 0
self.c2 = 0

def has_children(self, study_id: str) -> bool:
if study_id == "my_study":
if self.c1 > 0:
return False
self.c1 = 1
return True
elif study_id == "variant_1":
if self.c2 > 0:
return False
self.c2 = 1
return True
elif study_id == "variant_2":
return False
elif study_id == "sub_variant_1":
return False
raise ValueError(f"Unexpected study id: {study_id}")

children_provider = ChildrenProvider()
hash_children_provider = HasChildrenProvider()
study_metadata_repository.get = get_study
variant_study_repository.get = get_study
variant_study_repository.get_children = children_provider.get_children
variant_study_repository.has_children = hash_children_provider.has_children

service.delete_study(
"my_study",
Expand Down Expand Up @@ -1594,17 +1635,18 @@ def test_task_upgrade_study(tmp_path: Path) -> None:
)
study_mock.name = "my_study"
study_mock.to_json_summary.return_value = {"id": "my_study", "name": "foo"}
service.repository.has_children.return_value = False # type: ignore
service.repository.get.return_value = study_mock # type: ignore

study_id = "my_study"
service.task_service.reset_mock()
service.task_service.reset_mock() # type: ignore
service.task_service.list_tasks.side_effect = [
[
TaskDTO(
id="1",
name=f"Upgrade study my_study ({study_id}) to version 800",
status=TaskStatus.RUNNING,
creation_date_utc=str(datetime.utcnow()),
creation_date_utc=str(datetime.utcnow()), # type: ignore
type=TaskType.UNARCHIVE,
ref_id=study_id,
)
Expand Down Expand Up @@ -1634,6 +1676,56 @@ def test_task_upgrade_study(tmp_path: Path) -> None:
request_params=RequestParameters(user=DEFAULT_ADMIN_USER),
)

# check that a variant study or a raw study with children cannot be upgraded
parent_raw_study = Mock(
spec=RawStudy,
archived=False,
id="parent_raw_study",
name="parent_raw_study",
path=tmp_path,
version="720",
owner=None,
groups=[],
public_mode=PublicMode.NONE,
workspace="other_workspace",
)
study_mock.name = "parent_raw_study"
study_mock.to_json_summary.return_value = {"id": "parent_raw_study", "name": "parent_raw_study"}
service.repository.has_children.return_value = True # type: ignore
service.repository.get.return_value = parent_raw_study # type: ignore

with pytest.raises(StudyVariantUpgradeError):
service.upgrade_study(
"parent_raw_study",
target_version="",
params=RequestParameters(user=DEFAULT_ADMIN_USER),
)

variant_study = Mock(
spec=VariantStudy,
archived=False,
id="variant_study",
name="variant_study",
path=tmp_path,
version="720",
owner=None,
groups=[],
public_mode=PublicMode.NONE,
workspace="other_workspace",
)

study_mock.name = "variant_study"
study_mock.to_json_summary.return_value = {"id": "variant_study", "name": "variant_study"}
service.repository.has_children.return_value = True # type: ignore
service.repository.get.return_value = variant_study # type: ignore

with pytest.raises(StudyVariantUpgradeError):
service.upgrade_study(
"variant_study",
target_version="",
params=RequestParameters(user=DEFAULT_ADMIN_USER),
)


@with_db_context
@patch("antarest.study.service.upgrade_study")
Expand Down

0 comments on commit 5545e6f

Please sign in to comment.