diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index 87e7b04c38..405b91b48d 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -314,6 +314,17 @@ def __init__(self, uuid: str, message: t.Optional[str] = None) -> None: ) +class StudyVariantUpgradeError(HTTPException): + def __init__(self, is_variant: bool) -> None: + if is_variant: + super().__init__( + HTTPStatus.EXPECTATION_FAILED, + "Upgrade not supported for variant study", + ) + else: + super().__init__(HTTPStatus.EXPECTATION_FAILED, "Upgrade not supported for parent of variants") + + class UnsupportedStudyVersion(HTTPException): def __init__(self, version: str) -> None: super().__init__( diff --git a/antarest/study/repository.py b/antarest/study/repository.py index 81c2463c69..4a3965d9de 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -416,3 +416,16 @@ def list_duplicates(self) -> t.List[t.Tuple[str, str]]: subquery = session.query(Study.path).group_by(Study.path).having(func.count() > 1).subquery() query = session.query(Study.id, Study.path).filter(Study.path.in_(subquery)) return t.cast(t.List[t.Tuple[str, str]], query.all()) + + def has_children(self, uuid: str) -> bool: + """ + Check if a study has children. + + Args: + uuid: The `uuid` of the study to check. + + Returns: + True if the study has children, False otherwise. + """ + + return self.session.query(Study).filter(Study.parent_id == uuid).first() is not None diff --git a/antarest/study/service.py b/antarest/study/service.py index 16ff611022..0330992da2 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -27,6 +27,7 @@ StudyDeletionNotAllowed, StudyNotFoundError, StudyTypeUnsupported, + StudyVariantUpgradeError, TaskAlreadyRunning, UnsupportedOperationOnArchivedStudy, UnsupportedStudyVersion, @@ -2384,6 +2385,19 @@ def upgrade_study( assert_permission(params.user, study, StudyPermissionType.WRITE) self._assert_study_unarchived(study) + # The upgrade of a study variant requires the use of a command specifically dedicated to the upgrade. + # However, such a command does not currently exist. Moreover, upgrading a study (whether raw or variant) + # directly impacts its descendants, as it would necessitate upgrading all of them. + # It’s uncertain whether this would be an acceptable behavior. + # For this reason, upgrading a study is not possible if the study is a variant or if it has descendants. + + # First check if the study is a variant study, if so throw an error + if isinstance(study, VariantStudy): + raise StudyVariantUpgradeError(True) + # If the study is a parent raw study, throw an error + elif self.repository.has_children(study_id): + raise StudyVariantUpgradeError(False) + target_version = target_version or find_next_version(study.version) if not target_version: raise UnsupportedStudyVersion(study.version) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index f4b342ad7b..1452071aae 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -377,7 +377,7 @@ def clear_snapshot(self, variant_study: Study) -> None: shutil.rmtree(self.get_study_path(variant_study), ignore_errors=True) def has_children(self, study: VariantStudy) -> bool: - return len(self.repository.get_children(parent_id=study.id)) > 0 + return self.repository.has_children(study.id) def get_all_variants_children( self, diff --git a/tests/integration/test_studies_upgrade.py b/tests/integration/test_studies_upgrade.py index 366ba88e77..f734223344 100644 --- a/tests/integration/test_studies_upgrade.py +++ b/tests/integration/test_studies_upgrade.py @@ -52,3 +52,53 @@ def test_upgrade_study__bad_target_version(self, client: TestClient, user_access task = wait_task_completion(client, user_access_token, task_id) assert task.status == TaskStatus.FAILED assert target_version in task.result.message, f"Version not in {task.result.message=}" + + def test_upgrade_study__unmet_requirements(self, client: TestClient, admin_access_token: str): + """ + Test that an upgrade with unmet requirements fails, corresponding to the two following cases: + - the study is a raw study with at least one child study + - the study is a variant study + """ + + # set the admin access token in the client headers + client.headers = {"Authorization": f"Bearer {admin_access_token}"} + + # create a raw study + res = client.post( + "/v1/studies", + params={"name": "My Study"}, + ) + assert res.status_code == 201, res.json() + uuid = res.json() + + # create a child variant study + res = client.post( + f"/v1/studies/{uuid}/variants", + params={"name": "foo"}, + ) + assert res.status_code == 200, res.json() + child_uuid = res.json() + + # upgrade the raw study + res = client.put( + f"/v1/studies/{uuid}/upgrade", + ) + + # check that the upgrade failed (HttpException:417, with the expected message) + assert res.status_code == 417, res.json() + assert res.json() == { + "description": "Upgrade not supported for parent of variants", + "exception": "StudyVariantUpgradeError", + } + + # upgrade the variant study + res = client.put( + f"/v1/studies/{child_uuid}/upgrade", + ) + + # check that the upgrade failed (HttpException:417, with the expected message) + assert res.status_code == 417, res.json() + assert res.json() == { + "description": "Upgrade not supported for variant study", + "exception": "StudyVariantUpgradeError", + } diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index 5e58f1d1ba..891099dbcc 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -11,7 +11,7 @@ from starlette.responses import Response from antarest.core.config import Config, StorageConfig, WorkspaceConfig -from antarest.core.exceptions import TaskAlreadyRunning +from antarest.core.exceptions import StudyVariantUpgradeError, TaskAlreadyRunning from antarest.core.filetransfer.model import FileDownload, FileDownloadTaskDTO from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import Event, EventType @@ -1104,9 +1104,9 @@ def test_delete_with_prefetch(tmp_path: Path) -> None: seal(study_mock) study_metadata_repository.get.return_value = study_mock - variant_study_repository.get_children.return_value = [] + variant_study_repository.has_children.return_value = False - # if this fails, it may means the study metadata mock is missing some attribute definition + # if this fails, it may mean the study metadata mock is missing some attribute definition # this test is here to prevent errors if we add attribute fetching from child classes # (attributes in polymorphism are lazy) # see the comment in the delete method for more information @@ -1133,7 +1133,7 @@ def test_delete_with_prefetch(tmp_path: Path) -> None: seal(study_mock) study_metadata_repository.get.return_value = study_mock - variant_study_repository.get_children.return_value = [] + variant_study_repository.has_children.return_value = False # if this fails, it may means the study metadata mock is missing some definition # this test is here to prevent errors if we add attribute fetching from child classes (attributes in polymorphism are lazy) @@ -1180,8 +1180,7 @@ def create_study_fs_mock(variant: bool = False) -> str: return str(_study_dir) study_path = create_study_fs_mock() - study_mock = Mock( - spec=RawStudy, + study_mock = RawStudy( archived=False, id="my_study", path=study_path, @@ -1191,30 +1190,72 @@ def create_study_fs_mock(variant: bool = False) -> str: workspace=DEFAULT_WORKSPACE_NAME, last_access=datetime.utcnow(), ) - study_mock.to_json_summary.return_value = {"id": "my_study", "name": "foo"} - - # it freezes the mock and raise Attribute error if anything else than defined is used - seal(study_mock) v1 = VariantStudy(id="variant_1", path=create_study_fs_mock(variant=True)) v2 = VariantStudy(id="variant_2", path=create_study_fs_mock(variant=True)) v3 = VariantStudy(id="sub_variant_1", path=create_study_fs_mock(variant=True)) - study_metadata_repository.get.side_effect = [study_mock, v3, v1, v2] - variant_study_repository.get_children.side_effect = [ - [v1, v2], - [v3], - [], - [], - [], - [], - [], - ] - variant_study_repository.get.side_effect = [ - VariantStudy(id="variant_1"), - VariantStudy(id="sub_variant_1"), - VariantStudy(id="variant_2"), - ] + def get_study(study_id: str) -> Study: + if study_id == "my_study": + return study_mock + elif study_id == "variant_1": + return v1 + elif study_id == "variant_2": + return v2 + elif study_id == "sub_variant_1": + return v3 + raise ValueError(f"Unexpected study id: {study_id}") + + class ChildrenProvider: + def __init__(self): + self.c0 = 0 + self.c1 = 0 + + def get_children(self, parent_id: str) -> t.List[Study]: + if parent_id == "my_study": + if self.c0 > 0: + return [] + self.c0 = 1 + return [v1, v2] + elif parent_id == "variant_1": + if self.c1 > 0: + return [] + self.c1 = 1 + return [v3] + elif parent_id == "variant_2": + return [] + elif parent_id == "sub_variant_1": + return [] + raise ValueError(f"Unexpected study id: {parent_id}") + + class HasChildrenProvider: + def __init__(self): + self.c1 = 0 + self.c2 = 0 + + def has_children(self, study_id: str) -> bool: + if study_id == "my_study": + if self.c1 > 0: + return False + self.c1 = 1 + return True + elif study_id == "variant_1": + if self.c2 > 0: + return False + self.c2 = 1 + return True + elif study_id == "variant_2": + return False + elif study_id == "sub_variant_1": + return False + raise ValueError(f"Unexpected study id: {study_id}") + + children_provider = ChildrenProvider() + hash_children_provider = HasChildrenProvider() + study_metadata_repository.get = get_study + variant_study_repository.get = get_study + variant_study_repository.get_children = children_provider.get_children + variant_study_repository.has_children = hash_children_provider.has_children service.delete_study( "my_study", @@ -1594,17 +1635,18 @@ def test_task_upgrade_study(tmp_path: Path) -> None: ) study_mock.name = "my_study" study_mock.to_json_summary.return_value = {"id": "my_study", "name": "foo"} + service.repository.has_children.return_value = False # type: ignore service.repository.get.return_value = study_mock # type: ignore study_id = "my_study" - service.task_service.reset_mock() + service.task_service.reset_mock() # type: ignore service.task_service.list_tasks.side_effect = [ [ TaskDTO( id="1", name=f"Upgrade study my_study ({study_id}) to version 800", status=TaskStatus.RUNNING, - creation_date_utc=str(datetime.utcnow()), + creation_date_utc=str(datetime.utcnow()), # type: ignore type=TaskType.UNARCHIVE, ref_id=study_id, ) @@ -1634,6 +1676,56 @@ def test_task_upgrade_study(tmp_path: Path) -> None: request_params=RequestParameters(user=DEFAULT_ADMIN_USER), ) + # check that a variant study or a raw study with children cannot be upgraded + parent_raw_study = Mock( + spec=RawStudy, + archived=False, + id="parent_raw_study", + name="parent_raw_study", + path=tmp_path, + version="720", + owner=None, + groups=[], + public_mode=PublicMode.NONE, + workspace="other_workspace", + ) + study_mock.name = "parent_raw_study" + study_mock.to_json_summary.return_value = {"id": "parent_raw_study", "name": "parent_raw_study"} + service.repository.has_children.return_value = True # type: ignore + service.repository.get.return_value = parent_raw_study # type: ignore + + with pytest.raises(StudyVariantUpgradeError): + service.upgrade_study( + "parent_raw_study", + target_version="", + params=RequestParameters(user=DEFAULT_ADMIN_USER), + ) + + variant_study = Mock( + spec=VariantStudy, + archived=False, + id="variant_study", + name="variant_study", + path=tmp_path, + version="720", + owner=None, + groups=[], + public_mode=PublicMode.NONE, + workspace="other_workspace", + ) + + study_mock.name = "variant_study" + study_mock.to_json_summary.return_value = {"id": "variant_study", "name": "variant_study"} + service.repository.has_children.return_value = True # type: ignore + service.repository.get.return_value = variant_study # type: ignore + + with pytest.raises(StudyVariantUpgradeError): + service.upgrade_study( + "variant_study", + target_version="", + params=RequestParameters(user=DEFAULT_ADMIN_USER), + ) + @with_db_context @patch("antarest.study.service.upgrade_study")