diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index e2519a284b..6fce8f0213 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -2,7 +2,6 @@ import typing as t from http import HTTPStatus -from fastapi import HTTPException from fastapi.exceptions import HTTPException diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index e9366c4510..d74f4e0010 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -46,11 +46,11 @@ from antarest.study.storage.variantstudy.model.command.create_binding_constraint import ( DEFAULT_GROUP, EXPECTED_MATRIX_SHAPES, - TERM_MATRICES, BindingConstraintMatrices, BindingConstraintPropertiesBase, CreateBindingConstraint, OptionalProperties, + TermMatrices, ) from antarest.study.storage.variantstudy.model.command.remove_binding_constraint import RemoveBindingConstraint from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint @@ -60,10 +60,10 @@ OPERATOR_CONFLICT_MAP = { - BindingConstraintOperator.EQUAL: ["less_term_matrix", "greater_term_matrix"], - BindingConstraintOperator.GREATER: ["less_term_matrix", "equal_term_matrix"], - BindingConstraintOperator.LESS: ["equal_term_matrix", "greater_term_matrix"], - BindingConstraintOperator.BOTH: ["equal_term_matrix"], + BindingConstraintOperator.EQUAL: [TermMatrices.LESS.value, TermMatrices.GREATER.value], + BindingConstraintOperator.GREATER: [TermMatrices.LESS.value, TermMatrices.EQUAL.value], + BindingConstraintOperator.LESS: [TermMatrices.EQUAL.value, TermMatrices.GREATER.value], + BindingConstraintOperator.BOTH: [TermMatrices.EQUAL.value], } @@ -254,7 +254,7 @@ class ConstraintCreation(ConstraintInput): @root_validator(pre=True) def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - for _key in ["time_step"] + TERM_MATRICES: + for _key in ["time_step"] + [m.value for m in TermMatrices]: _camel = to_camel_case(_key) values[_key] = values.pop(_camel, values.get(_key)) @@ -272,7 +272,7 @@ def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t. # Collect the matrix shapes matrix_shapes = {} - for _field_name in ["values"] + TERM_MATRICES: + for _field_name in ["values"] + [m.value for m in TermMatrices]: if _matrix := values.get(_field_name): _array = np.array(_matrix) # We only store the shape if the array is not empty @@ -353,7 +353,7 @@ def _get_references_by_widths( references_by_width: t.Dict[int, t.List[t.Tuple[str, str]]] = {} _total = len(bcs) * len(matrix_id_fmts) for _index, (bc, fmt) in enumerate(itertools.product(bcs, matrix_id_fmts), 1): - if int(file_study.config.version) >= 870 and fmt not in operator_matrix_file_map.get(bc.operator, []): + if int(file_study.config.version) >= 870 and fmt not in operator_matrix_file_map[bc.operator]: continue bc_id = bc.id matrix_id = fmt.format(bc_id=bc.id) @@ -758,7 +758,7 @@ def update_binding_constraint( # Validates the matrices. Needed when the study is a variant because we only append the command to the list if isinstance(study, VariantStudy): - updated_matrices = [term for term in TERM_MATRICES if getattr(data, term)] + updated_matrices = [term for term in [m.value for m in TermMatrices] if getattr(data, term)] time_step = data.time_step or existing_constraint.time_step command.validates_and_fills_matrices( time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False @@ -930,7 +930,7 @@ def _replace_matrices_according_to_frequency_and_version( BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87, BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87, }[data.time_step].tolist() - for term in TERM_MATRICES: + for term in [m.value for m in TermMatrices]: if term not in args: args[term] = matrix return args @@ -941,6 +941,8 @@ def check_attributes_coherence( study_version: int, existing_operator: t.Optional[BindingConstraintOperator] = None, ) -> None: + update_operator = data.operator or existing_operator + if study_version < 870: if data.group: raise InvalidFieldForVersionError( @@ -950,19 +952,21 @@ def check_attributes_coherence( raise InvalidFieldForVersionError("You cannot fill a 'matrix_term' as these values refer to v8.7+ studies") elif data.values: raise InvalidFieldForVersionError("You cannot fill 'values' as it refers to the matrix before v8.7") - elif data.operator: + elif update_operator: conflicting_matrices = [ - getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[data.operator] if getattr(data, matrix) + getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[update_operator] if getattr(data, matrix) ] if conflicting_matrices: raise InvalidFieldForVersionError( - f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{data.operator}'" - ) - elif existing_operator: - conflicting_matrices = [ - getattr(data, matrix) for matrix in OPERATOR_CONFLICT_MAP[existing_operator] if getattr(data, matrix) - ] - if conflicting_matrices: - raise InvalidFieldForVersionError( - f"You cannot fill matrices '{conflicting_matrices}' while using the operator '{existing_operator}'" + f"You cannot fill matrices '{OPERATOR_CONFLICT_MAP[update_operator]}' while using the operator '{update_operator}'" ) + # TODO: the default operator should be fixed somewhere so this condition can be consistent + elif [ + getattr(data, matrix) + for matrix in OPERATOR_CONFLICT_MAP[BindingConstraintOperator.EQUAL] + if getattr(data, matrix) + ]: + raise InvalidFieldForVersionError( + f"You cannot fill one of the matrices '{OPERATOR_CONFLICT_MAP[BindingConstraintOperator.EQUAL]}' " + "while using the operator '{BindingConstraintOperator.EQUAL}'" + ) diff --git a/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py b/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py index d2a7e3189f..978c0e4250 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py +++ b/antarest/study/storage/rawstudy/model/filesystem/lazy_node.py @@ -1,8 +1,9 @@ +import shutil +import typing as t from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict, Generic, List, Optional, Tuple, Union, cast from zipfile import ZipFile from antarest.core.exceptions import ChildNotFoundError @@ -13,16 +14,16 @@ @dataclass class SimpleCache: - value: Any + value: t.Any expiration_date: datetime -class LazyNode(INode, ABC, Generic[G, S, V]): # type: ignore +class LazyNode(INode, ABC, t.Generic[G, S, V]): # type: ignore """ Abstract left with implemented a lazy loading for its daughter implementation. """ - ZIP_FILELIST_CACHE: Dict[str, SimpleCache] = {} + ZIP_FILELIST_CACHE: t.Dict[str, SimpleCache] = {} def __init__( self, @@ -34,7 +35,7 @@ def __init__( def _get_real_file_path( self, - ) -> Tuple[Path, Any]: + ) -> t.Tuple[Path, t.Any]: tmp_dir = None if self.config.zip_path: path, tmp_dir = self._extract_file_to_tmp_dir() @@ -59,12 +60,12 @@ def file_exists(self) -> bool: def _get( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, depth: int = -1, expanded: bool = False, formatted: bool = True, get_node: bool = False, - ) -> Union[Union[str, G], INode[G, S, V]]: + ) -> t.Union[t.Union[str, G], INode[G, S, V]]: self._assert_url_end(url) if get_node: @@ -75,7 +76,7 @@ def _get( if expanded: return link else: - return cast(G, self.context.resolver.resolve(link, formatted)) + return t.cast(G, self.context.resolver.resolve(link, formatted)) if expanded: return self.get_lazy_content() @@ -84,24 +85,24 @@ def _get( def get( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, depth: int = -1, expanded: bool = False, formatted: bool = True, - ) -> Union[str, G]: + ) -> t.Union[str, G]: output = self._get(url, depth, expanded, formatted, get_node=False) assert not isinstance(output, INode) return output def get_node( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, ) -> INode[G, S, V]: output = self._get(url, get_node=True) assert isinstance(output, INode) return output - def delete(self, url: Optional[List[str]] = None) -> None: + def delete(self, url: t.Optional[t.List[str]] = None) -> None: self._assert_url_end(url) if self.get_link_path().exists(): self.get_link_path().unlink() @@ -131,7 +132,7 @@ def get_link_path(self) -> Path: path = self.config.path.parent / (self.config.path.name + ".link") return path - def save(self, data: Union[str, bytes, S], url: Optional[List[str]] = None) -> None: + def save(self, data: t.Union[str, bytes, S], url: t.Optional[t.List[str]] = None) -> None: self._assert_not_in_zipped_file() self._assert_url_end(url) @@ -141,14 +142,33 @@ def save(self, data: Union[str, bytes, S], url: Optional[List[str]] = None) -> N self.config.path.unlink() return None - self.dump(cast(S, data), url) + self.dump(t.cast(S, data), url) if self.get_link_path().exists(): self.get_link_path().unlink() return None + def move_file_to( + self, target: t.Union[Path, "LazyNode[t.Any, t.Any, t.Any]"], force: bool = False, copy: bool = False + ) -> None: + if isinstance(target, Path): + target.unlink(missing_ok=True) + if copy: + shutil.copy2(self.infer_path(), target) + else: + self.infer_path().rename(target) + else: + target_path = target.infer_target_path(self.infer_is_link_path()) + if not force and target_path.exists(): + raise FileExistsError(f"Target path {target_path} already exists") + target_path.unlink(missing_ok=True) + if copy: + shutil.copy2(self.infer_path(), target_path) + else: + self.infer_path().rename(target_path) + def get_lazy_content( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, depth: int = -1, expanded: bool = False, ) -> str: @@ -157,7 +177,7 @@ def get_lazy_content( @abstractmethod def load( self, - url: Optional[List[str]] = None, + url: t.Optional[t.List[str]] = None, depth: int = -1, expanded: bool = False, formatted: bool = True, @@ -177,7 +197,7 @@ def load( raise NotImplementedError() @abstractmethod - def dump(self, data: S, url: Optional[List[str]] = None) -> None: + def dump(self, data: S, url: t.Optional[t.List[str]] = None) -> None: """ Store data on tree. diff --git a/antarest/study/storage/variantstudy/business/command_reverter.py b/antarest/study/storage/variantstudy/business/command_reverter.py index 5025e2a654..c60cfad601 100644 --- a/antarest/study/storage/variantstudy/business/command_reverter.py +++ b/antarest/study/storage/variantstudy/business/command_reverter.py @@ -8,8 +8,8 @@ from antarest.study.storage.variantstudy.model.command.common import CommandName from antarest.study.storage.variantstudy.model.command.create_area import CreateArea from antarest.study.storage.variantstudy.model.command.create_binding_constraint import ( - TERM_MATRICES, CreateBindingConstraint, + TermMatrices, ) from antarest.study.storage.variantstudy.model.command.create_cluster import CreateCluster from antarest.study.storage.variantstudy.model.command.create_district import CreateDistrict @@ -115,7 +115,7 @@ def _revert_update_binding_constraint( } matrix_service = command.command_context.matrix_service - for matrix_name in ["values"] + TERM_MATRICES: + for matrix_name in ["values"] + [m.value for m in TermMatrices]: matrix = getattr(command, matrix_name) if matrix is not None: args[matrix_name] = matrix_service.get_matrix_id(matrix) diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 78f5afe4a9..d9744af6bb 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -1,6 +1,7 @@ import json import typing as t from abc import ABCMeta +from enum import Enum import numpy as np from pydantic import BaseModel, Extra, Field, root_validator, validator @@ -23,7 +24,6 @@ from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand from antarest.study.storage.variantstudy.model.model import CommandDTO -TERM_MATRICES = ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"] DEFAULT_GROUP = "default" MatrixType = t.List[t.List[MatrixData]] @@ -35,6 +35,12 @@ } +class TermMatrices(str, Enum): + LESS = "less_term_matrix" + GREATER = "greater_term_matrix" + EQUAL = "equal_term_matrix" + + def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType, version: int) -> None: """ Check the binding constraint's matrix values for the specified time step. @@ -216,7 +222,7 @@ def to_dto(self) -> CommandDTO: args["group"] = self.group matrix_service = self.command_context.matrix_service - for matrix_name in TERM_MATRICES + ["values"]: + for matrix_name in [m.value for m in TermMatrices] + ["values"]: matrix_attr = getattr(self, matrix_name, None) if matrix_attr is not None: args[matrix_name] = matrix_service.get_matrix_id(matrix_attr) @@ -363,11 +369,9 @@ def apply_binding_constraint( BindingConstraintOperator.BOTH: [(self.less_term_matrix, "lt"), (self.greater_term_matrix, "gt")], } - current_operator = self.operator or BindingConstraintOperator( - [bc for bc in binding_constraints.values() if bc.get("id") == bd_id][0].get("operator") - ) + current_operator = self.operator or BindingConstraintOperator(binding_constraints[new_key]["operator"]) - for matrix_term, matrix_alias in operator_matrices_map.get(current_operator, []): + for matrix_term, matrix_alias in operator_matrices_map[current_operator]: if matrix_term: if not isinstance(matrix_term, str): # pragma: no cover raise TypeError(repr(matrix_term)) @@ -449,7 +453,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: args[prop] = other_command[prop] matrix_service = self.command_context.matrix_service - for matrix_name in ["values"] + TERM_MATRICES: + for matrix_name in ["values"] + [m.value for m in TermMatrices]: self_matrix = getattr(self, matrix_name) # matrix, ID or `None` other_matrix = getattr(other, matrix_name) # matrix, ID or `None` self_matrix_id = None if self_matrix is None else matrix_service.get_matrix_id(self_matrix) diff --git a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py index 0c8cf959e4..9cb23d91f8 100644 --- a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py @@ -1,9 +1,5 @@ import json -import shutil import typing as t -from pathlib import Path - -from pydantic import BaseModel, Field from antarest.core.model import JSON from antarest.matrixstore.model import MatrixData @@ -13,13 +9,12 @@ ) from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy -from antarest.study.storage.rawstudy.model.filesystem.folder_node import FolderNode from antarest.study.storage.rawstudy.model.filesystem.lazy_node import LazyNode from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput from antarest.study.storage.variantstudy.model.command.create_binding_constraint import ( DEFAULT_GROUP, - TERM_MATRICES, AbstractBindingConstraintCommand, + TermMatrices, create_binding_constraint_config, ) from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand @@ -33,115 +28,12 @@ BindingConstraintOperator.GREATER: "gt", } - -class MatrixInputTargetPaths(BaseModel, frozen=True, extra="forbid"): - """ - Model used to store the input and target paths for matrices. - """ - - matrix_input_paths: t.List[Path] = Field(..., min_items=1, max_items=2) - matrix_target_paths: t.List[Path] = Field(..., min_items=1, max_items=2) - - -def _infer_input_and_target_paths( - parent_node: FolderNode, - binding_constraint_id: str, - existing_operator: BindingConstraintOperator, - new_operator: BindingConstraintOperator, -) -> MatrixInputTargetPaths: - """ - Infer the target paths of the matrices to update according to the existing and new operators. - - Args: - parent_node: the parent folder node - binding_constraint_id: the binding constraint ID - existing_operator: the existing operator - new_operator: the new operator - - Returns: - the matrix input and target paths - """ - # new and existing operators should be different - assert existing_operator != new_operator, "Existing and new operators should be different" - - matrix_node_lt = parent_node.get_node([f"{binding_constraint_id}_lt"]) - assert isinstance( - matrix_node_lt, LazyNode - ), f"Node type not handled yet: LazyNode expected, got {type(matrix_node_lt)}" - matrix_node_eq = parent_node.get_node([f"{binding_constraint_id}_eq"]) - assert isinstance( - matrix_node_eq, LazyNode - ), f"Node type not handled yet: LazyNode expected, got {type(matrix_node_eq)}" - matrix_node_gt = parent_node.get_node([f"{binding_constraint_id}_gt"]) - assert isinstance( - matrix_node_gt, LazyNode - ), f"Node type not handled yet: LazyNode expected, got {type(matrix_node_gt)}" - - is_link_eq = matrix_node_eq.infer_is_link_path() - is_link_lt = matrix_node_lt.infer_is_link_path() - is_link_gt = matrix_node_gt.infer_is_link_path() - - if existing_operator != BindingConstraintOperator.BOTH and new_operator != BindingConstraintOperator.BOTH: - matrix_node = parent_node.get_node([f"{binding_constraint_id}_{ALIAS_OPERATOR_MAP[existing_operator]}"]) - assert isinstance( - matrix_node, LazyNode - ), f"Node type not handled yet: LazyNode expected, got {type(matrix_node)}" - new_matrix_node = parent_node.get_node([f"{binding_constraint_id}_{ALIAS_OPERATOR_MAP[new_operator]}"]) - assert isinstance( - new_matrix_node, LazyNode - ), f"Node type not handled yet: LazyNode expected, got {type(new_matrix_node)}" - is_link = matrix_node.infer_is_link_path() - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node.infer_path()], - matrix_target_paths=[new_matrix_node.infer_target_path(is_link)], - ) - elif new_operator == BindingConstraintOperator.BOTH: - if existing_operator == BindingConstraintOperator.EQUAL: - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node_eq.infer_path()], - matrix_target_paths=[ - matrix_node_lt.infer_target_path(is_link_eq), - matrix_node_gt.infer_target_path(is_link_eq), - ], - ) - elif existing_operator == BindingConstraintOperator.LESS: - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node_lt.infer_path()], - matrix_target_paths=[matrix_node_lt.infer_path(), matrix_node_gt.infer_target_path(is_link_lt)], - ) - elif existing_operator == BindingConstraintOperator.GREATER: - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node_gt.infer_path()], - matrix_target_paths=[matrix_node_lt.infer_target_path(is_link_gt), matrix_node_gt.infer_path()], - ) - else: - raise NotImplementedError( - f"Case not handled yet: existing_operator={existing_operator}, new_operator={new_operator}" - ) - elif existing_operator == BindingConstraintOperator.BOTH: - if new_operator == BindingConstraintOperator.EQUAL: - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node_lt.infer_path(), matrix_node_gt.infer_path()], - matrix_target_paths=[matrix_node_eq.infer_target_path(is_link_lt)], - ) - elif new_operator == BindingConstraintOperator.LESS: - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node_lt.infer_path(), matrix_node_gt.infer_path()], - matrix_target_paths=[matrix_node_lt.infer_target_path(is_link_lt)], - ) - elif new_operator == BindingConstraintOperator.GREATER: - return MatrixInputTargetPaths( - matrix_input_paths=[matrix_node_lt.infer_path(), matrix_node_gt.infer_path()], - matrix_target_paths=[matrix_node_gt.infer_target_path(is_link_gt)], - ) - else: - raise NotImplementedError( - f"Case not handled yet: existing_operator={existing_operator}, new_operator={new_operator}" - ) - else: - raise NotImplementedError( - f"Case not handled yet: existing_operator={existing_operator}, new_operator={new_operator}" - ) +HANDLED_OPERATORS = [ + BindingConstraintOperator.EQUAL, + BindingConstraintOperator.LESS, + BindingConstraintOperator.GREATER, + BindingConstraintOperator.BOTH, +] def _update_matrices_names( @@ -163,51 +55,54 @@ def _update_matrices_names( NotImplementedError: if the case is not handled """ - if existing_operator == new_operator: - return - parent_folder_node = file_study.tree.get_node(["input", "bindingconstraints"]) - assert isinstance(parent_folder_node, FolderNode), f"Node type not handled yet: {type(parent_folder_node)}" - - matrix_paths = _infer_input_and_target_paths( - parent_folder_node, binding_constraint_id, existing_operator, new_operator - ) + matrix_lt = parent_folder_node.get_node([f"{binding_constraint_id}_lt"]) + assert isinstance(matrix_lt, LazyNode), f"Node type not handled yet: LazyNode expected, got {type(matrix_lt)}" + matrix_eq = parent_folder_node.get_node([f"{binding_constraint_id}_eq"]) + assert isinstance(matrix_eq, LazyNode), f"Node type not handled yet: LazyNode expected, got {type(matrix_eq)}" + matrix_gt = parent_folder_node.get_node([f"{binding_constraint_id}_gt"]) + assert isinstance(matrix_gt, LazyNode), f"Node type not handled yet: LazyNode expected, got {type(matrix_gt)}" # TODO: due to legacy matrices generation, we need to check if the new matrix file already exists # and if it does, we need to first remove it before renaming the existing matrix file - if existing_operator != BindingConstraintOperator.BOTH and new_operator != BindingConstraintOperator.BOTH: - (matrix_path,) = matrix_paths.matrix_input_paths - (new_matrix_path,) = matrix_paths.matrix_target_paths - new_matrix_path.unlink(missing_ok=True) - matrix_path.rename(new_matrix_path) + if (existing_operator not in HANDLED_OPERATORS) or (new_operator not in HANDLED_OPERATORS): + raise NotImplementedError( + f"Case not handled yet: existing_operator={existing_operator}, new_operator={new_operator}" + ) + elif existing_operator == new_operator: + return # nothing to do + elif existing_operator != BindingConstraintOperator.BOTH and new_operator != BindingConstraintOperator.BOTH: + matrix_node = parent_folder_node.get_node([f"{binding_constraint_id}_{ALIAS_OPERATOR_MAP[existing_operator]}"]) + assert isinstance( + matrix_node, LazyNode + ), f"Node type not handled yet: LazyNode expected, got {type(matrix_node)}" + new_matrix_node = parent_folder_node.get_node([f"{binding_constraint_id}_{ALIAS_OPERATOR_MAP[new_operator]}"]) + assert isinstance( + new_matrix_node, LazyNode + ), f"Node type not handled yet: LazyNode expected, got {type(new_matrix_node)}" + matrix_node.move_file_to(new_matrix_node, force=True) elif new_operator == BindingConstraintOperator.BOTH: - matrix_path_lt, matrix_path_gt = matrix_paths.matrix_target_paths if existing_operator == BindingConstraintOperator.EQUAL: - (matrix_path_eq,) = matrix_paths.matrix_input_paths - matrix_path_lt.unlink(missing_ok=True) - matrix_path_gt.unlink(missing_ok=True) - matrix_path_eq.rename(matrix_path_lt) + matrix_eq.move_file_to(matrix_lt, force=True) + matrix_gt.delete() # copy the matrix lt to gt - shutil.copy(matrix_path_lt, matrix_path_gt) + matrix_lt.move_file_to(matrix_gt, force=True, copy=True) elif existing_operator == BindingConstraintOperator.LESS: - matrix_path_gt.unlink(missing_ok=True) - shutil.copy(matrix_path_lt, matrix_path_gt) - elif existing_operator == BindingConstraintOperator.GREATER: - matrix_path_lt.unlink(missing_ok=True) - shutil.copy(matrix_path_gt, matrix_path_lt) - elif existing_operator == BindingConstraintOperator.BOTH: - matrix_path_lt, matrix_path_gt = matrix_paths.matrix_input_paths + matrix_gt.delete() + matrix_lt.move_file_to(matrix_gt, force=True, copy=True) + else: + matrix_lt.delete() + matrix_gt.move_file_to(matrix_lt, force=True, copy=True) + else: if new_operator == BindingConstraintOperator.EQUAL: # TODO: we may retrieve the mean of the two matrices, but here we just copy the lt matrix - (matrix_path_eq,) = matrix_paths.matrix_target_paths - shutil.copy(matrix_path_lt, matrix_path_eq) - matrix_path_gt.unlink(missing_ok=True) - matrix_path_lt.unlink(missing_ok=True) + matrix_lt.move_file_to(matrix_eq, force=True) + matrix_gt.delete() elif new_operator == BindingConstraintOperator.LESS: - matrix_path_gt.unlink(missing_ok=True) - elif new_operator == BindingConstraintOperator.GREATER: - matrix_path_lt.unlink(missing_ok=True) + matrix_gt.delete() + else: + matrix_lt.delete() class UpdateBindingConstraint(AbstractBindingConstraintCommand): @@ -262,7 +157,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: new_operator = BindingConstraintOperator(self.operator) _update_matrices_names(study_data, self.id, existing_operator, new_operator) - updated_matrices = [term for term in TERM_MATRICES if hasattr(self, term) and getattr(self, term)] + updated_matrices = [ + term for term in [m.value for m in TermMatrices] if hasattr(self, term) and getattr(self, term) + ] study_version = study_data.config.version time_step = self.time_step or BindingConstraintFrequency(actual_cfg.get("type")) self.validates_and_fills_matrices( @@ -287,7 +184,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: return super().apply_binding_constraint(study_data, binding_constraints, index, self.id, old_groups=old_groups) def to_dto(self) -> CommandDTO: - matrices = ["values"] + TERM_MATRICES + matrices = ["values"] + [m.value for m in TermMatrices] matrix_service = self.command_context.matrix_service excluded_fields = frozenset(ICommand.__fields__) diff --git a/tests/integration/study_data_blueprint/test_binding_constraints.py b/tests/integration/study_data_blueprint/test_binding_constraints.py index fc61db7c7b..aba3d397ac 100644 --- a/tests/integration/study_data_blueprint/test_binding_constraints.py +++ b/tests/integration/study_data_blueprint/test_binding_constraints.py @@ -513,10 +513,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st assert res.json()["exception"] == "InvalidFieldForVersionError" assert res.json()["description"] == "You cannot fill a 'matrix_term' as these values refer to v8.7+ studies" - @pytest.mark.parametrize( - "study_type", - ["raw", "variant"], - ) + @pytest.mark.parametrize("study_type", ["raw", "variant"]) def test_for_version_870(self, client: TestClient, user_access_token: str, study_type: str) -> None: client.headers = {"Authorization": f"Bearer {user_access_token}"} # type: ignore @@ -725,6 +722,9 @@ def test_for_version_870(self, client: TestClient, user_access_token: str, study json={"greater_term_matrix": matrix_lt3.tolist()}, ) assert res.status_code == 422, res.json() + assert "greater_term_matrix" in res.json()["description"] + assert "equal" in res.json()["description"] + assert res.json()["exception"] == "InvalidFieldForVersionError" # update the binding constraint operator first res = client.put(