Skip to content

Commit

Permalink
fix(BCs-matrices): update following code review
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Jul 18, 2024
1 parent 294666a commit 1a487dc
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 34 deletions.
1 change: 0 additions & 1 deletion antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing as t
from http import HTTPStatus

from fastapi import HTTPException
from fastapi.exceptions import HTTPException


Expand Down
46 changes: 25 additions & 21 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
DEFAULT_GROUP,
EXPECTED_MATRIX_SHAPES,
TERM_MATRICES,
BindingConstraintMatrices,
BindingConstraintPropertiesBase,
CreateBindingConstraint,
OptionalProperties,
TermMatrices,
)
from antarest.study.storage.variantstudy.model.command.remove_binding_constraint import RemoveBindingConstraint
from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint
Expand All @@ -60,10 +60,10 @@


OPERATOR_CONFLICT_MAP = {
BindingConstraintOperator.EQUAL: ["less_term_matrix", "greater_term_matrix"],
BindingConstraintOperator.GREATER: ["less_term_matrix", "equal_term_matrix"],
BindingConstraintOperator.LESS: ["equal_term_matrix", "greater_term_matrix"],
BindingConstraintOperator.BOTH: ["equal_term_matrix"],
BindingConstraintOperator.EQUAL: [TermMatrices.LESS.value, TermMatrices.GREATER.value],
BindingConstraintOperator.GREATER: [TermMatrices.LESS.value, TermMatrices.EQUAL.value],
BindingConstraintOperator.LESS: [TermMatrices.EQUAL.value, TermMatrices.GREATER.value],
BindingConstraintOperator.BOTH: [TermMatrices.EQUAL.value],
}


Expand Down Expand Up @@ -254,7 +254,7 @@ class ConstraintCreation(ConstraintInput):

@root_validator(pre=True)
def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
for _key in ["time_step"] + TERM_MATRICES:
for _key in ["time_step"] + [m.value for m in TermMatrices]:
_camel = to_camel_case(_key)
values[_key] = values.pop(_camel, values.get(_key))

Expand All @@ -272,7 +272,7 @@ def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.

# Collect the matrix shapes
matrix_shapes = {}
for _field_name in ["values"] + TERM_MATRICES:
for _field_name in ["values"] + [m.value for m in TermMatrices]:
if _matrix := values.get(_field_name):
_array = np.array(_matrix)
# We only store the shape if the array is not empty
Expand Down Expand Up @@ -353,7 +353,7 @@ def _get_references_by_widths(
references_by_width: t.Dict[int, t.List[t.Tuple[str, str]]] = {}
_total = len(bcs) * len(matrix_id_fmts)
for _index, (bc, fmt) in enumerate(itertools.product(bcs, matrix_id_fmts), 1):
if int(file_study.config.version) >= 870 and fmt not in operator_matrix_file_map.get(bc.operator, []):
if int(file_study.config.version) >= 870 and fmt not in operator_matrix_file_map[bc.operator]:
continue
bc_id = bc.id
matrix_id = fmt.format(bc_id=bc.id)
Expand Down Expand Up @@ -758,7 +758,7 @@ def update_binding_constraint(

# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
updated_matrices = [term for term in TERM_MATRICES if getattr(data, term)]
updated_matrices = [term for term in [m.value for m in TermMatrices] if getattr(data, term)]
time_step = data.time_step or existing_constraint.time_step
command.validates_and_fills_matrices(
time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False
Expand Down Expand Up @@ -930,7 +930,7 @@ def _replace_matrices_according_to_frequency_and_version(
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87,
}[data.time_step].tolist()
for term in TERM_MATRICES:
for term in [m.value for m in TermMatrices]:
if term not in args:
args[term] = matrix
return args
Expand All @@ -941,6 +941,8 @@ def check_attributes_coherence(
study_version: int,
existing_operator: t.Optional[BindingConstraintOperator] = None,
) -> None:
update_operator = data.operator or existing_operator

if study_version < 870:
if data.group:
raise InvalidFieldForVersionError(
Expand All @@ -950,19 +952,21 @@ def check_attributes_coherence(
raise InvalidFieldForVersionError("You cannot fill a 'matrix_term' as these values refer to v8.7+ studies")
elif data.values:
raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7")
elif data.operator:
elif update_operator:
conflicting_matrices = [
getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[data.operator] if getattr(data, matrix)
getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[update_operator] if getattr(data, matrix)
]
if conflicting_matrices:
raise InvalidFieldForVersionError(
f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{data.operator}'"
)
elif existing_operator:
conflicting_matrices = [
getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[existing_operator] if getattr(data, matrix)
]
if conflicting_matrices:
raise InvalidFieldForVersionError(
f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{existing_operator}'"
f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{update_operator}'"
)
# TODO: the default operator should be fixed somewhere so this condition can be consistent
elif [
getattr(data, matrix)
for matrix in OPERATOR_CONFLICT_MAP[BindingConstraintOperator.EQUAL]
if getattr(data, matrix)
]:
raise InvalidFieldForVersionError(
f"You cannot fill one of the matrices '{OPERATOR_CONFLICT_MAP[BindingConstraintOperator.EQUAL]}' "
"while using the operator '{BindingConstraintOperator.EQUAL}'"
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from antarest.study.storage.variantstudy.model.command.common import CommandName
from antarest.study.storage.variantstudy.model.command.create_area import CreateArea
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
TERM_MATRICES,
CreateBindingConstraint,
TermMatrices,
)
from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster
from antarest.study.storage.variantstudy.model.command.create_district import CreateDistrict
Expand Down Expand Up @@ -115,7 +115,7 @@ def _revert_update_binding_constraint(
}

matrix_service = command.command_context.matrix_service
for matrix_name in ["values"] + TERM_MATRICES:
for matrix_name in ["values"] + [m.value for m in TermMatrices]:
matrix = getattr(command, matrix_name)
if matrix is not None:
args[matrix_name] = matrix_service.get_matrix_id(matrix)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import typing as t
from abc import ABCMeta
from enum import Enum

import numpy as np
from pydantic import BaseModel, Extra, Field, root_validator, validator
Expand All @@ -23,7 +24,6 @@
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
from antarest.study.storage.variantstudy.model.model import CommandDTO

TERM_MATRICES = ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]
DEFAULT_GROUP = "default"

MatrixType = t.List[t.List[MatrixData]]
Expand All @@ -35,6 +35,12 @@
}


class TermMatrices(str, Enum):
LESS = "less_term_matrix"
GREATER = "greater_term_matrix"
EQUAL = "equal_term_matrix"


def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType, version: int) -> None:
"""
Check the binding constraint's matrix values for the specified time step.
Expand Down Expand Up @@ -216,7 +222,7 @@ def to_dto(self) -> CommandDTO:
args["group"] = self.group

matrix_service = self.command_context.matrix_service
for matrix_name in TERM_MATRICES + ["values"]:
for matrix_name in [m.value for m in TermMatrices] + ["values"]:
matrix_attr = getattr(self, matrix_name, None)
if matrix_attr is not None:
args[matrix_name] = matrix_service.get_matrix_id(matrix_attr)
Expand Down Expand Up @@ -363,11 +369,9 @@ def apply_binding_constraint(
BindingConstraintOperator.BOTH: [(self.less_term_matrix, "lt"), (self.greater_term_matrix, "gt")],
}

current_operator = self.operator or BindingConstraintOperator(
[bc for bc in binding_constraints.values() if bc.get("id") == bd_id][0].get("operator")
)
current_operator = self.operator or BindingConstraintOperator(binding_constraints[new_key]["operator"])

for matrix_term, matrix_alias in operator_matrices_map.get(current_operator, []):
for matrix_term, matrix_alias in operator_matrices_map[current_operator]:
if matrix_term:
if not isinstance(matrix_term, str): # pragma: no cover
raise TypeError(repr(matrix_term))
Expand Down Expand Up @@ -449,7 +453,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]:
args[prop] = other_command[prop]

matrix_service = self.command_context.matrix_service
for matrix_name in ["values"] + TERM_MATRICES:
for matrix_name in ["values"] + [m.value for m in TermMatrices]:
self_matrix = getattr(self, matrix_name) # matrix, ID or `None`
other_matrix = getattr(other, matrix_name) # matrix, ID or `None`
self_matrix_id = None if self_matrix is None else matrix_service.get_matrix_id(self_matrix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
DEFAULT_GROUP,
TERM_MATRICES,
AbstractBindingConstraintCommand,
TermMatrices,
create_binding_constraint_config,
)
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
Expand Down Expand Up @@ -262,7 +262,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
new_operator = BindingConstraintOperator(self.operator)
_update_matrices_names(study_data, self.id, existing_operator, new_operator)

updated_matrices = [term for term in TERM_MATRICES if hasattr(self, term) and getattr(self, term)]
updated_matrices = [
term for term in [m.value for m in TermMatrices] if hasattr(self, term) and getattr(self, term)
]
study_version = study_data.config.version
time_step = self.time_step or BindingConstraintFrequency(actual_cfg.get("type"))
self.validates_and_fills_matrices(
Expand All @@ -287,7 +289,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
return super().apply_binding_constraint(study_data, binding_constraints, index, self.id, old_groups=old_groups)

def to_dto(self) -> CommandDTO:
matrices = ["values"] + TERM_MATRICES
matrices = ["values"] + [m.value for m in TermMatrices]
matrix_service = self.command_context.matrix_service

excluded_fields = frozenset(ICommand.__fields__)
Expand Down

0 comments on commit 1a487dc

Please sign in to comment.