From 59f0c6680223b272d58458fbd1f8436105de120c Mon Sep 17 00:00:00 2001 From: belthlemar Date: Wed, 24 Jan 2024 09:52:38 +0100 Subject: [PATCH] refactor code --- antarest/core/exceptions.py | 2 +- .../business/binding_constraint_management.py | 170 +++++++++--------- .../command/create_binding_constraint.py | 25 +-- .../command/update_binding_constraint.py | 2 +- .../test_binding_constraints.py | 36 ++-- 5 files changed, 115 insertions(+), 120 deletions(-) diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index 06a7c55a16..292d1c27cd 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -201,7 +201,7 @@ def __init__(self, message: str) -> None: class InvalidFieldForVersionError(HTTPException): def __init__(self, message: str) -> None: - super().__init__(HTTPStatus.BAD_REQUEST, message) + super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, message) class MissingDataError(HTTPException): diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index a4b89271ad..90f5a2f01d 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -233,7 +233,7 @@ def create_binding_constraint( # Validates the matrices. Needed when the study is a variant because we only append the command to the list if isinstance(study, VariantStudy): - command.validates_and_fills_matrices(version, True) + command.validates_and_fills_matrices(version=version, create=True) file_study = self.storage_service.get_storage(study).get_raw(study) execute_or_add_commands(study, file_study, [command], self.storage_service) @@ -264,19 +264,17 @@ def update_binding_constraint( } study_version = int(study.version) - args = BindingConstraintManager.fill_group_value(data, constraint, study_version, args) - args = BindingConstraintManager.fill_matrices_according_to_version(data, study_version, args) + args = _fill_group_value(data, constraint, study_version, args) + args = _fill_matrices_according_to_version(data, study_version, args) if data.key == "time_step" and data.value != constraint.time_step: # The user changed the time step, we need to update the matrix accordingly - args = BindingConstraintManager.replace_matrices_according_to_frequency_and_version( - data, study_version, args - ) + args = _replace_matrices_according_to_frequency_and_version(data, study_version, args) command = UpdateBindingConstraint(**args) # Validates the matrices. Needed when the study is a variant because we only append the command to the list if isinstance(study, VariantStudy): - command.validates_and_fills_matrices(study_version, False) + command.validates_and_fills_matrices(version=study_version, create=False) execute_or_add_commands(study, file_study, [command], self.storage_service) @@ -293,80 +291,6 @@ def remove_binding_constraint(self, study: Study, binding_constraint_id: str) -> execute_or_add_commands(study, file_study, [command], self.storage_service) - @staticmethod - def fill_group_value( - data: UpdateBindingConstProps, constraint: BindingConstraintConfigType, version: int, args: Dict[str, Any] - ) -> Dict[str, Any]: - if version < 870: - if data.key == "group": - raise InvalidFieldForVersionError( - f"You cannot specify a group as your study version is older than v8.7: {data.value}" - ) - else: - # cast to 870 to use the attribute group - constraint = cast(BindingConstraintConfig870, constraint) - args["group"] = data.value if data.key == "group" else constraint.group - return args - - @staticmethod - def fill_matrices_according_to_version( - data: UpdateBindingConstProps, version: int, args: Dict[str, Any] - ) -> Dict[str, Any]: - if data.key == "values": - if version >= 870: - raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7") - args["values"] = data.value - return args - for matrix in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]: - if data.key == matrix: - if version < 870: - raise InvalidFieldForVersionError( - "You cannot fill a 'matrix_term' as these values refer to v8.7+ studies" - ) - args[matrix] = data.value - return args - return args - - @staticmethod - def replace_matrices_according_to_frequency_and_version( - data: UpdateBindingConstProps, version: int, args: Dict[str, Any] - ) -> Dict[str, Any]: - if version < 870: - matrix = { - BindingConstraintFrequency.HOURLY.value: default_bc_hourly_86, - BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_86, - BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_86, - }[data.value].tolist() - args["values"] = matrix - else: - matrix = { - BindingConstraintFrequency.HOURLY.value: default_bc_hourly_87, - BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87, - BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87, - }[data.value].tolist() - args["less_term_matrix"] = matrix - args["equal_term_matrix"] = matrix - args["greater_term_matrix"] = matrix - return args - - @staticmethod - def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constraint_term_id: str) -> int: - try: - index = [elm.id for elm in constraints_term].index(constraint_term_id) - return index - except ValueError: - return -1 - - @staticmethod - def get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str: - if isinstance(data, ClusterInfoDTO): - constraint_id = f"{data.area}.{data.cluster}" - else: - area1 = data.area1 if data.area1 < data.area2 else data.area2 - area2 = data.area2 if area1 == data.area1 else data.area1 - constraint_id = f"{area1}%{area2}" - return constraint_id - def add_new_constraint_term( self, study: Study, @@ -383,9 +307,9 @@ def add_new_constraint_term( if constraint_term.data is None: raise MissingDataError("Add new constraint term : data is missing") - constraint_id = BindingConstraintManager.get_constraint_id(constraint_term.data) + constraint_id = _get_constraint_id(constraint_term.data) constraints_term = constraint.constraints or [] - if BindingConstraintManager.find_constraint_term_id(constraints_term, constraint_id) >= 0: + if _find_constraint_term_id(constraints_term, constraint_id) >= 0: raise ConstraintAlreadyExistError(study.id) constraints_term.append( @@ -436,12 +360,12 @@ def update_constraint_term( if data_id is None: raise ConstraintIdNotFoundError(study.id) - data_term_index = BindingConstraintManager.find_constraint_term_id(constraints, data_id) + data_term_index = _find_constraint_term_id(constraints, data_id) if data_term_index < 0: raise ConstraintIdNotFoundError(study.id) if isinstance(data, ConstraintTermDTO): - constraint_id = BindingConstraintManager.get_constraint_id(data.data) if data.data is not None else data_id + constraint_id = _get_constraint_id(data.data) if data.data is not None else data_id current_constraint = constraints[data_term_index] constraints.append( ConstraintTermDTO( @@ -451,9 +375,7 @@ def update_constraint_term( data=data.data if data.data is not None else current_constraint.data, ) ) - del constraints[data_term_index] - else: - del constraints[data_term_index] + del constraints[data_term_index] coeffs = {} for term in constraints: @@ -481,3 +403,75 @@ def remove_constraint_term( term_id: str, ) -> None: return self.update_constraint_term(study, binding_constraint_id, term_id) + + +def _fill_group_value( + data: UpdateBindingConstProps, constraint: BindingConstraintConfigType, version: int, args: Dict[str, Any] +) -> Dict[str, Any]: + if version < 870: + if data.key == "group": + raise InvalidFieldForVersionError( + f"You cannot specify a group as your study version is older than v8.7: {data.value}" + ) + else: + # cast to 870 to use the attribute group + constraint = cast(BindingConstraintConfig870, constraint) + args["group"] = data.value if data.key == "group" else constraint.group + return args + + +def _fill_matrices_according_to_version( + data: UpdateBindingConstProps, version: int, args: Dict[str, Any] +) -> Dict[str, Any]: + if data.key == "values": + if version >= 870: + raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7") + args["values"] = data.value + return args + for matrix in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]: + if data.key == matrix: + if version < 870: + raise InvalidFieldForVersionError( + "You cannot fill a 'matrix_term' as these values refer to v8.7+ studies" + ) + args[matrix] = data.value + return args + return args + + +def _replace_matrices_according_to_frequency_and_version( + data: UpdateBindingConstProps, version: int, args: Dict[str, Any] +) -> Dict[str, Any]: + if version < 870: + matrix = { + BindingConstraintFrequency.HOURLY.value: default_bc_hourly_86, + BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_86, + BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_86, + }[data.value].tolist() + args["values"] = matrix + else: + matrix = { + BindingConstraintFrequency.HOURLY.value: default_bc_hourly_87, + BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87, + BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87, + }[data.value].tolist() + args["less_term_matrix"] = matrix + args["equal_term_matrix"] = matrix + args["greater_term_matrix"] = matrix + return args + + +def _find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constraint_term_id: str) -> int: + try: + index = [elm.id for elm in constraints_term].index(constraint_term_id) + return index + except ValueError: + return -1 + + +def _get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str: + if isinstance(data, ClusterInfoDTO): + return f"{data.area}.{data.cluster}" + area1 = data.area1 if data.area1 < data.area2 else data.area2 + area2 = data.area2 if area1 == data.area1 else data.area1 + return f"{area1}%{area2}" diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 92eb6c5f1f..52b02155ef 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import numpy as np -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Extra, Field, root_validator from antarest.matrixstore.model import MatrixData from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency @@ -92,18 +92,17 @@ class BindingConstraintMatrices(BaseModel, extra=Extra.forbid): ) equal_term_matrix: Optional[Union[MatrixType, str]] = Field(None, description="equal term matrix for v8.7+ studies") - def __init__(self, **data: Any) -> None: - super().__init__(**data) - if data.get("values") is not None: + @root_validator() + def check_matrices( + cls, values: Dict[str, Optional[Union[MatrixType, str]]] + ) -> Dict[str, Optional[Union[MatrixType, str]]]: + if values["values"]: for term in ["less_term_matrix", "greater_term_matrix", "equal_term_matrix"]: - if data.get(term) is not None: + if values[term]: raise ValueError( f"You cannot fill 'values' (matrix before v8.7) and a matrix term: {term} (matrices since v8.7)" ) - self.values = data.get("values") - self.less_term_matrix = data.get("less_term_matrix") - self.greater_term_matrix = data.get("greater_term_matrix") - self.equal_term_matrix = data.get("equal_term_matrix") + return values class AbstractBindingConstraintCommand( @@ -145,7 +144,9 @@ def get_inner_matrices(self) -> List[str]: if matrix is not None ] - def get_corresponding_matrices(self, v: Optional[Union[MatrixType, str]], old: bool, create: bool) -> Optional[str]: + def get_corresponding_matrices( + self, v: Optional[Union[MatrixType, str]], *, old: bool, create: bool + ) -> Optional[str]: constants: GeneratorMatrixConstants constants = self.command_context.generator_matrix_constants time_step = self.time_step @@ -178,7 +179,7 @@ def get_corresponding_matrices(self, v: Optional[Union[MatrixType, str]], old: b # pragma: no cover raise TypeError(repr(v)) - def validates_and_fills_matrices(self, version: int, create: bool) -> None: + def validates_and_fills_matrices(self, *, version: int, create: bool) -> None: if version < 870: self.values = self.get_corresponding_matrices(self.values, old=True, create=create) else: @@ -209,7 +210,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: binding_constraints = study_data.tree.get(["input", "bindingconstraints", "bindingconstraints"]) new_key = len(binding_constraints) bd_id = transform_name_to_id(self.name) - self.validates_and_fills_matrices(study_data.config.version, True) + self.validates_and_fills_matrices(version=study_data.config.version, create=True) return apply_binding_constraint( study_data, diff --git a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py index 985734ec7f..6ac2210f57 100644 --- a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py @@ -53,7 +53,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: message="Failed to retrieve existing binding constraint", ) - self.validates_and_fills_matrices(study_data.config.version, False) + self.validates_and_fills_matrices(version=study_data.config.version, create=False) return apply_binding_constraint( study_data, diff --git a/tests/integration/study_data_blueprint/test_binding_constraints.py b/tests/integration/study_data_blueprint/test_binding_constraints.py index 6d9db10ccc..a37afc365e 100644 --- a/tests/integration/study_data_blueprint/test_binding_constraints.py +++ b/tests/integration/study_data_blueprint/test_binding_constraints.py @@ -38,7 +38,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st "metadata": {"country": "FR"}, }, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() res = client.post( f"/v1/studies/{study_id}/areas", @@ -49,7 +49,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st "metadata": {"country": "DE"}, }, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() res = client.post( f"/v1/studies/{study_id}/links", @@ -59,7 +59,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st "area2": area2_name, }, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() if study_type == "variant": # Create Variant @@ -68,7 +68,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st headers=user_headers, params={"name": "Variant 1"}, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() study_id = res.json() # ============================= @@ -93,7 +93,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st ], headers=user_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() res = client.post( f"/v1/studies/{study_id}/commands", @@ -112,7 +112,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st ], headers=user_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() # Creates a binding constraint with the new API res = client.post( @@ -127,7 +127,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st }, headers=user_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() # Get Binding Constraint list res = client.get(f"/v1/studies/{study_id}/bindingconstraints", headers=user_headers) @@ -180,7 +180,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st }, headers=user_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() # Get Binding Constraint res = client.get( @@ -295,7 +295,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st }, headers=user_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() # Check that the matrix is a daily/weekly matrix res = client.get( @@ -402,7 +402,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st }, headers=user_headers, ) - assert res.status_code == 400 + assert res.status_code == 422 assert res.json()["description"] == "You cannot fill a 'matrix_term' as these values refer to v8.7+ studies" # Wrong matrix shape @@ -440,7 +440,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st json={"key": "group", "value": grp_name}, headers=user_headers, ) - assert res.status_code == 400 + assert res.status_code == 422 assert res.json()["exception"] == "InvalidFieldForVersionError" assert ( res.json()["description"] @@ -453,7 +453,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st json={"key": "less_term_matrix", "value": [[]]}, headers=user_headers, ) - assert res.status_code == 400 + assert res.status_code == 422 assert res.json()["exception"] == "InvalidFieldForVersionError" assert res.json()["description"] == "You cannot fill a 'matrix_term' as these values refer to v8.7+ studies" @@ -493,7 +493,7 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud headers=admin_headers, params={"name": "Variant 1"}, ) - assert res.status_code == 200 + assert res.status_code in {200, 201} study_id = res.json() # ============================= @@ -508,7 +508,7 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud json={"name": bc_id_wo_group, **args}, headers=admin_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() res = client.get(f"/v1/studies/{study_id}/bindingconstraints/{bc_id_wo_group}", headers=admin_headers) assert res.json()["group"] == "default" @@ -520,7 +520,7 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud json={"name": bc_id_w_group, "group": "specific_grp", **args}, headers=admin_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() res = client.get(f"/v1/studies/{study_id}/bindingconstraints/{bc_id_w_group}", headers=admin_headers) assert res.json()["group"] == "specific_grp" @@ -534,7 +534,7 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud json={"name": bc_id_w_matrix, "less_term_matrix": matrix_to_list, **args}, headers=admin_headers, ) - assert res.status_code == 200, res.json() + assert res.status_code in {200, 201}, res.json() if study_type == "variant": res = client.get(f"/v1/studies/{study_id}/commands", headers=admin_headers) @@ -653,7 +653,7 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud }, headers=admin_headers, ) - assert res.status_code == 400 + assert res.status_code == 422 assert res.json()["description"] == "You cannot fill 'values' as it refers to the matrix before v8.7" # Update with old matrices @@ -662,6 +662,6 @@ def test_for_version_870(self, client: TestClient, admin_access_token: str, stud json={"key": "values", "value": [[]]}, headers=admin_headers, ) - assert res.status_code == 400 + assert res.status_code == 422 assert res.json()["exception"] == "InvalidFieldForVersionError" assert res.json()["description"] == "You cannot fill 'values' as it refers to the matrix before v8.7"