Skip to content

Commit

Permalink
feat(api-bc): improve the calculation of BC terms IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
hdinia committed Jan 30, 2024
1 parent 8fbcfa7 commit b42b2c1
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 66 deletions.
2 changes: 1 addition & 1 deletion antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.BAD_REQUEST, message)


class NoBindingConstraintError(HTTPException):
class BindingConstraintNotFoundError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.NOT_FOUND, message)

Expand Down
88 changes: 65 additions & 23 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, validator

from antarest.core.exceptions import (
ConstraintAlreadyExistError,
ConstraintIdNotFoundError,
DuplicateConstraintName,
InvalidConstraintName,
MissingDataError,
NoBindingConstraintError,
BindingConstraintNotFoundError,
NoConstraintError,
)
from antarest.matrixstore.model import MatrixData
Expand All @@ -29,21 +29,72 @@
from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint


class LinkInfoDTO(BaseModel):
class AreaLinkDTO(BaseModel):
"""
DTO for a constraint term on a link between two areas.
Attributes:
area1: the first area ID
area2: the second area ID
"""

area1: str
area2: str

def generate_id(self) -> str:
"""Return the constraint term ID for this link, of the form "area1%area2"."""
# Ensure IDs are in alphabetical order and lower case
ids = sorted((self.area1.lower(), self.area2.lower()))
return "%".join(ids)


class AreaClusterDTO(BaseModel):
"""
DTO for a constraint term on a cluster in an area.
Attributes:
area: the area ID
cluster: the cluster ID
"""

class ClusterInfoDTO(BaseModel):
area: str
cluster: str

def generate_id(self) -> str:
"""Return the constraint term ID for this Area/cluster constraint, of the form "area.cluster"."""
# Ensure IDs are in lower case
ids = [self.area.lower(), self.cluster.lower()]
return ".".join(ids)


class ConstraintTermDTO(BaseModel):
"""
DTO for a constraint term.
Attributes:
id: the constraint term ID, of the form "area1%area2" or "area.cluster".
weight: the constraint term weight, if any.
offset: the constraint term offset, if any.
data: the constraint term data (link or cluster), if any.
"""

id: Optional[str]
weight: Optional[float]
offset: Optional[float]
data: Optional[Union[LinkInfoDTO, ClusterInfoDTO]]
data: Optional[Union[AreaLinkDTO, AreaClusterDTO]]

@validator("id")
def id_to_lower(cls, v: Optional[str]) -> Optional[str]:
"""Ensure the ID is lower case."""
if v is None:
return None
return v.lower()

def generate_id(self) -> str:
"""Return the constraint term ID for this term based on its data."""
if self.data is None:
return self.id or ""
return self.data.generate_id()


class UpdateBindingConstProps(BaseModel):
Expand Down Expand Up @@ -97,12 +148,12 @@ def parse_constraint(key: str, value: str, char: str, new_config: BindingConstra
id=key,
weight=weight,
offset=offset if offset is not None else None,
data=LinkInfoDTO(
data=AreaLinkDTO(
area1=value1,
area2=value2,
)
if char == "%"
else ClusterInfoDTO(
else AreaClusterDTO(
area=value1,
cluster=value2,
),
Expand Down Expand Up @@ -208,7 +259,7 @@ def update_binding_constraint(
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
if not isinstance(constraint, BindingConstraintDTO):
raise NoBindingConstraintError(study.id)
raise BindingConstraintNotFoundError(study.id)

if data.key == "time_step" and data.value != constraint.time_step:
# The user changed the time step, we need to update the matrix accordingly
Expand Down Expand Up @@ -243,16 +294,6 @@ def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constrain
except ValueError:
return -1

@staticmethod
def get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str:
if isinstance(data, ClusterInfoDTO):
constraint_id = f"{data.area}.{data.cluster}"
else:
area1 = data.area1 if data.area1 < data.area2 else data.area2
area2 = data.area2 if area1 == data.area1 else data.area1
constraint_id = f"{area1}%{area2}"
return constraint_id

def add_new_constraint_term(
self,
study: Study,
Expand All @@ -262,12 +303,12 @@ def add_new_constraint_term(
file_study = self.storage_service.get_storage(study).get_raw(study)
constraint = self.get_binding_constraint(study, binding_constraint_id)
if not isinstance(constraint, BindingConstraintDTO):
raise NoBindingConstraintError(study.id)
raise BindingConstraintNotFoundError(study.id)

if constraint_term.data is None:
raise MissingDataError("Add new constraint term : data is missing")

constraint_id = BindingConstraintManager.get_constraint_id(constraint_term.data)
constraint_id = constraint_term.data.generate_id()
constraints_term = constraint.constraints or []
if BindingConstraintManager.find_constraint_term_id(constraints_term, constraint_id) >= 0:
raise ConstraintAlreadyExistError(study.id)
Expand Down Expand Up @@ -310,22 +351,22 @@ def update_constraint_term(
constraint = self.get_binding_constraint(study, binding_constraint_id)

if not isinstance(constraint, BindingConstraintDTO):
raise NoBindingConstraintError(study.id)
raise BindingConstraintNotFoundError(study.id)

constraint_terms = constraint.constraints # existing constraint terms
if constraint_terms is None:
raise NoConstraintError(study.id)

term_id = term.id if isinstance(term, ConstraintTermDTO) else term
if term_id is None:
raise NoConstraintError(study.id)
raise ConstraintIdNotFoundError(study.id)

term_id_index = BindingConstraintManager.find_constraint_term_id(constraint_terms, term_id)
if term_id_index < 0:
raise ConstraintIdNotFoundError(study.id)

if isinstance(term, ConstraintTermDTO):
updated_term_id = BindingConstraintManager.get_constraint_id(term.data) if term.data else term_id
updated_term_id = term.data.generate_id() if term.data else term_id
current_constraint = constraint_terms[term_id_index]

constraint_terms[term_id_index] = ConstraintTermDTO(
Expand Down Expand Up @@ -353,6 +394,7 @@ def update_constraint_term(
)
execute_or_add_commands(study, file_study, [command], self.storage_service)

# FIXME create a dedicated delete service
def remove_constraint_term(
self,
study: Study,
Expand Down
113 changes: 71 additions & 42 deletions tests/integration/study_data_blueprint/test_binding_constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,67 @@
import pytest

Check failure on line 1 in tests/integration/study_data_blueprint/test_binding_constraints.py

View workflow job for this annotation

GitHub Actions / python-lint

Imports are incorrectly sorted and/or formatted.
from starlette.testclient import TestClient

from antarest.study.business.binding_constraint_management import (
AreaClusterDTO, AreaLinkDTO, ConstraintTermDTO
)


class TestAreaLinkDTO:
@pytest.mark.parametrize(
"area1, area2, expected",
[
("Area 1", "Area 2", "area 1%area 2"),
("de", "fr", "de%fr"),
("fr", "de", "de%fr"),
("FR", "de", "de%fr"),
],
)
def test_constraint_id(self, area1: str, area2: str, expected: str) -> None:
info = AreaLinkDTO(area1=area1, area2=area2)
assert info.generate_id() == expected


class TestAreaClusterDTO:
@pytest.mark.parametrize(
"area, cluster, expected",
[
("Area 1", "Cluster X", "area 1.cluster x"),
("de", "Nuclear", "de.nuclear"),
("GB", "Gas", "gb.gas"),
],
)
def test_constraint_id(self, area: str, cluster: str, expected: str) -> None:
info = AreaClusterDTO(area=area, cluster=cluster)
assert info.generate_id() == expected


class TestConstraintTermDTO:
def test_constraint_id__link(self):
term = ConstraintTermDTO(
id="foo",
weight=3.14,
offset=123,
data=AreaLinkDTO(area1="Area 1", area2="Area 2"),
)
assert term.generate_id() == term.data.generate_id()

def test_constraint_id__cluster(self):
term = ConstraintTermDTO(
id="foo",
weight=3.14,
offset=123,
data=AreaClusterDTO(area="Area 1", cluster="Cluster X"),
)
assert term.generate_id() == term.data.generate_id()

def test_constraint_id__other(self):
term = ConstraintTermDTO(
id="foo",
weight=3.14,
offset=123,
)
assert term.generate_id() == "foo"


@pytest.mark.unit_test
class TestBindingConstraints:
Expand Down Expand Up @@ -173,7 +234,10 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) ->
json={
"weight": 1,
"offset": 2,
"data": {"area": area1_id, "cluster": cluster_id},
"data": {
"area": area1_id,
"cluster": cluster_id,
}, # NOTE: cluster_id in term data can be uppercase, but it must be lowercase in the returned ini configuration file
},
headers=user_headers,
)
Expand All @@ -195,15 +259,15 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) ->
"weight": 1.0,
},
{
"data": {"area": area1_id, "cluster": cluster_id},
"id": f"{area1_id}.{cluster_id}",
"data": {"area": area1_id, "cluster": cluster_id.lower()},
"id": f"{area1_id}.{cluster_id.lower()}",
"offset": 2.0,
"weight": 1.0,
},
]
assert constraint_terms == expected

# Update constraint cluster term
# Update constraint cluster term with uppercase cluster_id
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/term",
json={
Expand All @@ -214,7 +278,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) ->
)
assert res.status_code == 200, res.json()

# Get binding constraints list to check updated term
# Check updated terms, cluster_id should be lowercase in the returned configuration
res = client.get(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}",
headers=user_headers,
Expand All @@ -230,49 +294,14 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str) ->
"weight": 1.0,
},
{
"data": {"area": area1_id, "cluster": cluster_id},
"id": f"{area1_id}.{cluster_id}",
"data": {"area": area1_id, "cluster": cluster_id.lower()},
"id": f"{area1_id}.{cluster_id.lower()}",
"offset": None, # updated
"weight": 3.0, # updated
},
]
assert constraint_terms == expected

# Update constraint term regardless of the case of the cluster id
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/term",
json={
"id": f"{area1_id}.Cluster 1",
"weight": 4,
},
headers=user_headers,
)
assert res.status_code == 200, res.json()

# The term should be successfully updated
res = client.get(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}",
headers=user_headers,
)
assert res.status_code == 200, res.json()
binding_constraint = res.json()
constraint_terms = binding_constraint["constraints"]
expected = [
{
"data": {"area1": area1_id, "area2": area2_id},
"id": f"{area1_id}%{area2_id}",
"offset": 2.0,
"weight": 1.0,
},
{
"data": {"area": area1_id, "cluster": cluster_id},
"id": f"{area1_id}.{cluster_id}",
"offset": None, # updated
"weight": 4.0, # updated
},
]
assert constraint_terms == expected

# Update constraint cluster term with invalid id
res = client.put(
f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/term",
Expand Down

0 comments on commit b42b2c1

Please sign in to comment.