Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jan 24, 2024
1 parent bcf072d commit 59f0c66
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 120 deletions.
2 changes: 1 addition & 1 deletion antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, message: str) -> None:

class InvalidFieldForVersionError(HTTPException):
def __init__(self, message: str) -> None:
super().__init__(HTTPStatus.BAD_REQUEST, message)
super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, message)


class MissingDataError(HTTPException):
Expand Down
170 changes: 82 additions & 88 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def create_binding_constraint(

# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
command.validates_and_fills_matrices(version, True)
command.validates_and_fills_matrices(version=version, create=True)

file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
Expand Down Expand Up @@ -264,19 +264,17 @@ def update_binding_constraint(
}

study_version = int(study.version)
args = BindingConstraintManager.fill_group_value(data, constraint, study_version, args)
args = BindingConstraintManager.fill_matrices_according_to_version(data, study_version, args)
args = _fill_group_value(data, constraint, study_version, args)
args = _fill_matrices_according_to_version(data, study_version, args)

if data.key == "time_step" and data.value != constraint.time_step:
# The user changed the time step, we need to update the matrix accordingly
args = BindingConstraintManager.replace_matrices_according_to_frequency_and_version(
data, study_version, args
)
args = _replace_matrices_according_to_frequency_and_version(data, study_version, args)

command = UpdateBindingConstraint(**args)
# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
command.validates_and_fills_matrices(study_version, False)
command.validates_and_fills_matrices(version=study_version, create=False)

execute_or_add_commands(study, file_study, [command], self.storage_service)

Expand All @@ -293,80 +291,6 @@ def remove_binding_constraint(self, study: Study, binding_constraint_id: str) ->

execute_or_add_commands(study, file_study, [command], self.storage_service)

@staticmethod
def fill_group_value(
data: UpdateBindingConstProps, constraint: BindingConstraintConfigType, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if version < 870:
if data.key == "group":
raise InvalidFieldForVersionError(
f"You cannot specify a group as your study version is older than v8.7: {data.value}"
)
else:
# cast to 870 to use the attribute group
constraint = cast(BindingConstraintConfig870, constraint)
args["group"] = data.value if data.key == "group" else constraint.group
return args

@staticmethod
def fill_matrices_according_to_version(
data: UpdateBindingConstProps, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if data.key == "values":
if version >= 870:
raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7")
args["values"] = data.value
return args
for matrix in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]:
if data.key == matrix:
if version < 870:
raise InvalidFieldForVersionError(
"You cannot fill a 'matrix_term' as these values refer to v8.7+ studies"
)
args[matrix] = data.value
return args
return args

@staticmethod
def replace_matrices_according_to_frequency_and_version(
data: UpdateBindingConstProps, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if version < 870:
matrix = {
BindingConstraintFrequency.HOURLY.value: default_bc_hourly_86,
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_86,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_86,
}[data.value].tolist()
args["values"] = matrix
else:
matrix = {
BindingConstraintFrequency.HOURLY.value: default_bc_hourly_87,
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87,
}[data.value].tolist()
args["less_term_matrix"] = matrix
args["equal_term_matrix"] = matrix
args["greater_term_matrix"] = matrix
return args

@staticmethod
def find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constraint_term_id: str) -> int:
try:
index = [elm.id for elm in constraints_term].index(constraint_term_id)
return index
except ValueError:
return -1

@staticmethod
def get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str:
if isinstance(data, ClusterInfoDTO):
constraint_id = f"{data.area}.{data.cluster}"
else:
area1 = data.area1 if data.area1 < data.area2 else data.area2
area2 = data.area2 if area1 == data.area1 else data.area1
constraint_id = f"{area1}%{area2}"
return constraint_id

def add_new_constraint_term(
self,
study: Study,
Expand All @@ -383,9 +307,9 @@ def add_new_constraint_term(
if constraint_term.data is None:
raise MissingDataError("Add new constraint term : data is missing")

constraint_id = BindingConstraintManager.get_constraint_id(constraint_term.data)
constraint_id = _get_constraint_id(constraint_term.data)
constraints_term = constraint.constraints or []
if BindingConstraintManager.find_constraint_term_id(constraints_term, constraint_id) >= 0:
if _find_constraint_term_id(constraints_term, constraint_id) >= 0:
raise ConstraintAlreadyExistError(study.id)

constraints_term.append(
Expand Down Expand Up @@ -436,12 +360,12 @@ def update_constraint_term(
if data_id is None:
raise ConstraintIdNotFoundError(study.id)

data_term_index = BindingConstraintManager.find_constraint_term_id(constraints, data_id)
data_term_index = _find_constraint_term_id(constraints, data_id)
if data_term_index < 0:
raise ConstraintIdNotFoundError(study.id)

if isinstance(data, ConstraintTermDTO):
constraint_id = BindingConstraintManager.get_constraint_id(data.data) if data.data is not None else data_id
constraint_id = _get_constraint_id(data.data) if data.data is not None else data_id
current_constraint = constraints[data_term_index]
constraints.append(
ConstraintTermDTO(
Expand All @@ -451,9 +375,7 @@ def update_constraint_term(
data=data.data if data.data is not None else current_constraint.data,
)
)
del constraints[data_term_index]
else:
del constraints[data_term_index]
del constraints[data_term_index]

coeffs = {}
for term in constraints:
Expand Down Expand Up @@ -481,3 +403,75 @@ def remove_constraint_term(
term_id: str,
) -> None:
return self.update_constraint_term(study, binding_constraint_id, term_id)


def _fill_group_value(
data: UpdateBindingConstProps, constraint: BindingConstraintConfigType, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if version < 870:
if data.key == "group":
raise InvalidFieldForVersionError(
f"You cannot specify a group as your study version is older than v8.7: {data.value}"
)
else:
# cast to 870 to use the attribute group
constraint = cast(BindingConstraintConfig870, constraint)
args["group"] = data.value if data.key == "group" else constraint.group
return args


def _fill_matrices_according_to_version(
data: UpdateBindingConstProps, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if data.key == "values":
if version >= 870:
raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7")
args["values"] = data.value
return args
for matrix in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]:
if data.key == matrix:
if version < 870:
raise InvalidFieldForVersionError(
"You cannot fill a 'matrix_term' as these values refer to v8.7+ studies"
)
args[matrix] = data.value
return args
return args


def _replace_matrices_according_to_frequency_and_version(
data: UpdateBindingConstProps, version: int, args: Dict[str, Any]
) -> Dict[str, Any]:
if version < 870:
matrix = {
BindingConstraintFrequency.HOURLY.value: default_bc_hourly_86,
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_86,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_86,
}[data.value].tolist()
args["values"] = matrix
else:
matrix = {
BindingConstraintFrequency.HOURLY.value: default_bc_hourly_87,
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87,
}[data.value].tolist()
args["less_term_matrix"] = matrix
args["equal_term_matrix"] = matrix
args["greater_term_matrix"] = matrix
return args


def _find_constraint_term_id(constraints_term: List[ConstraintTermDTO], constraint_term_id: str) -> int:
try:
index = [elm.id for elm in constraints_term].index(constraint_term_id)
return index
except ValueError:
return -1


def _get_constraint_id(data: Union[LinkInfoDTO, ClusterInfoDTO]) -> str:
if isinstance(data, ClusterInfoDTO):
return f"{data.area}.{data.cluster}"
area1 = data.area1 if data.area1 < data.area2 else data.area2
area2 = data.area2 if area1 == data.area1 else data.area1
return f"{area1}%{area2}"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import numpy as np
from pydantic import BaseModel, Extra, Field
from pydantic import BaseModel, Extra, Field, root_validator

from antarest.matrixstore.model import MatrixData
from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
Expand Down Expand Up @@ -92,18 +92,17 @@ class BindingConstraintMatrices(BaseModel, extra=Extra.forbid):
)
equal_term_matrix: Optional[Union[MatrixType, str]] = Field(None, description="equal term matrix for v8.7+ studies")

def __init__(self, **data: Any) -> None:
super().__init__(**data)
if data.get("values") is not None:
@root_validator()
def check_matrices(
cls, values: Dict[str, Optional[Union[MatrixType, str]]]
) -> Dict[str, Optional[Union[MatrixType, str]]]:
if values["values"]:
for term in ["less_term_matrix", "greater_term_matrix", "equal_term_matrix"]:
if data.get(term) is not None:
if values[term]:
raise ValueError(
f"You cannot fill 'values' (matrix before v8.7) and a matrix term: {term} (matrices since v8.7)"
)
self.values = data.get("values")
self.less_term_matrix = data.get("less_term_matrix")
self.greater_term_matrix = data.get("greater_term_matrix")
self.equal_term_matrix = data.get("equal_term_matrix")
return values


class AbstractBindingConstraintCommand(
Expand Down Expand Up @@ -145,7 +144,9 @@ def get_inner_matrices(self) -> List[str]:
if matrix is not None
]

def get_corresponding_matrices(self, v: Optional[Union[MatrixType, str]], old: bool, create: bool) -> Optional[str]:
def get_corresponding_matrices(
self, v: Optional[Union[MatrixType, str]], *, old: bool, create: bool
) -> Optional[str]:
constants: GeneratorMatrixConstants
constants = self.command_context.generator_matrix_constants
time_step = self.time_step
Expand Down Expand Up @@ -178,7 +179,7 @@ def get_corresponding_matrices(self, v: Optional[Union[MatrixType, str]], old: b
# pragma: no cover
raise TypeError(repr(v))

def validates_and_fills_matrices(self, version: int, create: bool) -> None:
def validates_and_fills_matrices(self, *, version: int, create: bool) -> None:
if version < 870:
self.values = self.get_corresponding_matrices(self.values, old=True, create=create)
else:
Expand Down Expand Up @@ -209,7 +210,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
binding_constraints = study_data.tree.get(["input", "bindingconstraints", "bindingconstraints"])
new_key = len(binding_constraints)
bd_id = transform_name_to_id(self.name)
self.validates_and_fills_matrices(study_data.config.version, True)
self.validates_and_fills_matrices(version=study_data.config.version, create=True)

return apply_binding_constraint(
study_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
message="Failed to retrieve existing binding constraint",
)

self.validates_and_fills_matrices(study_data.config.version, False)
self.validates_and_fills_matrices(version=study_data.config.version, create=False)

return apply_binding_constraint(
study_data,
Expand Down
Loading

0 comments on commit 59f0c66

Please sign in to comment.