Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jan 24, 2024
1 parent 59f0c66 commit c2c1bc7
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
7 changes: 5 additions & 2 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def create_binding_constraint(

# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
command.validates_and_fills_matrices(version=version, create=True)
command.validates_and_fills_matrices(specific_matrices=None, version=version, create=True)

file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
Expand Down Expand Up @@ -274,7 +274,10 @@ def update_binding_constraint(
command = UpdateBindingConstraint(**args)
# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
command.validates_and_fills_matrices(version=study_version, create=False)
updated_matrix = None
if data.key in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]:
updated_matrix = [data.key]
command.validates_and_fills_matrices(specific_matrices=updated_matrix, version=study_version, create=False)

execute_or_add_commands(study, file_study, [command], self.storage_service)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
MatrixType = List[List[MatrixData]]


def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType, old: bool) -> None:
def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType, version: int) -> None:
"""
Check the binding constraint's matrix values for the specified time step.
Args:
time_step: The frequency of the binding constraint: "hourly", "daily" or "weekly".
values: The binding constraint's 2nd member matrix.
old: True if the study's version is prior to v8.7, otherwise False.
version: Study version.
Raises:
ValueError:
Expand All @@ -62,7 +62,7 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp
array = np.array(values, dtype=np.float64)
expected_shape = shapes[time_step]
actual_shape = array.shape
if old:
if version < 870:
if actual_shape != expected_shape:
raise ValueError(f"Invalid matrix shape {actual_shape}, expected {expected_shape}")
elif actual_shape[0] != expected_shape[0]:
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_inner_matrices(self) -> List[str]:
]

def get_corresponding_matrices(
self, v: Optional[Union[MatrixType, str]], *, old: bool, create: bool
self, v: Optional[Union[MatrixType, str]], version: int, create: bool
) -> Optional[str]:
constants: GeneratorMatrixConstants
constants = self.command_context.generator_matrix_constants
Expand All @@ -168,26 +168,31 @@ def get_corresponding_matrices(
BindingConstraintFrequency.WEEKLY: constants.get_binding_constraint_daily_weekly_87,
},
}
return methods["before_v87"][time_step]() if old else methods["after_v87"][time_step]()
return methods["before_v87"][time_step]() if version < 870 else methods["after_v87"][time_step]()
if isinstance(v, str):
# Check the matrix link
return validate_matrix(v, {"command_context": self.command_context})
if isinstance(v, list):
check_matrix_values(time_step, v, old=old)
check_matrix_values(time_step, v, version)
return validate_matrix(v, {"command_context": self.command_context})
# Invalid datatype
# pragma: no cover
raise TypeError(repr(v))

def validates_and_fills_matrices(self, *, version: int, create: bool) -> None:
def validates_and_fills_matrices(
self, *, specific_matrices: Optional[List[str]], version: int, create: bool
) -> None:
if version < 870:
self.values = self.get_corresponding_matrices(self.values, old=True, create=create)
self.values = self.get_corresponding_matrices(self.values, version, create)
elif specific_matrices:
for matrix in specific_matrices:
self.__setattr__(
matrix, self.get_corresponding_matrices(self.__getattribute__(matrix), version, create)
)
else:
self.less_term_matrix = self.get_corresponding_matrices(self.less_term_matrix, old=False, create=create)
self.greater_term_matrix = self.get_corresponding_matrices(
self.greater_term_matrix, old=False, create=create
)
self.equal_term_matrix = self.get_corresponding_matrices(self.equal_term_matrix, old=False, create=create)
self.less_term_matrix = self.get_corresponding_matrices(self.less_term_matrix, version, create)
self.greater_term_matrix = self.get_corresponding_matrices(self.greater_term_matrix, version, create)
self.equal_term_matrix = self.get_corresponding_matrices(self.equal_term_matrix, version, create)


class CreateBindingConstraint(AbstractBindingConstraintCommand):
Expand All @@ -210,7 +215,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
binding_constraints = study_data.tree.get(["input", "bindingconstraints", "bindingconstraints"])
new_key = len(binding_constraints)
bd_id = transform_name_to_id(self.name)
self.validates_and_fills_matrices(version=study_data.config.version, create=True)
self.validates_and_fills_matrices(specific_matrices=None, version=study_data.config.version, create=True)

return apply_binding_constraint(
study_data,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import validator
from typing import Any, Dict, List, Optional, Tuple

from antarest.core.model import JSON
from antarest.matrixstore.model import MatrixData
Expand Down Expand Up @@ -53,7 +51,10 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
message="Failed to retrieve existing binding constraint",
)

self.validates_and_fills_matrices(version=study_data.config.version, create=False)
# fmt: off
updated_matrices = [term for term in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"] if self.__getattribute__(term)]
self.validates_and_fills_matrices(specific_matrices=updated_matrices or None, version=study_data.config.version, create=False)
# fmt: on

return apply_binding_constraint(
study_data,
Expand Down

0 comments on commit c2c1bc7

Please sign in to comment.