Skip to content

Commit

Permalink
fix(BCs-matrices): update following code review
Browse files Browse the repository at this point in the history
  • Loading branch information
mabw-rte committed Jul 22, 2024
1 parent 294666a commit a01cedd
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 202 deletions.
1 change: 0 additions & 1 deletion antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing as t
from http import HTTPStatus

from fastapi import HTTPException
from fastapi.exceptions import HTTPException


Expand Down
46 changes: 25 additions & 21 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
DEFAULT_GROUP,
EXPECTED_MATRIX_SHAPES,
TERM_MATRICES,
BindingConstraintMatrices,
BindingConstraintPropertiesBase,
CreateBindingConstraint,
OptionalProperties,
TermMatrices,
)
from antarest.study.storage.variantstudy.model.command.remove_binding_constraint import RemoveBindingConstraint
from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint
Expand All @@ -60,10 +60,10 @@


OPERATOR_CONFLICT_MAP = {
BindingConstraintOperator.EQUAL: ["less_term_matrix", "greater_term_matrix"],
BindingConstraintOperator.GREATER: ["less_term_matrix", "equal_term_matrix"],
BindingConstraintOperator.LESS: ["equal_term_matrix", "greater_term_matrix"],
BindingConstraintOperator.BOTH: ["equal_term_matrix"],
BindingConstraintOperator.EQUAL: [TermMatrices.LESS.value, TermMatrices.GREATER.value],
BindingConstraintOperator.GREATER: [TermMatrices.LESS.value, TermMatrices.EQUAL.value],
BindingConstraintOperator.LESS: [TermMatrices.EQUAL.value, TermMatrices.GREATER.value],
BindingConstraintOperator.BOTH: [TermMatrices.EQUAL.value],
}


Expand Down Expand Up @@ -254,7 +254,7 @@ class ConstraintCreation(ConstraintInput):

@root_validator(pre=True)
def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
for _key in ["time_step"] + TERM_MATRICES:
for _key in ["time_step"] + [m.value for m in TermMatrices]:
_camel = to_camel_case(_key)
values[_key] = values.pop(_camel, values.get(_key))

Expand All @@ -272,7 +272,7 @@ def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.

# Collect the matrix shapes
matrix_shapes = {}
for _field_name in ["values"] + TERM_MATRICES:
for _field_name in ["values"] + [m.value for m in TermMatrices]:
if _matrix := values.get(_field_name):
_array = np.array(_matrix)
# We only store the shape if the array is not empty
Expand Down Expand Up @@ -353,7 +353,7 @@ def _get_references_by_widths(
references_by_width: t.Dict[int, t.List[t.Tuple[str, str]]] = {}
_total = len(bcs) * len(matrix_id_fmts)
for _index, (bc, fmt) in enumerate(itertools.product(bcs, matrix_id_fmts), 1):
if int(file_study.config.version) >= 870 and fmt not in operator_matrix_file_map.get(bc.operator, []):
if int(file_study.config.version) >= 870 and fmt not in operator_matrix_file_map[bc.operator]:
continue
bc_id = bc.id
matrix_id = fmt.format(bc_id=bc.id)
Expand Down Expand Up @@ -758,7 +758,7 @@ def update_binding_constraint(

# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
updated_matrices = [term for term in TERM_MATRICES if getattr(data, term)]
updated_matrices = [term for term in [m.value for m in TermMatrices] if getattr(data, term)]
time_step = data.time_step or existing_constraint.time_step
command.validates_and_fills_matrices(
time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False
Expand Down Expand Up @@ -930,7 +930,7 @@ def _replace_matrices_according_to_frequency_and_version(
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87,
}[data.time_step].tolist()
for term in TERM_MATRICES:
for term in [m.value for m in TermMatrices]:
if term not in args:
args[term] = matrix
return args
Expand All @@ -941,6 +941,8 @@ def check_attributes_coherence(
study_version: int,
existing_operator: t.Optional[BindingConstraintOperator] = None,
) -> None:
update_operator = data.operator or existing_operator

if study_version < 870:
if data.group:
raise InvalidFieldForVersionError(
Expand All @@ -950,19 +952,21 @@ def check_attributes_coherence(
raise InvalidFieldForVersionError("You cannot fill a 'matrix_term' as these values refer to v8.7+ studies")
elif data.values:
raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7")
elif data.operator:
elif update_operator:
conflicting_matrices = [
getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[data.operator] if getattr(data, matrix)
getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[update_operator] if getattr(data, matrix)
]
if conflicting_matrices:
raise InvalidFieldForVersionError(
f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{data.operator}'"
)
elif existing_operator:
conflicting_matrices = [
getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[existing_operator] if getattr(data, matrix)
]
if conflicting_matrices:
raise InvalidFieldForVersionError(
f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{existing_operator}'"
f"You cannot fill matrices '{OPERATOR_CONFLICT_MAP[update_operator]}' while using the operator '{update_operator}'"
)
# TODO: the default operator should be fixed somewhere so this condition can be consistent
elif [
getattr(data, matrix)
for matrix in OPERATOR_CONFLICT_MAP[BindingConstraintOperator.EQUAL]
if getattr(data, matrix)
]:
raise InvalidFieldForVersionError(
f"You cannot fill one of the matrices '{OPERATOR_CONFLICT_MAP[BindingConstraintOperator.EQUAL]}' "
"while using the operator '{BindingConstraintOperator.EQUAL}'"
)
54 changes: 37 additions & 17 deletions antarest/study/storage/rawstudy/model/filesystem/lazy_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import shutil
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, Tuple, Union, cast
from zipfile import ZipFile

from antarest.core.exceptions import ChildNotFoundError
Expand All @@ -13,16 +14,16 @@

@dataclass
class SimpleCache:
value: Any
value: t.Any
expiration_date: datetime


class LazyNode(INode, ABC, Generic[G, S, V]): # type: ignore
class LazyNode(INode, ABC, t.Generic[G, S, V]): # type: ignore
"""
Abstract left with implemented a lazy loading for its daughter implementation.
"""

ZIP_FILELIST_CACHE: Dict[str, SimpleCache] = {}
ZIP_FILELIST_CACHE: t.Dict[str, SimpleCache] = {}

def __init__(
self,
Expand All @@ -34,7 +35,7 @@ def __init__(

def _get_real_file_path(
self,
) -> Tuple[Path, Any]:
) -> t.Tuple[Path, t.Any]:
tmp_dir = None
if self.config.zip_path:
path, tmp_dir = self._extract_file_to_tmp_dir()
Expand All @@ -59,12 +60,12 @@ def file_exists(self) -> bool:

def _get(
self,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
depth: int = -1,
expanded: bool = False,
formatted: bool = True,
get_node: bool = False,
) -> Union[Union[str, G], INode[G, S, V]]:
) -> t.Union[t.Union[str, G], INode[G, S, V]]:
self._assert_url_end(url)

if get_node:
Expand All @@ -75,7 +76,7 @@ def _get(
if expanded:
return link
else:
return cast(G, self.context.resolver.resolve(link, formatted))
return t.cast(G, self.context.resolver.resolve(link, formatted))

if expanded:
return self.get_lazy_content()
Expand All @@ -84,24 +85,24 @@ def _get(

def get(
self,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
depth: int = -1,
expanded: bool = False,
formatted: bool = True,
) -> Union[str, G]:
) -> t.Union[str, G]:
output = self._get(url, depth, expanded, formatted, get_node=False)
assert not isinstance(output, INode)
return output

def get_node(
self,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
) -> INode[G, S, V]:
output = self._get(url, get_node=True)
assert isinstance(output, INode)
return output

def delete(self, url: Optional[List[str]] = None) -> None:
def delete(self, url: t.Optional[t.List[str]] = None) -> None:
self._assert_url_end(url)
if self.get_link_path().exists():
self.get_link_path().unlink()
Expand Down Expand Up @@ -131,7 +132,7 @@ def get_link_path(self) -> Path:
path = self.config.path.parent / (self.config.path.name + ".link")
return path

def save(self, data: Union[str, bytes, S], url: Optional[List[str]] = None) -> None:
def save(self, data: t.Union[str, bytes, S], url: t.Optional[t.List[str]] = None) -> None:
self._assert_not_in_zipped_file()
self._assert_url_end(url)

Expand All @@ -141,14 +142,33 @@ def save(self, data: Union[str, bytes, S], url: Optional[List[str]] = None) -> N
self.config.path.unlink()
return None

self.dump(cast(S, data), url)
self.dump(t.cast(S, data), url)
if self.get_link_path().exists():
self.get_link_path().unlink()
return None

def move_file_to(
self, target: t.Union[Path, "LazyNode[t.Any, t.Any, t.Any]"], force: bool = False, copy: bool = False
) -> None:
if isinstance(target, Path):
target.unlink(missing_ok=True)
if copy:
shutil.copy2(self.infer_path(), target)
else:
self.infer_path().rename(target)
else:
target_path = target.infer_target_path(self.infer_is_link_path())
if not force and target_path.exists():
raise FileExistsError(f"Target path {target_path} already exists")
target_path.unlink(missing_ok=True)
if copy:
shutil.copy2(self.infer_path(), target_path)
else:
self.infer_path().rename(target_path)

def get_lazy_content(
self,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
depth: int = -1,
expanded: bool = False,
) -> str:
Expand All @@ -157,7 +177,7 @@ def get_lazy_content(
@abstractmethod
def load(
self,
url: Optional[List[str]] = None,
url: t.Optional[t.List[str]] = None,
depth: int = -1,
expanded: bool = False,
formatted: bool = True,
Expand All @@ -177,7 +197,7 @@ def load(
raise NotImplementedError()

@abstractmethod
def dump(self, data: S, url: Optional[List[str]] = None) -> None:
def dump(self, data: S, url: t.Optional[t.List[str]] = None) -> None:
"""
Store data on tree.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from antarest.study.storage.variantstudy.model.command.common import CommandName
from antarest.study.storage.variantstudy.model.command.create_area import CreateArea
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
TERM_MATRICES,
CreateBindingConstraint,
TermMatrices,
)
from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster
from antarest.study.storage.variantstudy.model.command.create_district import CreateDistrict
Expand Down Expand Up @@ -115,7 +115,7 @@ def _revert_update_binding_constraint(
}

matrix_service = command.command_context.matrix_service
for matrix_name in ["values"] + TERM_MATRICES:
for matrix_name in ["values"] + [m.value for m in TermMatrices]:
matrix = getattr(command, matrix_name)
if matrix is not None:
args[matrix_name] = matrix_service.get_matrix_id(matrix)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import typing as t
from abc import ABCMeta
from enum import Enum

import numpy as np
from pydantic import BaseModel, Extra, Field, root_validator, validator
Expand All @@ -23,7 +24,6 @@
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
from antarest.study.storage.variantstudy.model.model import CommandDTO

TERM_MATRICES = ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]
DEFAULT_GROUP = "default"

MatrixType = t.List[t.List[MatrixData]]
Expand All @@ -35,6 +35,12 @@
}


class TermMatrices(str, Enum):
LESS = "less_term_matrix"
GREATER = "greater_term_matrix"
EQUAL = "equal_term_matrix"


def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType, version: int) -> None:
"""
Check the binding constraint's matrix values for the specified time step.
Expand Down Expand Up @@ -216,7 +222,7 @@ def to_dto(self) -> CommandDTO:
args["group"] = self.group

matrix_service = self.command_context.matrix_service
for matrix_name in TERM_MATRICES + ["values"]:
for matrix_name in [m.value for m in TermMatrices] + ["values"]:
matrix_attr = getattr(self, matrix_name, None)
if matrix_attr is not None:
args[matrix_name] = matrix_service.get_matrix_id(matrix_attr)
Expand Down Expand Up @@ -363,11 +369,9 @@ def apply_binding_constraint(
BindingConstraintOperator.BOTH: [(self.less_term_matrix, "lt"), (self.greater_term_matrix, "gt")],
}

current_operator = self.operator or BindingConstraintOperator(
[bc for bc in binding_constraints.values() if bc.get("id") == bd_id][0].get("operator")
)
current_operator = self.operator or BindingConstraintOperator(binding_constraints[new_key]["operator"])

for matrix_term, matrix_alias in operator_matrices_map.get(current_operator, []):
for matrix_term, matrix_alias in operator_matrices_map[current_operator]:
if matrix_term:
if not isinstance(matrix_term, str): # pragma: no cover
raise TypeError(repr(matrix_term))
Expand Down Expand Up @@ -449,7 +453,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]:
args[prop] = other_command[prop]

matrix_service = self.command_context.matrix_service
for matrix_name in ["values"] + TERM_MATRICES:
for matrix_name in ["values"] + [m.value for m in TermMatrices]:
self_matrix = getattr(self, matrix_name) # matrix, ID or `None`
other_matrix = getattr(other, matrix_name) # matrix, ID or `None`
self_matrix_id = None if self_matrix is None else matrix_service.get_matrix_id(self_matrix)
Expand Down
Loading

0 comments on commit a01cedd

Please sign in to comment.