Skip to content

Commit

Permalink
feat(api-bc): add endpoints to manage and validate BC groups
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Mar 22, 2024
1 parent 1395e9c commit d5eecde
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 100 deletions.
6 changes: 3 additions & 3 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from http import HTTPStatus
from typing import Optional
from typing import Any, Optional

from fastapi.exceptions import HTTPException

Expand Down Expand Up @@ -386,8 +386,8 @@ def __init__(self, message: str) -> None:


class IncoherenceBetweenMatricesLength(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, message)
def __init__(self, detail: Any) -> None:
super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, detail)


class MissingDataError(HTTPException):
Expand Down
158 changes: 106 additions & 52 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Dict, List, Optional, Union
import collections
import itertools
import logging
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
from pydantic import BaseModel, validator, root_validator
from pydantic import BaseModel, root_validator, validator
from requests.utils import CaseInsensitiveDict

from antarest.core.exceptions import (
BindingConstraintNotFoundError,
Expand Down Expand Up @@ -44,6 +48,11 @@
from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy

logger = logging.getLogger(__name__)

DEFAULT_GROUP = "default"
"""Default group name for binding constraints if missing or empty."""


class AreaLinkDTO(BaseModel):
"""
Expand Down Expand Up @@ -162,7 +171,7 @@ def accept(self, constraint: "BindingConstraintConfigType") -> bool:
if self.comments.upper() not in comments.upper():
return False
if self.group:
group = getattr(constraint, "group") or ""
group = getattr(constraint, "group", DEFAULT_GROUP)
if self.group.upper() != group.upper():
return False
if self.time_step is not None and self.time_step != constraint.time_step:
Expand Down Expand Up @@ -294,6 +303,42 @@ class BindingConstraintConfig870(BindingConstraintConfig):
BindingConstraintConfigType = Union[BindingConstraintConfig870, BindingConstraintConfig]


def _validate_binding_constraints(file_study: FileStudy, bcs: Sequence[BindingConstraintConfigType]) -> bool:
if int(file_study.config.version) < 870:
matrix_id_fmts = {"{bc_id}"}
else:
matrix_id_fmts = {"{bc_id}_eq", "{bc_id}_lt", "{bc_id}_gt"}

references_by_shapes = collections.defaultdict(list)
_total = len(bcs) * len(matrix_id_fmts)
for _index, (bc, fmt) in enumerate(itertools.product(bcs, matrix_id_fmts), 1):
matrix_id = fmt.format(bc_id=bc.id)
logger.info(f"⏲ Validating BC '{bc.id}': {matrix_id=} [{_index}/{_total}]")
_obj = file_study.tree.get(url=["input", "bindingconstraints", matrix_id])
_array = np.array(_obj["data"], dtype=float)
if _array.size == 0 or _array.shape[1] == 1:
continue
references_by_shapes[_array.shape].append((bc.id, matrix_id))
del _obj
del _array

if len(references_by_shapes) > 1:
most_common = collections.Counter(references_by_shapes.keys()).most_common()
invalid_constraints = collections.defaultdict(list)
for shape, _ in most_common[1:]:
references = references_by_shapes[shape]
for bc_id, matrix_id in references:
invalid_constraints[bc_id].append(f"'{matrix_id}' {shape}")
expected_shape = most_common[0][0]
detail = {
"msg": f"Matrix shapes mismatch in binding constraints group. Expected shape: {expected_shape}",
"invalid_constraints": dict(invalid_constraints),
}
raise IncoherenceBetweenMatricesLength(detail)

return True


class BindingConstraintManager:
def __init__(
self,
Expand Down Expand Up @@ -400,20 +445,64 @@ def get_binding_constraint(
# Else we return all the matching elements
return list(result.values())

def validate_binding_constraint(self, study: Study, constraint_id: str) -> None:
if int(study.version) < 870:
return # There's nothing to check for constraints before v8.7
file_study = self.storage_service.get_storage(study).get_raw(study)
def get_binding_constraint_groups(self, study: Study) -> Mapping[str, Sequence[BindingConstraintConfigType]]:
"""
Get all binding constraints grouped by group name.
Args:
study: the study
Returns:
A dictionary with group names as keys and lists of binding constraints as values.
"""
storage_service = self.storage_service.get_storage(study)
file_study = storage_service.get_raw(study)
config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"])
group = next((value["group"] for value in config.values() if value["id"] == constraint_id), None)
if not group:
raise BindingConstraintNotFoundError(study.id)
matrix_terms = {
"eq": get_matrix_data(file_study, constraint_id, "eq"),
"lt": get_matrix_data(file_study, constraint_id, "lt"),
"gt": get_matrix_data(file_study, constraint_id, "gt"),
}
check_matrices_coherence(file_study, group, constraint_id, matrix_terms)
bcs_by_group = CaseInsensitiveDict() # type: ignore
for value in config.values():
_bc_config = BindingConstraintManager.process_constraint(value, int(study.version))
_group = getattr(_bc_config, "group", DEFAULT_GROUP)
bcs_by_group.setdefault(_group, []).append(_bc_config)
return bcs_by_group

def get_binding_constraint_group(self, study: Study, group_name: str) -> Sequence[BindingConstraintConfigType]:
"""
Get all binding constraints from a given group.
Args:
study: the study.
group_name: the group name (case-insensitive).
Returns:
A list of binding constraints from the group.
"""
groups = self.get_binding_constraint_groups(study)
if group_name not in groups:
raise BindingConstraintNotFoundError(f"Group '{group_name}' not found")
return groups[group_name]

def validate_binding_constraint_group(self, study: Study, group_name: str) -> bool:
storage_service = self.storage_service.get_storage(study)
file_study = storage_service.get_raw(study)
bcs_by_group = self.get_binding_constraint_groups(study)
if group_name not in bcs_by_group:
raise BindingConstraintNotFoundError(f"Group '{group_name}' not found")
bcs = bcs_by_group[group_name]
return _validate_binding_constraints(file_study, bcs)

def validate_binding_constraint_groups(self, study: Study) -> bool:
storage_service = self.storage_service.get_storage(study)
file_study = storage_service.get_raw(study)
bcs_by_group = self.get_binding_constraint_groups(study)
invalid_groups = {}
for group_name, bcs in bcs_by_group.items():
try:
_validate_binding_constraints(file_study, bcs)
except IncoherenceBetweenMatricesLength as e:
invalid_groups[group_name] = e.detail
if invalid_groups:
raise IncoherenceBetweenMatricesLength(invalid_groups)
return True

def create_binding_constraint(
self,
Expand Down Expand Up @@ -446,7 +535,7 @@ def create_binding_constraint(
"comments": data.comments or "",
}
if version >= 870:
args["group"] = data.group or "default"
args["group"] = data.group or DEFAULT_GROUP

command = CreateBindingConstraint(
**args, command_context=self.storage_service.variant_study_service.command_factory.command_context
Expand Down Expand Up @@ -678,37 +767,6 @@ def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constrain
return -1


def get_binding_constraint_of_a_given_group(file_study: FileStudy, group_id: str) -> List[str]:
config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"])
config_values = list(config.values())
return [bd["id"] for bd in config_values if bd["group"] == group_id]


def check_matrices_coherence(
file_study: FileStudy, group_id: str, binding_constraint_id: str, matrix_terms: Dict[str, Any]
) -> None:
given_number_of_cols = set()
for term_str, term_data in matrix_terms.items():
if term_data:
nb_cols = len(term_data[0])
if nb_cols > 1:
given_number_of_cols.add(nb_cols)
if len(given_number_of_cols) > 1:
raise IncoherenceBetweenMatricesLength(
f"The matrices of {binding_constraint_id} must have the same number of columns, currently {given_number_of_cols}"
)
if len(given_number_of_cols) == 1:
given_size = list(given_number_of_cols)[0]
for bd_id in get_binding_constraint_of_a_given_group(file_study, group_id):
for term in list(matrix_terms.keys()):
matrix_file = file_study.tree.get(url=["input", "bindingconstraints", f"{bd_id}_{term}"])
column_size = len(matrix_file["data"][0])
if column_size > 1 and column_size != given_size:
raise IncoherenceBetweenMatricesLength(
f"The matrices of the group {group_id} do not have the same number of columns"
)


def check_attributes_coherence(
data: Union[BindingConstraintCreation, BindingConstraintEdition], study_version: int
) -> None:
Expand All @@ -721,7 +779,3 @@ def check_attributes_coherence(
raise InvalidFieldForVersionError("You cannot fill a 'matrix_term' as these values refer to v8.7+ studies")
elif data.values:
raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7")


def get_matrix_data(file_study: FileStudy, binding_constraint_id: str, keyword: str) -> List[Any]:
return file_study.tree.get(url=["input", "bindingconstraints", f"{binding_constraint_id}_{keyword}"])["data"] # type: ignore
Loading

0 comments on commit d5eecde

Please sign in to comment.