diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index 0afdfe7e34..ca498c030a 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -26,6 +26,7 @@ from antarest.study.storage.storage_service import StudyStorageService from antarest.study.storage.variantstudy.model.command.create_st_storage import CreateSTStorage from antarest.study.storage.variantstudy.model.command.remove_st_storage import RemoveSTStorage +from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig __all__ = ( @@ -74,8 +75,8 @@ def validate_name(cls, name: t.Optional[str]) -> str: raise ValueError("'name' must not be empty") return name - @property - def to_config(self) -> STStorageConfig: + # noinspection PyUnusedLocal + def to_config(self, study_version: t.Union[str, int]) -> STStorageConfig: values = self.dict(by_alias=False, exclude_none=True) return STStorageConfig(**values) @@ -259,21 +260,25 @@ def create_storage( Returns: The ID of the newly created short-term storage. """ - storage = form.to_config - command = CreateSTStorage( - area_id=area_id, - parameters=storage, - command_context=self.storage_service.variant_study_service.command_factory.command_context, - ) file_study = self._get_file_study(study) + storage = form.to_config(study.version) + command = self._make_create_cluster_cmd(area_id, storage) execute_or_add_commands( study, file_study, [command], self.storage_service, ) + output = self.get_storage(study, area_id, storage_id=storage.id) + return output - return self.get_storage(study, area_id, storage_id=storage.id) + def _make_create_cluster_cmd(self, area_id: str, cluster: STStorageConfig) -> CreateSTStorage: + command = CreateSTStorage( + area_id=area_id, + parameters=cluster, + command_context=self.storage_service.variant_study_service.command_factory.command_context, + ) + return command def get_storages( self, @@ -420,7 +425,7 @@ def delete_storages( file_study = self._get_file_study(study) execute_or_add_commands(study, file_study, [command], self.storage_service) - def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_name: str) -> STStorageOutput: + def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_cluster_name: str) -> STStorageOutput: """ Creates a duplicate cluster within the study area with a new name. @@ -428,7 +433,7 @@ def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_name study: The study in which the cluster will be duplicated. area_id: The identifier of the area where the cluster will be duplicated. source_id: The identifier of the cluster to be duplicated. - new_name: The new name for the duplicated cluster. + new_cluster_name: The new name for the duplicated cluster. Returns: The duplicated cluster configuration. @@ -436,22 +441,42 @@ def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_name Raises: ClusterAlreadyExists: If a cluster with the new name already exists in the area. """ - new_id = transform_name_to_id(new_name) - if any(new_id.lower() == storage.id.lower() for storage in self.get_storages(study, area_id)): + new_id = transform_name_to_id(new_cluster_name) + lower_new_id = new_id.lower() + if any(lower_new_id == storage.id.lower() for storage in self.get_storages(study, area_id)): raise ClusterAlreadyExists("Short-term storage", new_id) # Cluster duplication current_cluster = self.get_storage(study, area_id, source_id) - current_cluster.name = new_name + current_cluster.name = new_cluster_name creation_form = STStorageCreation(**current_cluster.dict(by_alias=False, exclude={"id"})) - new_storage = self.create_storage(study, area_id, creation_form) + new_config = creation_form.to_config(study.version) + create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) # Matrix edition - for ts_name in STStorageTimeSeries.__args__: # type: ignore - ts = self.get_matrix(study, area_id, source_id, ts_name) - self.update_matrix(study, area_id, new_id.lower(), ts_name, ts) + lower_source_id = source_id.lower() + ts_names = ["pmax_injection", "pmax_withdrawal", "lower_rule_curve", "upper_rule_curve", "inflows"] + source_paths = [ + STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=lower_source_id, ts_name=ts_name) + for ts_name in ts_names + ] + new_paths = [ + STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=lower_new_id, ts_name=ts_name) + for ts_name in ts_names + ] + + # Prepare and execute commands + commands: t.List[t.Union[CreateSTStorage, ReplaceMatrix]] = [create_cluster_cmd] + storage_service = self.storage_service.get_storage(study) + command_context = self.storage_service.variant_study_service.command_factory.command_context + for source_path, new_path in zip(source_paths, new_paths): + current_matrix = storage_service.get(study, source_path)["data"] + command = ReplaceMatrix(target=new_path, matrix=current_matrix, command_context=command_context) + commands.append(command) - return new_storage + execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) + + return STStorageOutput(**new_config.dict(by_alias=False)) def get_matrix( self, @@ -519,12 +544,11 @@ def _save_matrix_obj( ts_name: STStorageTimeSeries, matrix_obj: t.Dict[str, t.Any], ) -> None: - file_study = self._get_file_study(study) path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name) - try: - file_study.tree.save(matrix_obj, path.split("/")) - except KeyError: - raise STStorageMatrixNotFoundError(study.id, area_id, storage_id, ts_name) from None + matrix = matrix_obj["data"] + command_context = self.storage_service.variant_study_service.command_factory.command_context + command = ReplaceMatrix(target=path, matrix=matrix, command_context=command_context) + execute_or_add_commands(study, self._get_file_study(study), [command], self.storage_service) def validate_matrices( self, diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index 322c4669ae..5f2421d911 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -1,5 +1,6 @@ import json import re +import typing as t from unittest.mock import ANY import numpy as np @@ -683,3 +684,146 @@ def test__default_values( "initiallevel": 0.0, } assert actual == expected + + @pytest.fixture(name="base_study_id") + def base_study_id_fixture(self, request: t.Any, client: TestClient, user_access_token: str) -> str: + """Prepare a managed study for the variant study tests.""" + params = request.param + res = client.post( + "/v1/studies", + headers={"Authorization": f"Bearer {user_access_token}"}, + params=params, + ) + assert res.status_code in {200, 201}, res.json() + study_id: str = res.json() + return study_id + + @pytest.fixture(name="variant_id") + def variant_id_fixture(self, request: t.Any, client: TestClient, user_access_token: str, base_study_id: str) -> str: + """Prepare a variant study for the variant study tests.""" + name = request.param + res = client.post( + f"/v1/studies/{base_study_id}/variants", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"name": name}, + ) + assert res.status_code in {200, 201}, res.json() + study_id: str = res.json() + return study_id + + # noinspection PyTestParametrized + @pytest.mark.parametrize("base_study_id", [{"name": "Base Study", "version": 860}], indirect=True) + @pytest.mark.parametrize("variant_id", ["Variant Study"], indirect=True) + def test_variant_lifecycle(self, client: TestClient, user_access_token: str, variant_id: str) -> None: + """ + In this test, we want to check that short-term storages can be managed + in the context of a "variant" study. + """ + # Create an area + area_name = "France" + res = client.post( + f"/v1/studies/{variant_id}/areas", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"name": area_name, "type": "AREA"}, + ) + assert res.status_code in {200, 201}, res.json() + area_cfg = res.json() + area_id = area_cfg["id"] + + # Create a short-term storage + cluster_name = "Tesla1" + res = client.post( + f"/v1/studies/{variant_id}/areas/{area_id}/storages", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={ + "name": cluster_name, + "group": "Battery", + "injectionNominalCapacity": 4500, + "withdrawalNominalCapacity": 4230, + "reservoirCapacity": 5700, + }, + ) + assert res.status_code in {200, 201}, res.json() + cluster_id: str = res.json()["id"] + + # Update the short-term storage + res = client.patch( + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"reservoirCapacity": 5600}, + ) + assert res.status_code == 200, res.json() + cluster_cfg = res.json() + assert cluster_cfg["reservoirCapacity"] == 5600 + + # Update the series matrix + matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() + matrix_path = f"input/st-storage/series/{area_id}/{cluster_id.lower()}/pmax_injection" + args = {"target": matrix_path, "matrix": matrix} + res = client.post( + f"/v1/studies/{variant_id}/commands", + json=[{"action": "replace_matrix", "args": args}], + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code in {200, 201}, res.json() + + # Duplicate the short-term storage + new_name = "Tesla2" + res = client.post( + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"newName": new_name}, + ) + assert res.status_code in {200, 201}, res.json() + cluster_cfg = res.json() + assert cluster_cfg["name"] == new_name + new_id = cluster_cfg["id"] + + # Check that the duplicate has the right properties + res = client.get( + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{new_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200, res.json() + cluster_cfg = res.json() + assert cluster_cfg["group"] == "Battery" + assert cluster_cfg["injectionNominalCapacity"] == 4500 + assert cluster_cfg["withdrawalNominalCapacity"] == 4230 + assert cluster_cfg["reservoirCapacity"] == 5600 + + # Check that the duplicate has the right matrix + new_cluster_matrix_path = f"input/st-storage/series/{area_id}/{new_id.lower()}/pmax_injection" + res = client.get( + f"/v1/studies/{variant_id}/raw", + params={"path": new_cluster_matrix_path}, + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200 + assert res.json()["data"] == matrix + + # Delete the short-term storage + res = client.delete( + f"/v1/studies/{variant_id}/areas/{area_id}/storages", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=[cluster_id], + ) + assert res.status_code == 204, res.json() + + # Check the list of variant commands + res = client.get( + f"/v1/studies/{variant_id}/commands", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200, res.json() + commands = res.json() + assert len(commands) == 7 + actions = [command["action"] for command in commands] + assert actions == [ + "create_area", + "create_st_storage", + "update_config", + "replace_matrix", + "create_st_storage", + "replace_matrix", + "remove_st_storage", + ]