diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index cada3a5f5d..ab081c267b 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -41,6 +41,19 @@ def __str__(self) -> str: return self.detail +class STStorageNotFoundError(HTTPException): + def __init__(self, message: str) -> None: + super().__init__(HTTPStatus.NOT_FOUND, message) + + +class DuplicateSTStorageId(HTTPException): + """Exception raised when trying to create a short term storage with an already existing id.""" + + def __init__(self, st_storage_id: str) -> None: + msg = f"Short term storage '{st_storage_id}' already exists and could not be created" + super().__init__(HTTPStatus.CONFLICT, msg) + + class UnknownModuleError(Exception): def __init__(self, message: str) -> None: super(UnknownModuleError, self).__init__(message) diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index d18dce9f9c..47dbaa41bd 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -8,9 +8,12 @@ from typing_extensions import Literal from antarest.core.exceptions import ( + AreaNotFound, + DuplicateSTStorageId, STStorageConfigNotFoundError, STStorageFieldsNotFoundError, STStorageMatrixNotFoundError, + STStorageNotFoundError, ) from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands from antarest.study.model import Study @@ -24,6 +27,7 @@ from antarest.study.storage.storage_service import StudyStorageService from antarest.study.storage.variantstudy.model.command.create_st_storage import CreateSTStorage from antarest.study.storage.variantstudy.model.command.remove_st_storage import RemoveSTStorage +from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig __all__ = ( @@ -258,6 +262,10 @@ def create_storage( The ID of the newly created short-term storage. """ storage = form.to_config + + file_study = self._get_file_study(study) + _check_creation_feasibility(file_study, area_id, storage.id) + command = CreateSTStorage( area_id=area_id, parameters=storage, @@ -350,18 +358,11 @@ def update_storage( """ study_version = study.version - # review: reading the configuration poses a problem for variants, - # because it requires generating a snapshot, which takes time. - # This reading could be avoided if we don't need the previous values - # (no cross-field validation, no default values, etc.). - # In return, we won't be able to return a complete `STStorageOutput` object. - # So, we need to make sure the frontend doesn't need the missing fields. - # This missing information could also be a problem for the API users. - # The solution would be to avoid reading the configuration if the study is a variant - # (we then use the default values), otherwise, for a RAW study, we read the configuration - # and update the modified values. + # For variants, this method requires generating a snapshot, which takes time. + # But sadly, there's no other way to prevent creating wrong commands. file_study = self._get_file_study(study) + _check_update_feasibility(file_study, area_id, storage_id) path = STORAGE_LIST_PATH.format(area_id=area_id, storage_id=storage_id) try: @@ -408,6 +409,9 @@ def delete_storages( area_id: The area ID of the short-term storage. storage_ids: IDs list of short-term storages to remove. """ + file_study = self._get_file_study(study) + _check_deletion_feasibility(file_study, area_id, storage_ids) + command_context = self.storage_service.variant_study_service.command_factory.command_context for storage_id in storage_ids: command = RemoveSTStorage( @@ -415,7 +419,6 @@ def delete_storages( storage_id=storage_id, command_context=command_context, ) - file_study = self._get_file_study(study) execute_or_add_commands(study, file_study, [command], self.storage_service) def get_matrix( @@ -473,8 +476,7 @@ def update_matrix( ts_name: Name of the time series to update. ts: Matrix of the time series to update. """ - matrix_object = ts.dict() - self._save_matrix_obj(study, area_id, storage_id, ts_name, matrix_object) + self._save_matrix_obj(study, area_id, storage_id, ts_name, ts.data) def _save_matrix_obj( self, @@ -482,14 +484,13 @@ def _save_matrix_obj( area_id: str, storage_id: str, ts_name: STStorageTimeSeries, - matrix_obj: t.Dict[str, t.Any], + matrix_data: t.List[t.List[float]], ) -> None: file_study = self._get_file_study(study) + command_context = self.storage_service.variant_study_service.command_factory.command_context path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name) - try: - file_study.tree.save(matrix_obj, path.split("/")) - except KeyError: - raise STStorageMatrixNotFoundError(study.id, area_id, storage_id, ts_name) from None + command = ReplaceMatrix(target=path, matrix=matrix_data, command_context=command_context) + execute_or_add_commands(study, file_study, [command], self.storage_service) def validate_matrices( self, @@ -534,3 +535,29 @@ def validate_matrices( # Validation successful return True + + +def _get_existing_storage_ids(file_study: FileStudy, area_id: str) -> t.List[str]: + areas = file_study.config.areas + if area_id not in areas: + raise AreaNotFound(f"Area {area_id} does not exist in study") + return [existing_storage.id for existing_storage in areas.get(area_id).st_storages] # type: ignore + + +def _check_deletion_feasibility(file_study: FileStudy, area_id: str, storage_ids: t.Sequence[str]) -> None: + existing_ids = _get_existing_storage_ids(file_study, area_id) + for storage_id in storage_ids: + if storage_id not in existing_ids: + raise STStorageNotFoundError(f"Short term storage {storage_id} does not exist in study") + + +def _check_update_feasibility(file_study: FileStudy, area_id: str, storage_id: str) -> None: + existing_ids = _get_existing_storage_ids(file_study, area_id) + if storage_id not in existing_ids: + raise STStorageNotFoundError(f"Short term storage {storage_id} does not exist in study") + + +def _check_creation_feasibility(file_study: FileStudy, area_id: str, storage_id: str) -> None: + existing_ids = _get_existing_storage_ids(file_study, area_id) + if storage_id in existing_ids: + raise DuplicateSTStorageId(storage_id) diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index fdffe5efe1..ccb762919c 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -28,11 +28,9 @@ class TestSTStorage: which contains the following areas: ["de", "es", "fr", "it"]. """ + @pytest.mark.parametrize("study_type", ["raw", "variant"]) def test_lifecycle__nominal( - self, - client: TestClient, - user_access_token: str, - study_id: str, + self, client: TestClient, user_access_token: str, study_id: str, study_type: str ) -> None: """ The purpose of this integration test is to test the endpoints @@ -58,10 +56,15 @@ def test_lifecycle__nominal( We will test the deletion of short-term storages. """ + # ============================= + # SET UP + # ============================= + user_headers = {"Authorization": f"Bearer {user_access_token}"} + # Upgrade study to version 860 res = client.put( f"/v1/studies/{study_id}/upgrade", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, params={"target_version": 860}, ) res.raise_for_status() @@ -69,6 +72,25 @@ def test_lifecycle__nominal( task = wait_task_completion(client, user_access_token, task_id) assert task.status == TaskStatus.COMPLETED, task + # Copies the study, to convert it into a managed one. + res = client.post( + f"/v1/studies/{study_id}/copy", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"dest": "default", "with_outputs": False, "use_task": False}, # type: ignore + ) + assert res.status_code == 201, res.json() + study_id = res.json() + + if study_type == "variant": + # Create Variant + res = client.post( + f"/v1/studies/{study_id}/variants", + headers=user_headers, + params={"name": "Variant 1"}, + ) + assert res.status_code in {200, 201}, res.json() + study_id = res.json() + # ============================= # SHORT-TERM STORAGE CREATION # ============================= @@ -84,7 +106,7 @@ def test_lifecycle__nominal( for attempt in attempts: res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=attempt, ) assert res.status_code == 422, res.json() @@ -101,7 +123,7 @@ def test_lifecycle__nominal( } res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=siemens_properties, ) assert res.status_code == 200, res.json() @@ -113,7 +135,7 @@ def test_lifecycle__nominal( # reading the properties of a short-term storage res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() assert res.json() == siemens_config @@ -126,7 +148,7 @@ def test_lifecycle__nominal( array = np.random.rand(8760, 1) * 1000 res = client.put( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}/series/inflows", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={ "index": list(range(array.shape[0])), "columns": list(range(array.shape[1])), @@ -139,7 +161,7 @@ def test_lifecycle__nominal( # reading the matrix of a short-term storage res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}/series/inflows", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() matrix = res.json() @@ -149,7 +171,7 @@ def test_lifecycle__nominal( # validating the matrices of a short-term storage res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}/validate", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() assert res.json() is True @@ -161,7 +183,7 @@ def test_lifecycle__nominal( # Reading the list of short-term storages res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() assert res.json() == [siemens_config] @@ -169,7 +191,7 @@ def test_lifecycle__nominal( # updating properties res = client.patch( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={ "name": "New Siemens Battery", "reservoirCapacity": 2500, @@ -185,7 +207,7 @@ def test_lifecycle__nominal( res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() assert res.json() == siemens_config @@ -197,7 +219,7 @@ def test_lifecycle__nominal( # updating properties res = client.patch( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={ "initialLevel": 0.59, "reservoirCapacity": 0, @@ -217,7 +239,7 @@ def test_lifecycle__nominal( bad_properties = {"efficiency": 2.0} res = client.patch( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=bad_properties, ) assert res.status_code == 422, res.json() @@ -226,7 +248,7 @@ def test_lifecycle__nominal( # The short-term storage properties should not have been updated. res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() assert res.json() == siemens_config @@ -239,7 +261,7 @@ def test_lifecycle__nominal( res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=[siemens_battery_id], ) assert res.status_code == 204, res.json() @@ -249,7 +271,7 @@ def test_lifecycle__nominal( res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=[], ) assert res.status_code == 204, res.json() @@ -269,7 +291,7 @@ def test_lifecycle__nominal( } res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=siemens_properties, ) assert res.status_code == 200, res.json() @@ -288,7 +310,7 @@ def test_lifecycle__nominal( } res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=grand_maison_properties, ) assert res.status_code == 200, res.json() @@ -298,7 +320,7 @@ def test_lifecycle__nominal( # Reading the list of short-term storages res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() siemens_config = {**DEFAULT_PROPERTIES, **siemens_properties, "id": siemens_battery_id} @@ -309,7 +331,7 @@ def test_lifecycle__nominal( res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=[siemens_battery_id, grand_maison_id], ) assert res.status_code == 204, res.json() @@ -318,7 +340,7 @@ def test_lifecycle__nominal( # The list of short-term storages should be empty. res = client.get( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() assert res.json() == [] @@ -332,25 +354,20 @@ def test_lifecycle__nominal( res = client.request( "DELETE", f"/v1/studies/{study_id}/areas/{bad_area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=[siemens_battery_id], ) - assert res.status_code == 500, res.json() + assert res.status_code == 404 obj = res.json() - description = obj["description"] - assert bad_area_id in description - assert re.search( - r"CommandName.REMOVE_ST_STORAGE", - description, - flags=re.IGNORECASE, - ) + assert obj["description"] == f"Area is not found: 'Area {bad_area_id} does not exist in study'" + assert obj["exception"] == "AreaNotFound" # Check delete with the wrong value of `study_id` bad_study_id = "bad_study" res = client.request( "DELETE", f"/v1/studies/{bad_study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json=[siemens_battery_id], ) obj = res.json() @@ -361,7 +378,7 @@ def test_lifecycle__nominal( # Check get with wrong `area_id` res = client.get( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) obj = res.json() description = obj["description"] @@ -371,7 +388,7 @@ def test_lifecycle__nominal( # Check get with wrong `study_id` res = client.get( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) obj = res.json() description = obj["description"] @@ -381,7 +398,7 @@ def test_lifecycle__nominal( # Check POST with wrong `study_id` res = client.post( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"name": siemens_battery, "group": "Battery"}, ) obj = res.json() @@ -392,20 +409,18 @@ def test_lifecycle__nominal( # Check POST with wrong `area_id` res = client.post( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"name": siemens_battery, "group": "Battery"}, ) - assert res.status_code == 500, res.json() + assert res.status_code == 404 obj = res.json() - description = obj["description"] - assert bad_area_id in description - assert re.search(r"Area ", description, flags=re.IGNORECASE) - assert re.search(r"does not exist ", description, flags=re.IGNORECASE) + assert obj["description"] == f"Area is not found: 'Area {bad_area_id} does not exist in study'" + assert obj["exception"] == "AreaNotFound" # Check POST with wrong `group` res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"name": siemens_battery, "group": "GroupFoo"}, ) assert res.status_code == 422, res.json() @@ -416,33 +431,30 @@ def test_lifecycle__nominal( # Check PATCH with the wrong `area_id` res = client.patch( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"efficiency": 1.0}, ) - assert res.status_code == 404, res.json() + assert res.status_code == 404 obj = res.json() - description = obj["description"] - assert bad_area_id in description - assert re.search(r"not a child of ", description, flags=re.IGNORECASE) + assert obj["description"] == f"Area is not found: 'Area {bad_area_id} does not exist in study'" + assert obj["exception"] == "AreaNotFound" # Check PATCH with the wrong `storage_id` bad_storage_id = "bad_storage" res = client.patch( f"/v1/studies/{study_id}/areas/{area_id}/storages/{bad_storage_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"efficiency": 1.0}, ) - assert res.status_code == 404, res.json() + assert res.status_code == 404 obj = res.json() - description = obj["description"] - assert bad_storage_id in description - assert re.search(r"fields of storage", description, flags=re.IGNORECASE) - assert re.search(r"not found", description, flags=re.IGNORECASE) + assert obj["description"] == f"Short term storage {bad_storage_id} does not exist in study" + assert obj["exception"] == "STStorageNotFoundError" # Check PATCH with the wrong `study_id` res = client.patch( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"efficiency": 1.0}, ) assert res.status_code == 404, res.json() @@ -450,11 +462,8 @@ def test_lifecycle__nominal( description = obj["description"] assert bad_study_id in description - def test__default_values( - self, - client: TestClient, - user_access_token: str, - ) -> None: + @pytest.mark.parametrize("study_type", ["raw", "variant"]) + def test__default_values(self, client: TestClient, user_access_token: str, study_type: str) -> None: """ The purpose of this integration test is to test the default values of the properties of a short-term storage. @@ -464,18 +473,29 @@ def test__default_values( Then the short-term storage is created with initialLevel = 0.0, and initialLevelOptim = False. """ # Create a new study in version 860 (or higher) + user_headers = {"Authorization": f"Bearer {user_access_token}"} res = client.post( "/v1/studies", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, params={"name": "MyStudy", "version": 860}, ) assert res.status_code in {200, 201}, res.json() study_id = res.json() + if study_type == "variant": + # Create Variant + res = client.post( + f"/v1/studies/{study_id}/variants", + headers=user_headers, + params={"name": "Variant 1"}, + ) + assert res.status_code in {200, 201}, res.json() + study_id = res.json() + # Create a new area named "FR" res = client.post( f"/v1/studies/{study_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"name": "FR", "type": "AREA"}, ) assert res.status_code in {200, 201}, res.json() @@ -485,7 +505,7 @@ def test__default_values( tesla_battery = "Tesla Battery" res = client.post( f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"name": tesla_battery, "group": "Battery"}, ) assert res.status_code == 200, res.json() @@ -497,7 +517,7 @@ def test__default_values( # are properly set in the configuration file. res = client.get( f"/v1/studies/{study_id}/raw", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, params={"path": f"input/st-storage/clusters/{area_id}/list/{tesla_battery_id}"}, ) assert res.status_code == 200, res.json() @@ -512,7 +532,7 @@ def test__default_values( # Create a variant of the study res = client.post( f"/v1/studies/{study_id}/variants", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, params={"name": "MyVariant"}, ) assert res.status_code in {200, 201}, res.json() @@ -522,7 +542,7 @@ def test__default_values( siemens_battery = "Siemens Battery" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"name": siemens_battery, "group": "Battery"}, ) assert res.status_code == 200, res.json() @@ -530,7 +550,7 @@ def test__default_values( # Check the variant commands res = client.get( f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() commands = res.json() @@ -556,7 +576,7 @@ def test__default_values( siemens_battery_id = transform_name_to_id(siemens_battery) res = client.patch( f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"initialLevel": 0.5}, ) assert res.status_code == 200, res.json() @@ -564,7 +584,7 @@ def test__default_values( # Check the variant commands res = client.get( f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() commands = res.json() @@ -584,7 +604,7 @@ def test__default_values( # Update the initialLevel property of the "Siemens Battery" short-term storage back to 0 res = client.patch( f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, json={"initialLevel": 0.0, "injectionNominalCapacity": 1600}, ) assert res.status_code == 200, res.json() @@ -592,7 +612,7 @@ def test__default_values( # Check the variant commands res = client.get( f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, ) assert res.status_code == 200, res.json() commands = res.json() @@ -619,7 +639,7 @@ def test__default_values( # are properly set in the configuration file. res = client.get( f"/v1/studies/{variant_id}/raw", - headers={"Authorization": f"Bearer {user_access_token}"}, + headers=user_headers, params={"path": f"input/st-storage/clusters/{area_id}/list/{siemens_battery_id}"}, ) assert res.status_code == 200, res.json() diff --git a/tests/study/business/areas/test_st_storage_management.py b/tests/study/business/areas/test_st_storage_management.py index 646dc26c78..c86ea8dea5 100644 --- a/tests/study/business/areas/test_st_storage_management.py +++ b/tests/study/business/areas/test_st_storage_management.py @@ -3,7 +3,7 @@ import re import uuid from typing import Any, MutableMapping, Sequence, cast -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import numpy as np import pytest @@ -11,16 +11,19 @@ from sqlalchemy.orm.session import Session # type: ignore from antarest.core.exceptions import ( + AreaNotFound, STStorageConfigNotFoundError, STStorageFieldsNotFoundError, STStorageMatrixNotFoundError, + STStorageNotFoundError, ) from antarest.core.model import PublicMode from antarest.login.model import Group, User from antarest.study.business.areas.st_storage_management import STStorageInput, STStorageManager from antarest.study.model import RawStudy, Study, StudyContentStatus from antarest.study.storage.rawstudy.ini_reader import IniReader -from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import STStorageGroup +from antarest.study.storage.rawstudy.model.filesystem.config.model import Area, FileStudyTreeConfig +from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import STStorageConfig, STStorageGroup from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.rawstudy.model.filesystem.ini_file_node import IniFileNode from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import FileStudyTree @@ -287,11 +290,27 @@ def test_update_storage__nominal_case( get_node=Mock(return_value=ini_file_node), ) + area = Mock(spec=Area) + mock_config = Mock(spec=FileStudyTreeConfig) + file_study.config = mock_config + # Given the following arguments manager = STStorageManager(study_storage_service) - - # Run the method being tested edit_form = STStorageInput(initial_level=0, initial_level_optim=False) + + # Test behavior for area not in study + mock_config.areas = {"fake_area": area} + with pytest.raises(AreaNotFound, match=f"Area West does not exist in study"): + manager.update_storage(study, area_id="West", storage_id="storage1", form=edit_form) + + # Test behavior for st_storage not in study + mock_config.areas = {"West": area} + area.st_storages = [STStorageConfig(name="fake_name", group="battery")] + with pytest.raises(STStorageNotFoundError, match=f"Short term storage storage1 does not exist in study"): + manager.update_storage(study, area_id="West", storage_id="storage1", form=edit_form) + + # Test behavior for nominal case + area.st_storages = [STStorageConfig(name="storage1", group="battery")] manager.update_storage(study, area_id="West", storage_id="storage1", form=edit_form) # Assert that the storage fields have been updated