From ab420104ab310199f58412e5a16c90ea94d4a1a2 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Thu, 25 Jan 2024 16:11:33 +0100 Subject: [PATCH] feat(api-bc): improve the calculation of binding constraints IDs --- .../business/binding_constraint_management.py | 68 ++++++++++++++++--- .../test_binding_constraint_management.py | 60 ++++++++++++++++ 2 files changed, 118 insertions(+), 10 deletions(-) create mode 100644 tests/study/business/test_binding_constraint_management.py diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index c9eb01d9fa..149f9bd81b 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from antarest.core.exceptions import ( ConstraintAlreadyExistError, @@ -30,21 +30,72 @@ class LinkInfoDTO(BaseModel): + """ + DTO for a constraint term on a link between two areas. + + Attributes: + area1: the first area ID + area2: the second area ID + """ + area1: str area2: str + def calc_term_id(self) -> str: + """Return the constraint term ID for this link, of the form "area1%area2".""" + # Ensure IDs are in alphabetical order and lower case + ids = sorted((self.area1.lower(), self.area2.lower())) + return "%".join(ids) + class ClusterInfoDTO(BaseModel): + """ + DTO for a constraint term on a cluster in an area. + + Attributes: + area: the area ID + cluster: the cluster ID + """ + area: str cluster: str + def calc_term_id(self) -> str: + """Return the constraint term ID for this Area/cluster constraint, of the form "area.cluster".""" + # Ensure IDs are in lower case + ids = [self.area.lower(), self.cluster.lower()] + return ".".join(ids) + class ConstraintTermDTO(BaseModel): + """ + DTO for a constraint term. + + Attributes: + id: the constraint term ID, of the form "area1%area2" or "area.cluster". + weight: the constraint term weight, if any. + offset: the constraint term offset, if any. + data: the constraint term data (link or cluster), if any. + """ + id: Optional[str] weight: Optional[float] offset: Optional[float] data: Optional[Union[LinkInfoDTO, ClusterInfoDTO]] + @validator("id") + def id_to_lower(cls, v: Optional[str]) -> Optional[str]: + """Ensure the ID is lower case.""" + if v is None: + return None + return v.lower() + + def calc_term_id(self) -> str: + """Return the constraint term ID for this term based on its data.""" + if self.data is None: + return self.id or "" + return self.data.calc_term_id() + class UpdateBindingConstProps(BaseModel): key: str @@ -245,13 +296,7 @@ def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constrain @staticmethod def get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str: - if isinstance(data, ClusterInfoDTO): - constraint_id = f"{data.area}.{data.cluster}" - else: - area1 = data.area1 if data.area1 < data.area2 else data.area2 - area2 = data.area2 if area1 == data.area1 else data.area1 - constraint_id = f"{area1}%{area2}" - return constraint_id + return data.calc_term_id() def add_new_constraint_term( self, @@ -262,12 +307,13 @@ def add_new_constraint_term( file_study = self.storage_service.get_storage(study).get_raw(study) constraint = self.get_binding_constraint(study, binding_constraint_id) if not isinstance(constraint, BindingConstraintDTO): + # todo: should be renamed 'BindingConstraintNotFoundError' raise NoBindingConstraintError(study.id) if constraint_term.data is None: raise MissingDataError("Add new constraint term : data is missing") - constraint_id = BindingConstraintManager.get_constraint_id(constraint_term.data) + constraint_id = constraint_term.data.calc_term_id() constraints_term = constraint.constraints or [] if BindingConstraintManager.find_constraint_term_id(constraints_term, constraint_id) >= 0: raise ConstraintAlreadyExistError(study.id) @@ -309,10 +355,12 @@ def update_constraint_term( file_study = self.storage_service.get_storage(study).get_raw(study) constraint = self.get_binding_constraint(study, binding_constraint_id) if not isinstance(constraint, BindingConstraintDTO): + # todo: should be renamed 'BindingConstraintNotFoundError' raise NoBindingConstraintError(study.id) constraints = constraint.constraints if constraints is None: + # todo: should be renamed 'ConstraintTermNotFoundError' raise NoConstraintError(study.id) data_id = data.id if isinstance(data, ConstraintTermDTO) else data @@ -324,7 +372,7 @@ def update_constraint_term( raise ConstraintIdNotFoundError(study.id) if isinstance(data, ConstraintTermDTO): - constraint_id = BindingConstraintManager.get_constraint_id(data.data) if data.data is not None else data_id + constraint_id = data_id if data.data is None else data.data.calc_term_id() current_constraint = constraints[data_term_index] constraints.append( ConstraintTermDTO( diff --git a/tests/study/business/test_binding_constraint_management.py b/tests/study/business/test_binding_constraint_management.py new file mode 100644 index 0000000000..270c642bd6 --- /dev/null +++ b/tests/study/business/test_binding_constraint_management.py @@ -0,0 +1,60 @@ +import pytest + +from antarest.study.business.binding_constraint_management import ClusterInfoDTO, ConstraintTermDTO, LinkInfoDTO + + +class TestLinkInfoDTO: + @pytest.mark.parametrize( + "area1, area2, expected", + [ + ("Area 1", "Area 2", "area 1%area 2"), + ("de", "fr", "de%fr"), + ("fr", "de", "de%fr"), + ("FR", "de", "de%fr"), + ], + ) + def test_constraint_id(self, area1: str, area2: str, expected: str) -> None: + info = LinkInfoDTO(area1=area1, area2=area2) + assert info.calc_term_id() == expected + + +class TestClusterInfoDTO: + @pytest.mark.parametrize( + "area, cluster, expected", + [ + ("Area 1", "Cluster X", "area 1.cluster x"), + ("de", "Nuclear", "de.nuclear"), + ("GB", "Gas", "gb.gas"), + ], + ) + def test_constraint_id(self, area: str, cluster: str, expected: str) -> None: + info = ClusterInfoDTO(area=area, cluster=cluster) + assert info.calc_term_id() == expected + + +class TestConstraintTermDTO: + def test_constraint_id__link(self): + term = ConstraintTermDTO( + id="foo", + weight=3.14, + offset=123, + data=LinkInfoDTO(area1="Area 1", area2="Area 2"), + ) + assert term.calc_term_id() == term.data.calc_term_id() + + def test_constraint_id__cluster(self): + term = ConstraintTermDTO( + id="foo", + weight=3.14, + offset=123, + data=ClusterInfoDTO(area="Area 1", cluster="Cluster X"), + ) + assert term.calc_term_id() == term.data.calc_term_id() + + def test_constraint_id__other(self): + term = ConstraintTermDTO( + id="foo", + weight=3.14, + offset=123, + ) + assert term.calc_term_id() == "foo"