diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index 3414d6477b..9a2230c1d1 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -174,7 +174,7 @@ def __init__(self, message: str) -> None: super().__init__(HTTPStatus.BAD_REQUEST, message) -class NoBindingConstraintError(HTTPException): +class BindingConstraintNotFoundError(HTTPException): def __init__(self, message: str) -> None: super().__init__(HTTPStatus.NOT_FOUND, message) diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index 4816666c66..55112da1da 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -1,14 +1,14 @@ from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from antarest.core.exceptions import ( + BindingConstraintNotFoundError, ConstraintAlreadyExistError, ConstraintIdNotFoundError, DuplicateConstraintName, InvalidConstraintName, MissingDataError, - NoBindingConstraintError, NoConstraintError, ) from antarest.matrixstore.model import MatrixData @@ -29,21 +29,72 @@ from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint -class LinkInfoDTO(BaseModel): +class AreaLinkDTO(BaseModel): + """ + DTO for a constraint term on a link between two areas. + + Attributes: + area1: the first area ID + area2: the second area ID + """ + area1: str area2: str + def generate_id(self) -> str: + """Return the constraint term ID for this link, of the form "area1%area2".""" + # Ensure IDs are in alphabetical order and lower case + ids = sorted((self.area1.lower(), self.area2.lower())) + return "%".join(ids) + + +class AreaClusterDTO(BaseModel): + """ + DTO for a constraint term on a cluster in an area. + + Attributes: + area: the area ID + cluster: the cluster ID + """ -class ClusterInfoDTO(BaseModel): area: str cluster: str + def generate_id(self) -> str: + """Return the constraint term ID for this Area/cluster constraint, of the form "area.cluster".""" + # Ensure IDs are in lower case + ids = [self.area.lower(), self.cluster.lower()] + return ".".join(ids) + class ConstraintTermDTO(BaseModel): + """ + DTO for a constraint term. + + Attributes: + id: the constraint term ID, of the form "area1%area2" or "area.cluster". + weight: the constraint term weight, if any. + offset: the constraint term offset, if any. + data: the constraint term data (link or cluster), if any. + """ + id: Optional[str] weight: Optional[float] offset: Optional[float] - data: Optional[Union[LinkInfoDTO, ClusterInfoDTO]] + data: Optional[Union[AreaLinkDTO, AreaClusterDTO]] + + @validator("id") + def id_to_lower(cls, v: Optional[str]) -> Optional[str]: + """Ensure the ID is lower case.""" + if v is None: + return None + return v.lower() + + def generate_id(self) -> str: + """Return the constraint term ID for this term based on its data.""" + if self.data is None: + return self.id or "" + return self.data.generate_id() class UpdateBindingConstProps(BaseModel): @@ -97,12 +148,12 @@ def parse_constraint(key: str, value: str, char: str, new_config: BindingConstra id=key, weight=weight, offset=offset if offset is not None else None, - data=LinkInfoDTO( + data=AreaLinkDTO( area1=value1, area2=value2, ) if char == "%" - else ClusterInfoDTO( + else AreaClusterDTO( area=value1, cluster=value2, ), @@ -208,7 +259,7 @@ def update_binding_constraint( file_study = self.storage_service.get_storage(study).get_raw(study) constraint = self.get_binding_constraint(study, binding_constraint_id) if not isinstance(constraint, BindingConstraintDTO): - raise NoBindingConstraintError(study.id) + raise BindingConstraintNotFoundError(study.id) if data.key == "time_step" and data.value != constraint.time_step: # The user changed the time step, we need to update the matrix accordingly @@ -243,16 +294,6 @@ def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constrain except ValueError: return -1 - @staticmethod - def get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str: - if isinstance(data, ClusterInfoDTO): - constraint_id = f"{data.area}.{data.cluster}" - else: - area1 = data.area1 if data.area1 < data.area2 else data.area2 - area2 = data.area2 if area1 == data.area1 else data.area1 - constraint_id = f"{area1}%{area2}" - return constraint_id - def add_new_constraint_term( self, study: Study, @@ -262,12 +303,12 @@ def add_new_constraint_term( file_study = self.storage_service.get_storage(study).get_raw(study) constraint = self.get_binding_constraint(study, binding_constraint_id) if not isinstance(constraint, BindingConstraintDTO): - raise NoBindingConstraintError(study.id) + raise BindingConstraintNotFoundError(study.id) if constraint_term.data is None: raise MissingDataError("Add new constraint term : data is missing") - constraint_id = BindingConstraintManager.get_constraint_id(constraint_term.data) + constraint_id = constraint_term.data.generate_id() constraints_term = constraint.constraints or [] if BindingConstraintManager.find_constraint_term_id(constraints_term, constraint_id) >= 0: raise ConstraintAlreadyExistError(study.id) @@ -310,7 +351,7 @@ def update_constraint_term( constraint = self.get_binding_constraint(study, binding_constraint_id) if not isinstance(constraint, BindingConstraintDTO): - raise NoBindingConstraintError(study.id) + raise BindingConstraintNotFoundError(study.id) constraint_terms = constraint.constraints # existing constraint terms if constraint_terms is None: @@ -318,14 +359,14 @@ def update_constraint_term( term_id = term.id if isinstance(term, ConstraintTermDTO) else term if term_id is None: - raise NoConstraintError(study.id) + raise ConstraintIdNotFoundError(study.id) term_id_index = BindingConstraintManager.find_constraint_term_id(constraint_terms, term_id) if term_id_index < 0: raise ConstraintIdNotFoundError(study.id) if isinstance(term, ConstraintTermDTO): - updated_term_id = BindingConstraintManager.get_constraint_id(term.data) if term.data else term_id + updated_term_id = term.data.generate_id() if term.data else term_id current_constraint = constraint_terms[term_id_index] constraint_terms[term_id_index] = ConstraintTermDTO( @@ -353,6 +394,7 @@ def update_constraint_term( ) execute_or_add_commands(study, file_study, [command], self.storage_service) + # FIXME create a dedicated delete service def remove_constraint_term( self, study: Study, diff --git a/tests/integration/study_data_blueprint/test_binding_constraints.py b/tests/integration/study_data_blueprint/test_binding_constraints.py index 627575097e..93b5237f7f 100644 --- a/tests/integration/study_data_blueprint/test_binding_constraints.py +++ b/tests/integration/study_data_blueprint/test_binding_constraints.py @@ -1,6 +1,65 @@ import pytest from starlette.testclient import TestClient +from antarest.study.business.binding_constraint_management import AreaClusterDTO, AreaLinkDTO, ConstraintTermDTO + + +class TestAreaLinkDTO: + @pytest.mark.parametrize( + "area1, area2, expected", + [ + ("Area 1", "Area 2", "area 1%area 2"), + ("de", "fr", "de%fr"), + ("fr", "de", "de%fr"), + ("FR", "de", "de%fr"), + ], + ) + def test_constraint_id(self, area1: str, area2: str, expected: str) -> None: + info = AreaLinkDTO(area1=area1, area2=area2) + assert info.generate_id() == expected + + +class TestAreaClusterDTO: + @pytest.mark.parametrize( + "area, cluster, expected", + [ + ("Area 1", "Cluster X", "area 1.cluster x"), + ("de", "Nuclear", "de.nuclear"), + ("GB", "Gas", "gb.gas"), + ], + ) + def test_constraint_id(self, area: str, cluster: str, expected: str) -> None: + info = AreaClusterDTO(area=area, cluster=cluster) + assert info.generate_id() == expected + + +class TestConstraintTermDTO: + def test_constraint_id__link(self): + term = ConstraintTermDTO( + id="foo", + weight=3.14, + offset=123, + data=AreaLinkDTO(area1="Area 1", area2="Area 2"), + ) + assert term.generate_id() == term.data.generate_id() + + def test_constraint_id__cluster(self): + term = ConstraintTermDTO( + id="foo", + weight=3.14, + offset=123, + data=AreaClusterDTO(area="Area 1", cluster="Cluster X"), + ) + assert term.generate_id() == term.data.generate_id() + + def test_constraint_id__other(self): + term = ConstraintTermDTO( + id="foo", + weight=3.14, + offset=123, + ) + assert term.generate_id() == "foo" + @pytest.mark.unit_test class TestBindingConstraints: @@ -173,7 +232,10 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) -> json={ "weight": 1, "offset": 2, - "data": {"area": area1_id, "cluster": cluster_id}, + "data": { + "area": area1_id, + "cluster": cluster_id, + }, # NOTE: cluster_id in term data can be uppercase, but it must be lowercase in the returned ini configuration file }, headers=user_headers, ) @@ -195,15 +257,15 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) -> "weight": 1.0, }, { - "data": {"area": area1_id, "cluster": cluster_id}, - "id": f"{area1_id}.{cluster_id}", + "data": {"area": area1_id, "cluster": cluster_id.lower()}, + "id": f"{area1_id}.{cluster_id.lower()}", "offset": 2.0, "weight": 1.0, }, ] assert constraint_terms == expected - # Update constraint cluster term + # Update constraint cluster term with uppercase cluster_id res = client.put( f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/term", json={ @@ -214,7 +276,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) -> ) assert res.status_code == 200, res.json() - # Get binding constraints list to check updated term + # Check updated terms, cluster_id should be lowercase in the returned configuration res = client.get( f"/v1/studies/{study_id}/bindingconstraints/{bc_id}", headers=user_headers, @@ -230,49 +292,14 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) -> "weight": 1.0, }, { - "data": {"area": area1_id, "cluster": cluster_id}, - "id": f"{area1_id}.{cluster_id}", + "data": {"area": area1_id, "cluster": cluster_id.lower()}, + "id": f"{area1_id}.{cluster_id.lower()}", "offset": None, # updated "weight": 3.0, # updated }, ] assert constraint_terms == expected - # Update constraint term regardless of the case of the cluster id - res = client.put( - f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/term", - json={ - "id": f"{area1_id}.Cluster 1", - "weight": 4, - }, - headers=user_headers, - ) - assert res.status_code == 200, res.json() - - # The term should be successfully updated - res = client.get( - f"/v1/studies/{study_id}/bindingconstraints/{bc_id}", - headers=user_headers, - ) - assert res.status_code == 200, res.json() - binding_constraint = res.json() - constraint_terms = binding_constraint["constraints"] - expected = [ - { - "data": {"area1": area1_id, "area2": area2_id}, - "id": f"{area1_id}%{area2_id}", - "offset": 2.0, - "weight": 1.0, - }, - { - "data": {"area": area1_id, "cluster": cluster_id}, - "id": f"{area1_id}.{cluster_id}", - "offset": None, # updated - "weight": 4.0, # updated - }, - ] - assert constraint_terms == expected - # Update constraint cluster term with invalid id res = client.put( f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/term",