From d5eecdec5277921da94c4ace85f89d88aa2c38d1 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 22 Mar 2024 16:06:16 +0100 Subject: [PATCH] feat(api-bc): add endpoints to manage and validate BC groups --- antarest/core/exceptions.py | 6 +- .../business/binding_constraint_management.py | 158 ++++++++++++------ antarest/study/web/study_data_blueprint.py | 138 ++++++++++++--- .../test_binding_constraints.py | 79 ++++++--- 4 files changed, 281 insertions(+), 100 deletions(-) diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index 079263ae91..9094d322be 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -1,6 +1,6 @@ import re from http import HTTPStatus -from typing import Optional +from typing import Any, Optional from fastapi.exceptions import HTTPException @@ -386,8 +386,8 @@ def __init__(self, message: str) -> None: class IncoherenceBetweenMatricesLength(HTTPException): - def __init__(self, message: str) -> None: - super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, message) + def __init__(self, detail: Any) -> None: + super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, detail) class MissingDataError(HTTPException): diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index 241d0cf98a..e29d4cad5e 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -1,7 +1,11 @@ -from typing import Any, Dict, List, Optional, Union +import collections +import itertools +import logging +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import numpy as np -from pydantic import BaseModel, validator, root_validator +from pydantic import BaseModel, root_validator, validator +from requests.utils import CaseInsensitiveDict from antarest.core.exceptions import ( BindingConstraintNotFoundError, @@ -44,6 +48,11 @@ from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy +logger = logging.getLogger(__name__) + +DEFAULT_GROUP = "default" +"""Default group name for binding constraints if missing or empty.""" + class AreaLinkDTO(BaseModel): """ @@ -162,7 +171,7 @@ def accept(self, constraint: "BindingConstraintConfigType") -> bool: if self.comments.upper() not in comments.upper(): return False if self.group: - group = getattr(constraint, "group") or "" + group = getattr(constraint, "group", DEFAULT_GROUP) if self.group.upper() != group.upper(): return False if self.time_step is not None and self.time_step != constraint.time_step: @@ -294,6 +303,42 @@ class BindingConstraintConfig870(BindingConstraintConfig): BindingConstraintConfigType = Union[BindingConstraintConfig870, BindingConstraintConfig] +def _validate_binding_constraints(file_study: FileStudy, bcs: Sequence[BindingConstraintConfigType]) -> bool: + if int(file_study.config.version) < 870: + matrix_id_fmts = {"{bc_id}"} + else: + matrix_id_fmts = {"{bc_id}_eq", "{bc_id}_lt", "{bc_id}_gt"} + + references_by_shapes = collections.defaultdict(list) + _total = len(bcs) * len(matrix_id_fmts) + for _index, (bc, fmt) in enumerate(itertools.product(bcs, matrix_id_fmts), 1): + matrix_id = fmt.format(bc_id=bc.id) + logger.info(f"⏲ Validating BC '{bc.id}': {matrix_id=} [{_index}/{_total}]") + _obj = file_study.tree.get(url=["input", "bindingconstraints", matrix_id]) + _array = np.array(_obj["data"], dtype=float) + if _array.size == 0 or _array.shape[1] == 1: + continue + references_by_shapes[_array.shape].append((bc.id, matrix_id)) + del _obj + del _array + + if len(references_by_shapes) > 1: + most_common = collections.Counter(references_by_shapes.keys()).most_common() + invalid_constraints = collections.defaultdict(list) + for shape, _ in most_common[1:]: + references = references_by_shapes[shape] + for bc_id, matrix_id in references: + invalid_constraints[bc_id].append(f"'{matrix_id}' {shape}") + expected_shape = most_common[0][0] + detail = { + "msg": f"Matrix shapes mismatch in binding constraints group. Expected shape: {expected_shape}", + "invalid_constraints": dict(invalid_constraints), + } + raise IncoherenceBetweenMatricesLength(detail) + + return True + + class BindingConstraintManager: def __init__( self, @@ -400,20 +445,64 @@ def get_binding_constraint( # Else we return all the matching elements return list(result.values()) - def validate_binding_constraint(self, study: Study, constraint_id: str) -> None: - if int(study.version) < 870: - return # There's nothing to check for constraints before v8.7 - file_study = self.storage_service.get_storage(study).get_raw(study) + def get_binding_constraint_groups(self, study: Study) -> Mapping[str, Sequence[BindingConstraintConfigType]]: + """ + Get all binding constraints grouped by group name. + + Args: + study: the study + + Returns: + A dictionary with group names as keys and lists of binding constraints as values. + """ + storage_service = self.storage_service.get_storage(study) + file_study = storage_service.get_raw(study) config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"]) - group = next((value["group"] for value in config.values() if value["id"] == constraint_id), None) - if not group: - raise BindingConstraintNotFoundError(study.id) - matrix_terms = { - "eq": get_matrix_data(file_study, constraint_id, "eq"), - "lt": get_matrix_data(file_study, constraint_id, "lt"), - "gt": get_matrix_data(file_study, constraint_id, "gt"), - } - check_matrices_coherence(file_study, group, constraint_id, matrix_terms) + bcs_by_group = CaseInsensitiveDict() # type: ignore + for value in config.values(): + _bc_config = BindingConstraintManager.process_constraint(value, int(study.version)) + _group = getattr(_bc_config, "group", DEFAULT_GROUP) + bcs_by_group.setdefault(_group, []).append(_bc_config) + return bcs_by_group + + def get_binding_constraint_group(self, study: Study, group_name: str) -> Sequence[BindingConstraintConfigType]: + """ + Get all binding constraints from a given group. + + Args: + study: the study. + group_name: the group name (case-insensitive). + + Returns: + A list of binding constraints from the group. + """ + groups = self.get_binding_constraint_groups(study) + if group_name not in groups: + raise BindingConstraintNotFoundError(f"Group '{group_name}' not found") + return groups[group_name] + + def validate_binding_constraint_group(self, study: Study, group_name: str) -> bool: + storage_service = self.storage_service.get_storage(study) + file_study = storage_service.get_raw(study) + bcs_by_group = self.get_binding_constraint_groups(study) + if group_name not in bcs_by_group: + raise BindingConstraintNotFoundError(f"Group '{group_name}' not found") + bcs = bcs_by_group[group_name] + return _validate_binding_constraints(file_study, bcs) + + def validate_binding_constraint_groups(self, study: Study) -> bool: + storage_service = self.storage_service.get_storage(study) + file_study = storage_service.get_raw(study) + bcs_by_group = self.get_binding_constraint_groups(study) + invalid_groups = {} + for group_name, bcs in bcs_by_group.items(): + try: + _validate_binding_constraints(file_study, bcs) + except IncoherenceBetweenMatricesLength as e: + invalid_groups[group_name] = e.detail + if invalid_groups: + raise IncoherenceBetweenMatricesLength(invalid_groups) + return True def create_binding_constraint( self, @@ -446,7 +535,7 @@ def create_binding_constraint( "comments": data.comments or "", } if version >= 870: - args["group"] = data.group or "default" + args["group"] = data.group or DEFAULT_GROUP command = CreateBindingConstraint( **args, command_context=self.storage_service.variant_study_service.command_factory.command_context @@ -678,37 +767,6 @@ def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constrain return -1 -def get_binding_constraint_of_a_given_group(file_study: FileStudy, group_id: str) -> List[str]: - config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"]) - config_values = list(config.values()) - return [bd["id"] for bd in config_values if bd["group"] == group_id] - - -def check_matrices_coherence( - file_study: FileStudy, group_id: str, binding_constraint_id: str, matrix_terms: Dict[str, Any] -) -> None: - given_number_of_cols = set() - for term_str, term_data in matrix_terms.items(): - if term_data: - nb_cols = len(term_data[0]) - if nb_cols > 1: - given_number_of_cols.add(nb_cols) - if len(given_number_of_cols) > 1: - raise IncoherenceBetweenMatricesLength( - f"The matrices of {binding_constraint_id} must have the same number of columns, currently {given_number_of_cols}" - ) - if len(given_number_of_cols) == 1: - given_size = list(given_number_of_cols)[0] - for bd_id in get_binding_constraint_of_a_given_group(file_study, group_id): - for term in list(matrix_terms.keys()): - matrix_file = file_study.tree.get(url=["input", "bindingconstraints", f"{bd_id}_{term}"]) - column_size = len(matrix_file["data"][0]) - if column_size > 1 and column_size != given_size: - raise IncoherenceBetweenMatricesLength( - f"The matrices of the group {group_id} do not have the same number of columns" - ) - - def check_attributes_coherence( data: Union[BindingConstraintCreation, BindingConstraintEdition], study_version: int ) -> None: @@ -721,7 +779,3 @@ def check_attributes_coherence( raise InvalidFieldForVersionError("You cannot fill a 'matrix_term' as these values refer to v8.7+ studies") elif data.values: raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7") - - -def get_matrix_data(file_study: FileStudy, binding_constraint_id: str, keyword: str) -> List[Any]: - return file_study.tree.get(url=["input", "bindingconstraints", f"{binding_constraint_id}_{keyword}"])["data"] # type: ignore diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index cbd70f58b2..2bcb56100f 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -1,7 +1,7 @@ import enum import logging from http import HTTPStatus -from typing import Any, Dict, List, Optional, Sequence, Union, cast +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast from fastapi import APIRouter, Body, Depends, Query from starlette.responses import RedirectResponse @@ -984,35 +984,129 @@ def update_binding_constraint( return study_service.binding_constraint_manager.update_binding_constraint(study, binding_constraint_id, data) @bp.get( - "/studies/{uuid}/bindingconstraints/{binding_constraint_id}/validate", + "/studies/{uuid}/constraint-groups", tags=[APITag.study_data], - summary="Validate binding constraint configuration", + summary="Get the list of binding constraint groups", + ) + def get_binding_constraint_groups( + uuid: str, + current_user: JWTUser = Depends(auth.get_current_user), + ) -> Mapping[str, Sequence[BindingConstraintConfigType]]: + """ + Get the list of binding constraint groups for the study. + + Args: + - `uuid`: The UUID of the study. + + Returns: + - The list of binding constraints for each group. + """ + logger.info( + f"Fetching binding constraint groups for study {uuid}", + extra={"user": current_user.id}, + ) + params = RequestParameters(user=current_user) + study = study_service.check_study_access(uuid, StudyPermissionType.READ, params) + result = study_service.binding_constraint_manager.get_binding_constraint_groups(study) + return result + + @bp.get( + # We use "validate-all" because it is unlikely to conflict with a group name. + "/studies/{uuid}/constraint-groups/validate-all", + tags=[APITag.study_data], + summary="Validate all binding constraint groups", response_model=None, ) - def validate_binding_constraint( + def validate_binding_constraint_groups( uuid: str, - binding_constraint_id: str, current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: + ) -> bool: """ - Validates the binding constraint configuration. + Checks if the dimensions of the right-hand side matrices are consistent with + the dimensions of the binding constraint matrices within the same group. - Parameters: + Args: - `uuid`: The study UUID. - - `binding_constraint_id`: The binding constraint id to validate - For studies with versions prior to v8.7, no validation is performed. - For studies with version 8.7 or later, the endpoint checks if the dimensions - of the right-hand side matrices are consistent with the dimensions of the - binding constraint matrices within the same group. + Returns: + - `true` if all groups are valid. + + Raises: + - HTTPException(422) if any group is invalid. """ logger.info( - f"Validating binding constraint {binding_constraint_id} for study {uuid}", + f"Validating all binding constraint groups for study {uuid}", extra={"user": current_user.id}, ) params = RequestParameters(user=current_user) study = study_service.check_study_access(uuid, StudyPermissionType.READ, params) - return study_service.binding_constraint_manager.validate_binding_constraint(study, binding_constraint_id) + return study_service.binding_constraint_manager.validate_binding_constraint_groups(study) + + @bp.get( + "/studies/{uuid}/constraint-groups/{group}", + tags=[APITag.study_data], + summary="Get the binding constraint group", + ) + def get_binding_constraint_group( + uuid: str, + group: str, + current_user: JWTUser = Depends(auth.get_current_user), + ) -> Sequence[BindingConstraintConfigType]: + """ + Get the binding constraint group for the study. + + Args: + - `uuid`: The UUID of the study. + - `group`: The name of the binding constraint group (case-insensitive). + + Returns: + - The list of binding constraints in the group. + + Raises: + - HTTPException(404) if the group does not exist. + """ + logger.info( + f"Fetching binding constraint group '{group}' for study {uuid}", + extra={"user": current_user.id}, + ) + params = RequestParameters(user=current_user) + study = study_service.check_study_access(uuid, StudyPermissionType.READ, params) + result = study_service.binding_constraint_manager.get_binding_constraint_group(study, group) + return result + + @bp.get( + "/studies/{uuid}/constraint-groups/{group}/validate", + tags=[APITag.study_data], + summary="Validate the binding constraint group", + response_model=None, + ) + def validate_binding_constraint_group( + uuid: str, + group: str, + current_user: JWTUser = Depends(auth.get_current_user), + ) -> bool: + """ + Checks if the dimensions of the right-hand side matrices are consistent with + the dimensions of the binding constraint matrices within the same group. + + Args: + - `uuid`: The study UUID. + - `group`: The name of the binding constraint group (case-insensitive). + + Returns: + - `true` if the group is valid. + + Raises: + - HTTPException(404) if the group does not exist. + - HTTPException(422) if the group is invalid. + """ + logger.info( + f"Validating binding constraint group '{group}' for study {uuid}", + extra={"user": current_user.id}, + ) + params = RequestParameters(user=current_user) + study = study_service.check_study_access(uuid, StudyPermissionType.READ, params) + return study_service.binding_constraint_manager.validate_binding_constraint_group(study, group) @bp.post("/studies/{uuid}/bindingconstraints", tags=[APITag.study_data], summary="Create a binding constraint") def create_binding_constraint( @@ -1115,7 +1209,7 @@ def get_allocation_matrix( """ Get the hydraulic allocation matrix for all areas. - Parameters: + Args: - `uuid`: The study UUID. Returns the data frame matrix, where: @@ -1145,7 +1239,7 @@ def get_allocation_form_fields( """ Get the form fields used for the allocation form. - Parameters: + Args: - `uuid`: The study UUID, - `area_id`: the area ID. @@ -1183,7 +1277,7 @@ def set_allocation_form_fields( """ Update the hydraulic allocation of a given area. - Parameters: + Args: - `uuid`: The study UUID, - `area_id`: the area ID. @@ -1227,7 +1321,7 @@ def get_correlation_matrix( """ Get the hydraulic/load/solar/wind correlation matrix of a study. - Parameters: + Args: - `uuid`: The UUID of the study. - `columns`: a filter on the area identifiers: - Use no parameter to select all areas. @@ -1279,7 +1373,7 @@ def set_correlation_matrix( """ Set the hydraulic/load/solar/wind correlation matrix of a study. - Parameters: + Args: - `uuid`: The UUID of the study. - `index`: a list of all study areas. - `columns`: a list of selected production areas. @@ -1310,7 +1404,7 @@ def get_correlation_form_fields( """ Get the form fields used for the correlation form. - Parameters: + Args: - `uuid`: The UUID of the study. - `area_id`: the area ID. @@ -1349,7 +1443,7 @@ def set_correlation_form_fields( """ Update the hydraulic/load/solar/wind correlation of a given area. - Parameters: + Args: - `uuid`: The UUID of the study. - `area_id`: the area ID. diff --git a/tests/integration/study_data_blueprint/test_binding_constraints.py b/tests/integration/study_data_blueprint/test_binding_constraints.py index f1c9b80ad1..5950a8b121 100644 --- a/tests/integration/study_data_blueprint/test_binding_constraints.py +++ b/tests/integration/study_data_blueprint/test_binding_constraints.py @@ -258,8 +258,8 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st bc_id = binding_constraints_list[0]["id"] - # Asserts binding constraint configuration is always valid. - res = client.get(f"/v1/studies/{study_id}/bindingconstraints/{bc_id}/validate", headers=user_headers) + # Asserts binding constraint configuration is valid. + res = client.get(f"/v1/studies/{study_id}/constraint-groups", headers=user_headers) assert res.status_code == 200, res.json() # ============================= @@ -834,11 +834,15 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud assert res.status_code in {200, 201}, res.json() second_bc_id = res.json()["id"] - # todo: validate the BC group "Group 1" - # res = client.get(f"/v1/studies/{study_id}/bindingconstraints/Group 1/validate", headers=admin_headers) - # assert res.status_code == 422 - # assert res.json()["exception"] == "IncoherenceBetweenMatricesLength" - # assert res.json()["description"] == "Mismatched column count in 'Group 1'". + # validate the BC group "Group 1" + res = client.get(f"/v1/studies/{study_id}/constraint-groups/Group 1/validate", headers=admin_headers) + assert res.status_code == 422 + assert res.json()["exception"] == "IncoherenceBetweenMatricesLength" + description = res.json()["description"] + assert description == { + "invalid_constraints": {"second bc": ["'second bc_gt' (8784, 4)"]}, + "msg": "Matrix shapes mismatch in binding constraints group. Expected shape: (8784, 3)", + } # So, we correct the shape of the matrix of the Second BC res = client.put( @@ -877,11 +881,15 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud # This should succeed but cause the validation endpoint to fail. assert res.status_code in {200, 201}, res.json() - # todo: validate the BC group "Group 1" - # res = client.get(f"/v1/studies/{study_id}/bindingconstraints/Group 1/validate", headers=admin_headers) - # assert res.status_code == 422 - # assert res.json()["exception"] == "IncoherenceBetweenMatricesLength" - # assert res.json()["description"] == "Mismatched column count in 'Group 1'". + # validate the BC group "Group 1" + res = client.get(f"/v1/studies/{study_id}/constraint-groups/Group 1/validate", headers=admin_headers) + assert res.status_code == 422 + assert res.json()["exception"] == "IncoherenceBetweenMatricesLength" + description = res.json()["description"] + assert description == { + "invalid_constraints": {"third bc": ["'third bc_lt' (8784, 4)"]}, + "msg": "Matrix shapes mismatch in binding constraints group. Expected shape: (8784, 3)", + } # So, we correct the shape of the matrix of the Second BC res = client.put( @@ -904,22 +912,47 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud ) assert res.status_code in {200, 201}, res.json() - # todo: validate the "Group 2" - # # For the moment the bc is valid - # res = client.get(f"/v1/studies/{study_id}/bindingconstraints/Group 2/validate", headers=admin_headers) - # assert res.status_code in {200, 201}, res.json() + # validate the "Group 2": for the moment the BC is valid + res = client.get(f"/v1/studies/{study_id}/constraint-groups/Group 2/validate", headers=admin_headers) + assert res.status_code in {200, 201}, res.json() res = client.put( f"v1/studies/{study_id}/bindingconstraints/{second_bc_id}", - json={"greater_term_matrix": matrix_lt3.tolist()}, + json={"greater_term_matrix": matrix_gt4.tolist()}, headers=admin_headers, ) # This should succeed but cause the validation endpoint to fail. assert res.status_code in {200, 201}, res.json() - # For the moment the "Group 2" is valid - # todo: validate the "Group 2" - # res = client.get(f"/v1/studies/{study_id}/bindingconstraints/Group 2/validate", headers=admin_headers) - # assert res.status_code == 422 - # assert res.json()["exception"] == "IncoherenceBetweenMatricesLength" - # assert res.json()["description"] == "Mismatched column count in 'Group 2'". + # Collect all the binding constraints groups + res = client.get(f"/v1/studies/{study_id}/constraint-groups", headers=admin_headers) + assert res.status_code in {200, 201}, res.json() + groups = res.json() + assert set(groups) == {"default", "random_grp", "Group 1", "Group 2"} + assert groups["Group 2"] == [ + { + "comments": "New API", + "constraints": None, + "enabled": True, + "filter_synthesis": "", + "filter_year_by_year": "", + "group": "Group 2", + "id": "second bc", + "name": "Second BC", + "operator": "less", + "time_step": "hourly", + } + ] + + # Validate all binding constraints groups + res = client.get(f"/v1/studies/{study_id}/constraint-groups/validate-all", headers=admin_headers) + assert res.status_code == 422, res.json() + exception = res.json()["exception"] + description = res.json()["description"] + assert exception == "IncoherenceBetweenMatricesLength" + assert description == { + "Group 1": { + "msg": "Matrix shapes mismatch in binding constraints group. Expected shape: (8784, 3)", + "invalid_constraints": {"third bc": ["'third bc_lt' (8784, 4)"]}, + } + }