Skip to content

Commit

Permalink
feat(api-bc): improve the calculation of binding constraints IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Jan 25, 2024
1 parent b4e3fd8 commit ab42010
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 10 deletions.
68 changes: 58 additions & 10 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, validator

from antarest.core.exceptions import (
ConstraintAlreadyExistError,
Expand Down Expand Up @@ -30,21 +30,72 @@


class LinkInfoDTO(BaseModel):
"""
DTO for a constraint term on a link between two areas.
Attributes:
area1: the first area ID
area2: the second area ID
"""

area1: str
area2: str

def calc_term_id(self) -> str:
"""Return the constraint term ID for this link, of the form "area1%area2"."""
# Ensure IDs are in alphabetical order and lower case
ids = sorted((self.area1.lower(), self.area2.lower()))
return "%".join(ids)


class ClusterInfoDTO(BaseModel):
"""
DTO for a constraint term on a cluster in an area.
Attributes:
area: the area ID
cluster: the cluster ID
"""

area: str
cluster: str

def calc_term_id(self) -> str:
"""Return the constraint term ID for this Area/cluster constraint, of the form "area.cluster"."""
# Ensure IDs are in lower case
ids = [self.area.lower(), self.cluster.lower()]
return ".".join(ids)


class ConstraintTermDTO(BaseModel):
"""
DTO for a constraint term.
Attributes:
id: the constraint term ID, of the form "area1%area2" or "area.cluster".
weight: the constraint term weight, if any.
offset: the constraint term offset, if any.
data: the constraint term data (link or cluster), if any.
"""

id: Optional[str]
weight: Optional[float]
offset: Optional[float]
data: Optional[Union[LinkInfoDTO, ClusterInfoDTO]]

@validator("id")
def id_to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Ensure the ID is lower case."""
if v is None:
return None
return v.lower()

def calc_term_id(self) -> str:
"""Return the constraint term ID for this term based on its data."""
if self.data is None:
return self.id or ""
return self.data.calc_term_id()


class UpdateBindingConstProps(BaseModel):
key: str
Expand Down Expand Up @@ -245,13 +296,7 @@ def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constrain

@staticmethod
def get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str:
if isinstance(data, ClusterInfoDTO):
constraint_id = f"{data.area}.{data.cluster}"
else:
area1 = data.area1 if data.area1 < data.area2 else data.area2
area2 = data.area2 if area1 == data.area1 else data.area1
constraint_id = f"{area1}%{area2}"
return constraint_id
return data.calc_term_id()

def add_new_constraint_term(
self,
Expand All @@ -262,12 +307,13 @@ def add_new_constraint_term(
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
if not isinstance(constraint, BindingConstraintDTO):
# todo: should be renamed 'BindingConstraintNotFoundError'
raise NoBindingConstraintError(study.id)

if constraint_term.data is None:
raise MissingDataError("Add new constraint term : data is missing")

constraint_id = BindingConstraintManager.get_constraint_id(constraint_term.data)
constraint_id = constraint_term.data.calc_term_id()
constraints_term = constraint.constraints or []
if BindingConstraintManager.find_constraint_term_id(constraints_term, constraint_id) >= 0:
raise ConstraintAlreadyExistError(study.id)
Expand Down Expand Up @@ -309,10 +355,12 @@ def update_constraint_term(
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
if not isinstance(constraint, BindingConstraintDTO):
# todo: should be renamed 'BindingConstraintNotFoundError'
raise NoBindingConstraintError(study.id)

constraints = constraint.constraints
if constraints is None:
# todo: should be renamed 'ConstraintTermNotFoundError'
raise NoConstraintError(study.id)

data_id = data.id if isinstance(data, ConstraintTermDTO) else data
Expand All @@ -324,7 +372,7 @@ def update_constraint_term(
raise ConstraintIdNotFoundError(study.id)

if isinstance(data, ConstraintTermDTO):
constraint_id = BindingConstraintManager.get_constraint_id(data.data) if data.data is not None else data_id
constraint_id = data_id if data.data is None else data.data.calc_term_id()
current_constraint = constraints[data_term_index]
constraints.append(
ConstraintTermDTO(
Expand Down
60 changes: 60 additions & 0 deletions tests/study/business/test_binding_constraint_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest

from antarest.study.business.binding_constraint_management import ClusterInfoDTO, ConstraintTermDTO, LinkInfoDTO


class TestLinkInfoDTO:
@pytest.mark.parametrize(
"area1, area2, expected",
[
("Area 1", "Area 2", "area 1%area 2"),
("de", "fr", "de%fr"),
("fr", "de", "de%fr"),
("FR", "de", "de%fr"),
],
)
def test_constraint_id(self, area1: str, area2: str, expected: str) -> None:
info = LinkInfoDTO(area1=area1, area2=area2)
assert info.calc_term_id() == expected


class TestClusterInfoDTO:
@pytest.mark.parametrize(
"area, cluster, expected",
[
("Area 1", "Cluster X", "area 1.cluster x"),
("de", "Nuclear", "de.nuclear"),
("GB", "Gas", "gb.gas"),
],
)
def test_constraint_id(self, area: str, cluster: str, expected: str) -> None:
info = ClusterInfoDTO(area=area, cluster=cluster)
assert info.calc_term_id() == expected


class TestConstraintTermDTO:
def test_constraint_id__link(self):
term = ConstraintTermDTO(
id="foo",
weight=3.14,
offset=123,
data=LinkInfoDTO(area1="Area 1", area2="Area 2"),
)
assert term.calc_term_id() == term.data.calc_term_id()

def test_constraint_id__cluster(self):
term = ConstraintTermDTO(
id="foo",
weight=3.14,
offset=123,
data=ClusterInfoDTO(area="Area 1", cluster="Cluster X"),
)
assert term.calc_term_id() == term.data.calc_term_id()

def test_constraint_id__other(self):
term = ConstraintTermDTO(
id="foo",
weight=3.14,
offset=123,
)
assert term.calc_term_id() == "foo"

0 comments on commit ab42010

Please sign in to comment.