From 3bb1cc1e5c139bc690d14a86163861266347c3d5 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 30 Jan 2024 11:44:27 +0100 Subject: [PATCH 01/12] fix(model): ensure validation constraints on inherited fields in `AllOptionalMetaclass` --- antarest/study/business/utils.py | 110 ++++-- .../business/test_all_optional_metaclass.py | 349 ++++++++++++++++++ 2 files changed, 420 insertions(+), 39 deletions(-) create mode 100644 tests/study/business/test_all_optional_metaclass.py diff --git a/antarest/study/business/utils.py b/antarest/study/business/utils.py index fbffbc310b..53596bf797 100644 --- a/antarest/study/business/utils.py +++ b/antarest/study/business/utils.py @@ -1,7 +1,8 @@ -from typing import Any, Callable, Dict, MutableSequence, Optional, Sequence, Tuple, Type, TypedDict, TypeVar +import typing as t -import pydantic -from pydantic import BaseModel, Extra +import pydantic.fields +import pydantic.main +from pydantic import BaseModel from antarest.core.exceptions import CommandApplicationError from antarest.core.jwt import DEFAULT_ADMIN_USER @@ -21,11 +22,11 @@ def execute_or_add_commands( study: Study, file_study: FileStudy, - commands: Sequence[ICommand], + commands: t.Sequence[ICommand], storage_service: StudyStorageService, ) -> None: if isinstance(study, RawStudy): - executed_commands: MutableSequence[ICommand] = [] + executed_commands: t.MutableSequence[ICommand] = [] for command in commands: result = command.apply(file_study) if not result.status: @@ -58,69 +59,100 @@ def execute_or_add_commands( ) -class FormFieldsBaseModel(BaseModel): +class FormFieldsBaseModel( + BaseModel, + alias_generator=to_camel_case, + extra="forbid", + validate_assignment=True, + allow_population_by_field_name=True, +): """ Pydantic Model for webapp form """ - class Config: - alias_generator = to_camel_case - extra = Extra.forbid - validate_assignment = True - allow_population_by_field_name = True - -class FieldInfo(TypedDict, total=False): +class FieldInfo(t.TypedDict, total=False): path: str - default_value: Any - start_version: Optional[int] - end_version: Optional[int] - # Workaround to replace Pydantic's computed values which are ignored by FastAPI. + default_value: t.Any + start_version: t.Optional[int] + end_version: t.Optional[int] + # Workaround to replace Pydantic computed values which are ignored by FastAPI. # TODO: check @computed_field available in Pydantic v2 to remove it # (value) -> encoded_value - encode: Optional[Callable[[Any], Any]] + encode: t.Optional[t.Callable[[t.Any], t.Any]] # (encoded_value, current_value) -> decoded_value - decode: Optional[Callable[[Any, Optional[Any]], Any]] + decode: t.Optional[t.Callable[[t.Any, t.Optional[t.Any]], t.Any]] class AllOptionalMetaclass(pydantic.main.ModelMetaclass): """ Metaclass that makes all fields of a Pydantic model optional. - This metaclass modifies the class's annotations to make all fields - optional by wrapping them with the `Optional` type. - Usage: class MyModel(BaseModel, metaclass=AllOptionalMetaclass): field1: str field2: int ... - The fields defined in the model will be automatically converted to optional - fields, allowing instances of the model to be created even if not all fields - are provided during initialization. + Instances of the model can be created even if not all fields are provided during initialization. + Default values, when provided, are used unless `use_none` is set to `True`. """ def __new__( - cls: Type["AllOptionalMetaclass"], + cls: t.Type["AllOptionalMetaclass"], name: str, - bases: Tuple[Type[Any], ...], - namespaces: Dict[str, Any], - **kwargs: Dict[str, Any], - ) -> Any: + bases: t.Tuple[t.Type[t.Any], ...], + namespaces: t.Dict[str, t.Any], + use_none: bool = False, + **kwargs: t.Dict[str, t.Any], + ) -> t.Any: + """ + Create a new instance of the metaclass. + + Args: + name: Name of the class to create. + bases: Base classes of the class to create (a Pydantic model). + namespaces: namespace of the class to create that defines the fields of the model. + use_none: If `True`, the default value of the fields is set to `None`. + Note that this field is not part of the Pydantic model, but it is an extension. + **kwargs: Additional keyword arguments used by the metaclass. + """ + # Modify the annotations of the class (but not of the ancestor classes) + # in order to make all fields optional. + # If the current model inherits from another model, the annotations of the ancestor models + # are not modified, because the fields are already converted to `ModelField`. annotations = namespaces.get("__annotations__", {}) - for base in bases: - for ancestor in reversed(base.__mro__): - annotations.update(getattr(ancestor, "__annotations__", {})) - for field, field_type in annotations.items(): - if not field.startswith("__"): - # Optional fields are correctly handled - annotations[field] = Optional[annotations[field]] + for field_name, field_type in annotations.items(): + if not field_name.startswith("__"): + # Making already optional fields optional is not a problem (nothing is changed). + annotations[field_name] = t.Optional[field_type] namespaces["__annotations__"] = annotations - return super().__new__(cls, name, bases, namespaces) + + if use_none: + # Modify the namespace fields to set their default value to `None`. + for field_name, field_info in namespaces.items(): + if isinstance(field_info, pydantic.fields.FieldInfo): + field_info.default = None + field_info.default_factory = None + + # Create the class: all annotations are converted into `ModelField`. + instance = super().__new__(cls, name, bases, namespaces, **kwargs) + + # Modify the inherited fields of the class to make them optional + # and set their default value to `None`. + model_field: pydantic.fields.ModelField + for field_name, model_field in instance.__fields__.items(): + model_field.required = False + model_field.allow_none = True + if use_none: + model_field.default = None + model_field.default_factory = None + model_field.field_info.default = None + + return instance -MODEL = TypeVar("MODEL", bound=Type[BaseModel]) +MODEL = t.TypeVar("MODEL", bound=t.Type[BaseModel]) def camel_case_model(model: MODEL) -> MODEL: diff --git a/tests/study/business/test_all_optional_metaclass.py b/tests/study/business/test_all_optional_metaclass.py new file mode 100644 index 0000000000..2e83a4e433 --- /dev/null +++ b/tests/study/business/test_all_optional_metaclass.py @@ -0,0 +1,349 @@ +import typing as t + +import pytest +from pydantic import BaseModel, Field, ValidationError + +from antarest.study.business.utils import AllOptionalMetaclass + +# ============================================== +# Classic way to use default and optional values +# ============================================== + + +class ClassicModel(BaseModel): + mandatory: float = Field(ge=0, le=1) + mandatory_with_default: float = Field(ge=0, le=1, default=0.2) + mandatory_with_none: float = Field(ge=0, le=1, default=None) + optional: t.Optional[float] = Field(ge=0, le=1) + optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) + optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) + + +class ClassicSubModel(ClassicModel): + pass + + +class TestClassicModel: + """ + Test that default and optional values work as expected. + """ + + @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) + def test_classes(self, cls: t.Type[BaseModel]) -> None: + assert cls.__fields__["mandatory"].required is True + assert cls.__fields__["mandatory"].allow_none is False + assert cls.__fields__["mandatory"].default is None + assert cls.__fields__["mandatory"].default_factory is None # undefined + + assert cls.__fields__["mandatory_with_default"].required is False + assert cls.__fields__["mandatory_with_default"].allow_none is False + assert cls.__fields__["mandatory_with_default"].default == 0.2 + assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined + + assert cls.__fields__["mandatory_with_none"].required is False + assert cls.__fields__["mandatory_with_none"].allow_none is True + assert cls.__fields__["mandatory_with_none"].default is None + assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined + + assert cls.__fields__["optional"].required is False + assert cls.__fields__["optional"].allow_none is True + assert cls.__fields__["optional"].default is None + assert cls.__fields__["optional"].default_factory is None # undefined + + assert cls.__fields__["optional_with_default"].required is False + assert cls.__fields__["optional_with_default"].allow_none is True + assert cls.__fields__["optional_with_default"].default == 0.2 + assert cls.__fields__["optional_with_default"].default_factory is None # undefined + + assert cls.__fields__["optional_with_none"].required is False + assert cls.__fields__["optional_with_none"].allow_none is True + assert cls.__fields__["optional_with_none"].default is None + assert cls.__fields__["optional_with_none"].default_factory is None # undefined + + @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) + def test_initialization(self, cls: t.Type[ClassicModel]) -> None: + # We can build a model without providing optional or default values. + # The initialized value will be the default value or `None` for optional values. + obj = cls(mandatory=0.5) + assert obj.mandatory == 0.5 + assert obj.mandatory_with_default == 0.2 + assert obj.mandatory_with_none is None + assert obj.optional is None + assert obj.optional_with_default == 0.2 + assert obj.optional_with_none is None + + # We must provide a value for mandatory fields. + with pytest.raises(ValidationError): + cls() + + @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) + def test_validation(self, cls: t.Type[ClassicModel]) -> None: + # We CANNOT use `None` as a value for a field with a default value. + with pytest.raises(ValidationError): + cls(mandatory=0.5, mandatory_with_default=None) + + # We can use `None` as a value for optional fields with default value. + cls(mandatory=0.5, optional_with_default=None) + + # We can validate a model with valid values. + cls( + mandatory=0.5, + mandatory_with_default=0.2, + mandatory_with_none=0.2, + optional=0.5, + optional_with_default=0.2, + optional_with_none=0.2, + ) + + # We CANNOT validate a model with invalid values. + with pytest.raises(ValidationError): + cls(mandatory=2) + + with pytest.raises(ValidationError): + cls(mandatory=0.5, mandatory_with_default=2) + + with pytest.raises(ValidationError): + cls(mandatory=0.5, mandatory_with_none=2) + + with pytest.raises(ValidationError): + cls(mandatory=0.5, optional=2) + + with pytest.raises(ValidationError): + cls(mandatory=0.5, optional_with_default=2) + + with pytest.raises(ValidationError): + cls(mandatory=0.5, optional_with_none=2) + + +# ========================== +# Using AllOptionalMetaclass +# ========================== + + +class AllOptionalModel(BaseModel, metaclass=AllOptionalMetaclass): + mandatory: float = Field(ge=0, le=1) + mandatory_with_default: float = Field(ge=0, le=1, default=0.2) + mandatory_with_none: float = Field(ge=0, le=1, default=None) + optional: t.Optional[float] = Field(ge=0, le=1) + optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) + optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) + + +class AllOptionalSubModel(AllOptionalModel): + pass + + +class InheritedAllOptionalModel(ClassicModel, metaclass=AllOptionalMetaclass): + pass + + +class TestAllOptionalModel: + """ + Test that AllOptionalMetaclass works with base classes. + """ + + @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) + def test_classes(self, cls: t.Type[BaseModel]) -> None: + assert cls.__fields__["mandatory"].required is False + assert cls.__fields__["mandatory"].allow_none is True + assert cls.__fields__["mandatory"].default is None + assert cls.__fields__["mandatory"].default_factory is None # undefined + + assert cls.__fields__["mandatory_with_default"].required is False + assert cls.__fields__["mandatory_with_default"].allow_none is True + assert cls.__fields__["mandatory_with_default"].default == 0.2 + assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined + + assert cls.__fields__["mandatory_with_none"].required is False + assert cls.__fields__["mandatory_with_none"].allow_none is True + assert cls.__fields__["mandatory_with_none"].default is None + assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined + + assert cls.__fields__["optional"].required is False + assert cls.__fields__["optional"].allow_none is True + assert cls.__fields__["optional"].default is None + assert cls.__fields__["optional"].default_factory is None # undefined + + assert cls.__fields__["optional_with_default"].required is False + assert cls.__fields__["optional_with_default"].allow_none is True + assert cls.__fields__["optional_with_default"].default == 0.2 + assert cls.__fields__["optional_with_default"].default_factory is None # undefined + + assert cls.__fields__["optional_with_none"].required is False + assert cls.__fields__["optional_with_none"].allow_none is True + assert cls.__fields__["optional_with_none"].default is None + assert cls.__fields__["optional_with_none"].default_factory is None # undefined + + @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) + def test_initialization(self, cls: t.Type[AllOptionalModel]) -> None: + # We can build a model without providing values. + # The initialized value will be the default value or `None` for optional values. + # Note that the mandatory fields are not required anymore, and can be `None`. + obj = cls() + assert obj.mandatory is None + assert obj.mandatory_with_default == 0.2 + assert obj.mandatory_with_none is None + assert obj.optional is None + assert obj.optional_with_default == 0.2 + assert obj.optional_with_none is None + + # If we convert the model to a dictionary, without `None` values, + # we should have a dictionary with default values only. + actual = obj.dict(exclude_none=True) + expected = { + "mandatory_with_default": 0.2, + "optional_with_default": 0.2, + } + assert actual == expected + + @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) + def test_validation(self, cls: t.Type[AllOptionalModel]) -> None: + # We can use `None` as a value for all fields. + cls(mandatory=None) + cls(mandatory_with_default=None) + cls(mandatory_with_none=None) + cls(optional=None) + cls(optional_with_default=None) + cls(optional_with_none=None) + + # We can validate a model with valid values. + cls( + mandatory=0.5, + mandatory_with_default=0.2, + mandatory_with_none=0.2, + optional=0.5, + optional_with_default=0.2, + optional_with_none=0.2, + ) + + # We CANNOT validate a model with invalid values. + with pytest.raises(ValidationError): + cls(mandatory=2) + + with pytest.raises(ValidationError): + cls(mandatory_with_default=2) + + with pytest.raises(ValidationError): + cls(mandatory_with_none=2) + + with pytest.raises(ValidationError): + cls(optional=2) + + with pytest.raises(ValidationError): + cls(optional_with_default=2) + + with pytest.raises(ValidationError): + cls(optional_with_none=2) + + +# The `use_none` keyword argument is set to `True` to allow the use of `None` +# as a default value for the fields of the model. + + +class UseNoneModel(BaseModel, metaclass=AllOptionalMetaclass, use_none=True): + mandatory: float = Field(ge=0, le=1) + mandatory_with_default: float = Field(ge=0, le=1, default=0.2) + mandatory_with_none: float = Field(ge=0, le=1, default=None) + optional: t.Optional[float] = Field(ge=0, le=1) + optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) + optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) + + +class UseNoneSubModel(UseNoneModel): + pass + + +class InheritedUseNoneModel(ClassicModel, metaclass=AllOptionalMetaclass, use_none=True): + pass + + +class TestUseNoneModel: + @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) + def test_classes(self, cls: t.Type[BaseModel]) -> None: + assert cls.__fields__["mandatory"].required is False + assert cls.__fields__["mandatory"].allow_none is True + assert cls.__fields__["mandatory"].default is None + assert cls.__fields__["mandatory"].default_factory is None # undefined + + assert cls.__fields__["mandatory_with_default"].required is False + assert cls.__fields__["mandatory_with_default"].allow_none is True + assert cls.__fields__["mandatory_with_default"].default is None + assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined + + assert cls.__fields__["mandatory_with_none"].required is False + assert cls.__fields__["mandatory_with_none"].allow_none is True + assert cls.__fields__["mandatory_with_none"].default is None + assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined + + assert cls.__fields__["optional"].required is False + assert cls.__fields__["optional"].allow_none is True + assert cls.__fields__["optional"].default is None + assert cls.__fields__["optional"].default_factory is None # undefined + + assert cls.__fields__["optional_with_default"].required is False + assert cls.__fields__["optional_with_default"].allow_none is True + assert cls.__fields__["optional_with_default"].default is None + assert cls.__fields__["optional_with_default"].default_factory is None # undefined + + assert cls.__fields__["optional_with_none"].required is False + assert cls.__fields__["optional_with_none"].allow_none is True + assert cls.__fields__["optional_with_none"].default is None + assert cls.__fields__["optional_with_none"].default_factory is None # undefined + + @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) + def test_initialization(self, cls: t.Type[UseNoneModel]) -> None: + # We can build a model without providing values. + # The initialized value will be the default value or `None` for optional values. + # Note that the mandatory fields are not required anymore, and can be `None`. + obj = cls() + assert obj.mandatory is None + assert obj.mandatory_with_default is None + assert obj.mandatory_with_none is None + assert obj.optional is None + assert obj.optional_with_default is None + assert obj.optional_with_none is None + + # If we convert the model to a dictionary, without `None` values, + # we should have an empty dictionary. + actual = obj.dict(exclude_none=True) + expected = {} + assert actual == expected + + @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) + def test_validation(self, cls: t.Type[UseNoneModel]) -> None: + # We can use `None` as a value for all fields. + cls(mandatory=None) + cls(mandatory_with_default=None) + cls(mandatory_with_none=None) + cls(optional=None) + cls(optional_with_default=None) + cls(optional_with_none=None) + + # We can validate a model with valid values. + cls( + mandatory=0.5, + mandatory_with_default=0.2, + mandatory_with_none=0.2, + optional=0.5, + optional_with_default=0.2, + optional_with_none=0.2, + ) + + # We CANNOT validate a model with invalid values. + with pytest.raises(ValidationError): + cls(mandatory=2) + + with pytest.raises(ValidationError): + cls(mandatory_with_default=2) + + with pytest.raises(ValidationError): + cls(mandatory_with_none=2) + + with pytest.raises(ValidationError): + cls(optional=2) + + with pytest.raises(ValidationError): + cls(optional_with_default=2) + + with pytest.raises(ValidationError): + cls(optional_with_none=2) From 47e92005f87d150e9f99cc5ceffc539502d4f357 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 13:28:39 +0100 Subject: [PATCH 02/12] style(command): use `t` alias to import `typing` --- .../model/command/create_cluster.py | 30 ++++++++-------- .../command/create_renewables_cluster.py | 16 ++++----- .../model/command/create_st_storage.py | 36 +++++++++---------- .../model/command/remove_cluster.py | 8 ++--- .../model/command/remove_st_storage.py | 8 ++--- 5 files changed, 49 insertions(+), 49 deletions(-) diff --git a/antarest/study/storage/variantstudy/model/command/create_cluster.py b/antarest/study/storage/variantstudy/model/command/create_cluster.py index 392abb23a9..f9edfba949 100644 --- a/antarest/study/storage/variantstudy/model/command/create_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/create_cluster.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, cast +import typing as t -from pydantic import Extra, validator +from pydantic import validator from antarest.core.model import JSON from antarest.core.utils.utils import assert_this @@ -34,9 +34,9 @@ class CreateCluster(ICommand): area_id: str cluster_name: str - parameters: Dict[str, str] - prepro: Optional[Union[List[List[MatrixData]], str]] = None - modulation: Optional[Union[List[List[MatrixData]], str]] = None + parameters: t.Dict[str, str] + prepro: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = None + modulation: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = None @validator("cluster_name") def validate_cluster_name(cls, val: str) -> str: @@ -47,8 +47,8 @@ def validate_cluster_name(cls, val: str) -> str: @validator("prepro", always=True) def validate_prepro( - cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any - ) -> Optional[Union[List[List[MatrixData]], str]]: + cls, v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], values: t.Any + ) -> t.Optional[t.Union[t.List[t.List[MatrixData]], str]]: if v is None: v = values["command_context"].generator_matrix_constants.get_thermal_prepro_data() return v @@ -58,8 +58,8 @@ def validate_prepro( @validator("modulation", always=True) def validate_modulation( - cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any - ) -> Optional[Union[List[List[MatrixData]], str]]: + cls, v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], values: t.Any + ) -> t.Optional[t.Union[t.List[t.List[MatrixData]], str]]: if v is None: v = values["command_context"].generator_matrix_constants.get_thermal_prepro_modulation() return v @@ -67,7 +67,7 @@ def validate_modulation( else: return validate_matrix(v, values) - def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: # Search the Area in the configuration if self.area_id not in study_data.areas: return ( @@ -173,14 +173,14 @@ def match(self, other: ICommand, equal: bool = False) -> bool: and self.modulation == other.modulation ) - def _create_diff(self, other: "ICommand") -> List["ICommand"]: - other = cast(CreateCluster, other) + def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: + other = t.cast(CreateCluster, other) from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig # Series identifiers are in lower case. series_id = transform_name_to_id(self.cluster_name, lower=True) - commands: List[ICommand] = [] + commands: t.List[ICommand] = [] if self.prepro != other.prepro: commands.append( ReplaceMatrix( @@ -207,8 +207,8 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]: ) return commands - def get_inner_matrices(self) -> List[str]: - matrices: List[str] = [] + def get_inner_matrices(self) -> t.List[str]: + matrices: t.List[str] = [] if self.prepro: assert_this(isinstance(self.prepro, str)) matrices.append(strip_matrix_protocol(self.prepro)) diff --git a/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py b/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py index 0658de4076..ab61d8f710 100644 --- a/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Tuple, cast +import typing as t -from pydantic import Extra, validator +from pydantic import validator from antarest.core.model import JSON from antarest.study.storage.rawstudy.model.filesystem.config.model import ( @@ -32,7 +32,7 @@ class CreateRenewablesCluster(ICommand): area_id: str cluster_name: str - parameters: Dict[str, str] + parameters: t.Dict[str, str] @validator("cluster_name") def validate_cluster_name(cls, val: str) -> str: @@ -41,7 +41,7 @@ def validate_cluster_name(cls, val: str) -> str: raise ValueError("Area name must only contains [a-zA-Z0-9],&,-,_,(,) characters") return val - def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: if study_data.enr_modelling != ENR_MODELLING.CLUSTERS.value: # Since version 8.1 of the solver, we can use renewable clusters # instead of "Load", "Wind" and "Solar" objects for modelling. @@ -147,11 +147,11 @@ def match(self, other: ICommand, equal: bool = False) -> bool: return simple_match return simple_match and self.parameters == other.parameters - def _create_diff(self, other: "ICommand") -> List["ICommand"]: - other = cast(CreateRenewablesCluster, other) + def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: + other = t.cast(CreateRenewablesCluster, other) from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig - commands: List[ICommand] = [] + commands: t.List[ICommand] = [] if self.parameters != other.parameters: commands.append( UpdateConfig( @@ -162,5 +162,5 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]: ) return commands - def get_inner_matrices(self) -> List[str]: + def get_inner_matrices(self) -> t.List[str]: return [] diff --git a/antarest/study/storage/variantstudy/model/command/create_st_storage.py b/antarest/study/storage/variantstudy/model/command/create_st_storage.py index 771c2dd4b0..8342a92cb7 100644 --- a/antarest/study/storage/variantstudy/model/command/create_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/create_st_storage.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List, Optional, Tuple, Union, cast +import typing as t import numpy as np from pydantic import Field, validator @@ -28,7 +28,7 @@ # Minimum required version. REQUIRED_VERSION = 860 -MatrixType = List[List[MatrixData]] +MatrixType = t.List[t.List[MatrixData]] # noinspection SpellCheckingInspection @@ -48,23 +48,23 @@ class CreateSTStorage(ICommand): area_id: str = Field(description="Area ID", regex=r"[a-z0-9_(),& -]+") parameters: STStorageConfigType - pmax_injection: Optional[Union[MatrixType, str]] = Field( + pmax_injection: t.Optional[t.Union[MatrixType, str]] = Field( None, description="Charge capacity (modulation)", ) - pmax_withdrawal: Optional[Union[MatrixType, str]] = Field( + pmax_withdrawal: t.Optional[t.Union[MatrixType, str]] = Field( None, description="Discharge capacity (modulation)", ) - lower_rule_curve: Optional[Union[MatrixType, str]] = Field( + lower_rule_curve: t.Optional[t.Union[MatrixType, str]] = Field( None, description="Lower rule curve (coefficient)", ) - upper_rule_curve: Optional[Union[MatrixType, str]] = Field( + upper_rule_curve: t.Optional[t.Union[MatrixType, str]] = Field( None, description="Upper rule curve (coefficient)", ) - inflows: Optional[Union[MatrixType, str]] = Field( + inflows: t.Optional[t.Union[MatrixType, str]] = Field( None, description="Inflows (MW)", ) @@ -82,10 +82,10 @@ def storage_name(self) -> str: @validator(*_MATRIX_NAMES, always=True) def register_matrix( cls, - v: Optional[Union[MatrixType, str]], - values: Dict[str, Any], + v: t.Optional[t.Union[MatrixType, str]], + values: t.Dict[str, t.Any], field: ModelField, - ) -> Optional[Union[MatrixType, str]]: + ) -> t.Optional[t.Union[MatrixType, str]]: """ Validates a matrix array or link, and store the matrix array in the matrix repository. @@ -138,13 +138,13 @@ def register_matrix( constrained = set(_MATRIX_NAMES) - {"inflows"} if field.name in constrained and (np.any(array < 0) or np.any(array > 1)): raise ValueError("Matrix values should be between 0 and 1") - v = cast(MatrixType, array.tolist()) + v = t.cast(MatrixType, array.tolist()) return validate_matrix(v, values) # Invalid datatype # pragma: no cover raise TypeError(repr(v)) - def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ Applies configuration changes to the study data: add the short-term storage in the storages list. @@ -284,7 +284,7 @@ def match(self, other: "ICommand", equal: bool = False) -> bool: else: return self.area_id == other.area_id and self.storage_id == other.storage_id - def _create_diff(self, other: "ICommand") -> List["ICommand"]: + def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: """ Creates a list of commands representing the differences between the current instance and another `ICommand` object. @@ -299,8 +299,8 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]: from antarest.study.storage.variantstudy.model.command.replace_matrix import ReplaceMatrix from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig - other = cast(CreateSTStorage, other) - commands: List[ICommand] = [ + other = t.cast(CreateSTStorage, other) + commands: t.List[ICommand] = [ ReplaceMatrix( target=f"input/st-storage/series/{self.area_id}/{self.storage_id}/{attr}", matrix=strip_matrix_protocol(getattr(other, attr)), @@ -310,7 +310,7 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]: if getattr(self, attr) != getattr(other, attr) ] if self.parameters != other.parameters: - data: Dict[str, Any] = json.loads(other.parameters.json(by_alias=True, exclude={"id"})) + data: t.Dict[str, t.Any] = json.loads(other.parameters.json(by_alias=True, exclude={"id"})) commands.append( UpdateConfig( target=f"input/st-storage/clusters/{self.area_id}/list/{self.storage_id}", @@ -320,9 +320,9 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]: ) return commands - def get_inner_matrices(self) -> List[str]: + def get_inner_matrices(self) -> t.List[str]: """ Retrieves the list of matrix IDs. """ - matrices: List[str] = [strip_matrix_protocol(getattr(self, attr)) for attr in _MATRIX_NAMES] + matrices: t.List[str] = [strip_matrix_protocol(getattr(self, attr)) for attr in _MATRIX_NAMES] return matrices diff --git a/antarest/study/storage/variantstudy/model/command/remove_cluster.py b/antarest/study/storage/variantstudy/model/command/remove_cluster.py index ff1fdd4a73..095e62f526 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/remove_cluster.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +import typing as t from antarest.study.storage.rawstudy.model.filesystem.config.model import Area, FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy @@ -27,7 +27,7 @@ class RemoveCluster(ICommand): area_id: str cluster_id: str - def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ Applies configuration changes to the study data: remove the thermal clusters from the storages list. @@ -128,10 +128,10 @@ def match(self, other: ICommand, equal: bool = False) -> bool: return False return self.cluster_id == other.cluster_id and self.area_id == other.area_id - def _create_diff(self, other: "ICommand") -> List["ICommand"]: + def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: return [] - def get_inner_matrices(self) -> List[str]: + def get_inner_matrices(self) -> t.List[str]: return [] # noinspection SpellCheckingInspection diff --git a/antarest/study/storage/variantstudy/model/command/remove_st_storage.py b/antarest/study/storage/variantstudy/model/command/remove_st_storage.py index 0fd4bbb1de..116f402c08 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/remove_st_storage.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +import typing as t from pydantic import Field @@ -29,7 +29,7 @@ class RemoveSTStorage(ICommand): area_id: str = Field(description="Area ID", regex=r"[a-z0-9_(),& -]+") storage_id: str = Field(description="Short term storage ID", regex=r"[a-z0-9_(),& -]+") - def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ Applies configuration changes to the study data: remove the storage from the storages list. @@ -143,8 +143,8 @@ def match(self, other: "ICommand", equal: bool = False) -> bool: # or matrices, so that shallow and deep comparisons are identical. return self.__eq__(other) - def _create_diff(self, other: "ICommand") -> List["ICommand"]: + def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: return [] - def get_inner_matrices(self) -> List[str]: + def get_inner_matrices(self) -> t.List[str]: return [] From b84271206eab498877c0e23d65c2814718300b62 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 13:29:05 +0100 Subject: [PATCH 03/12] chore: reorder imports --- antarest/launcher/ssh_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antarest/launcher/ssh_client.py b/antarest/launcher/ssh_client.py index acace09c9b..7337a308ac 100644 --- a/antarest/launcher/ssh_client.py +++ b/antarest/launcher/ssh_client.py @@ -1,6 +1,6 @@ import contextlib -import socket import shlex +import socket from typing import Any, List, Tuple import paramiko From 06d2fad501ffde982c725290e8ce4724b9950404 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 15:09:00 +0100 Subject: [PATCH 04/12] docs: correct docstring (invalid exception name) --- antarest/study/business/areas/hydro_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/antarest/study/business/areas/hydro_management.py b/antarest/study/business/areas/hydro_management.py index a6a8742168..e0a52ee4e1 100644 --- a/antarest/study/business/areas/hydro_management.py +++ b/antarest/study/business/areas/hydro_management.py @@ -163,7 +163,7 @@ def update_inflow_structure(self, study: Study, area_id: str, values: InflowStru values: The new inflow structure values to be updated. Raises: - RequestValidationError: If the provided `values` parameter is None or invalid. + ValidationError: If the provided `values` parameter is None or invalid. """ # NOTE: Updates only "intermonthly-correlation" due to current model scope. path = INFLOW_PATH.format(area_id=area_id) From c2661d1b799e50c8dd07b8d11bbbf53873cd557c Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 16:17:25 +0100 Subject: [PATCH 05/12] chore(main): correct docstring and remove unused imports --- antarest/main.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/antarest/main.py b/antarest/main.py index 93be38ed2b..2187088d2b 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -6,7 +6,6 @@ from typing import Any, Dict, Optional, Sequence, Tuple, cast import pydantic -import sqlalchemy.ext.baked # type: ignore import uvicorn # type: ignore import uvicorn.config # type: ignore from fastapi import FastAPI, HTTPException @@ -56,18 +55,16 @@ class PathType: which specify whether the path argument must exist, whether it can be a file, and whether it can be a directory, respectively. - Example Usage: + Example Usage:: - ```python - import argparse - from antarest.main import PathType + import argparse + from antarest.main import PathType - parser = argparse.ArgumentParser() - parser.add_argument('--input', type=PathType(file_ok=True, exists=True)) - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=PathType(file_ok=True, exists=True)) + args = parser.parse_args() - print(args.input) - ``` + print(args.input) In the above example, `PathType` is used to specify the type of the `--input` argument for the `argparse` parser. The argument must be an existing file path. @@ -401,9 +398,11 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: application.add_middleware( RateLimitMiddleware, authenticate=auth_manager.create_auth_function(), - backend=RedisBackend(config.redis.host, config.redis.port, 1, config.redis.password) - if config.redis is not None - else MemoryBackend(), + backend=( + MemoryBackend() + if config.redis is None + else RedisBackend(config.redis.host, config.redis.port, 1, config.redis.password) + ), config=RATE_LIMIT_CONFIG, ) From 26180675bd3e0b24403ffd760f13f4a103bb4555 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 18:03:42 +0100 Subject: [PATCH 06/12] fix(model): allow models to have `None` values --- .../study/business/areas/renewable_management.py | 4 ++-- .../study/business/areas/st_storage_management.py | 2 +- .../study/business/areas/thermal_management.py | 4 ++-- .../study/business/thematic_trimming_management.py | 2 +- antarest/study/business/xpansion_management.py | 2 +- .../rawstudy/model/filesystem/config/cluster.py | 7 ++++++- antarest/study/web/study_data_blueprint.py | 14 ++++++++++++-- .../study_data_blueprint/test_renewable.py | 2 +- .../study_data_blueprint/test_st_storage.py | 2 +- .../study_data_blueprint/test_thermal.py | 2 +- 10 files changed, 28 insertions(+), 13 deletions(-) diff --git a/antarest/study/business/areas/renewable_management.py b/antarest/study/business/areas/renewable_management.py index 56a4b44a8d..abb0d17503 100644 --- a/antarest/study/business/areas/renewable_management.py +++ b/antarest/study/business/areas/renewable_management.py @@ -37,7 +37,7 @@ class TimeSeriesInterpretation(EnumIgnoreCase): @camel_case_model -class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass): +class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass, use_none=True): """ Model representing the data structure required to edit an existing renewable cluster. """ @@ -76,7 +76,7 @@ def to_config(self, study_version: t.Union[str, int]) -> RenewableConfigType: @camel_case_model -class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass): +class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass, use_none=True): """ Model representing the output data structure to display the details of a renewable cluster. """ diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index 8d1edd59b1..e607edca58 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -37,7 +37,7 @@ @camel_case_model -class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass): +class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass, use_none=True): """ Model representing the form used to EDIT an existing short-term storage. """ diff --git a/antarest/study/business/areas/thermal_management.py b/antarest/study/business/areas/thermal_management.py index b421775557..6f24c41ee7 100644 --- a/antarest/study/business/areas/thermal_management.py +++ b/antarest/study/business/areas/thermal_management.py @@ -30,7 +30,7 @@ @camel_case_model -class ThermalClusterInput(Thermal860Properties, metaclass=AllOptionalMetaclass): +class ThermalClusterInput(Thermal860Properties, metaclass=AllOptionalMetaclass, use_none=True): """ Model representing the data structure required to edit an existing thermal cluster within a study. """ @@ -70,7 +70,7 @@ def to_config(self, study_version: t.Union[str, int]) -> ThermalConfigType: @camel_case_model -class ThermalClusterOutput(Thermal860Config, metaclass=AllOptionalMetaclass): +class ThermalClusterOutput(Thermal860Config, metaclass=AllOptionalMetaclass, use_none=True): """ Model representing the output data structure to display the details of a thermal cluster within a study. """ diff --git a/antarest/study/business/thematic_trimming_management.py b/antarest/study/business/thematic_trimming_management.py index 4259046701..1ebfeebe04 100644 --- a/antarest/study/business/thematic_trimming_management.py +++ b/antarest/study/business/thematic_trimming_management.py @@ -12,7 +12,7 @@ from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig -class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetaclass): +class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetaclass, use_none=True): """ This class manages the configuration of result filtering in a simulation. diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py index dc595d5428..f3adadad32 100644 --- a/antarest/study/business/xpansion_management.py +++ b/antarest/study/business/xpansion_management.py @@ -196,7 +196,7 @@ def from_config(cls, config_obj: JSON) -> "GetXpansionSettings": return cls.construct(**config_obj) -class UpdateXpansionSettings(XpansionSettings, metaclass=AllOptionalMetaclass): +class UpdateXpansionSettings(XpansionSettings, metaclass=AllOptionalMetaclass, use_none=True): """ DTO object used to update the Xpansion settings. diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py index 2c7053e3ce..4563a0d217 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py @@ -84,9 +84,14 @@ class ClusterProperties(ItemProperties): @property def installed_capacity(self) -> float: - """""" + # fields may contain `None` values if they are turned into `Optional` fields + if self.unit_count is None or self.nominal_capacity is None: + return 0.0 return self.unit_count * self.nominal_capacity @property def enabled_capacity(self) -> float: + # fields may contain `None` values if they are turned into `Optional` fields + if self.enabled is None or self.installed_capacity is None: + return 0.0 return self.enabled * self.installed_capacity diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index 3bd394f0e6..c7e6ac17fd 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -25,8 +25,18 @@ RenewableClusterInput, RenewableClusterOutput, ) -from antarest.study.business.areas.st_storage_management import * # noqa -from antarest.study.business.areas.thermal_management import * # noqa +from antarest.study.business.areas.st_storage_management import ( + STStorageCreation, + STStorageInput, + STStorageMatrix, + STStorageOutput, + STStorageTimeSeries, +) +from antarest.study.business.areas.thermal_management import ( + ThermalClusterCreation, + ThermalClusterInput, + ThermalClusterOutput, +) from antarest.study.business.binding_constraint_management import ( BindingConstraintPropertiesWithName, ConstraintTermDTO, diff --git a/tests/integration/study_data_blueprint/test_renewable.py b/tests/integration/study_data_blueprint/test_renewable.py index c3bb7eaa79..14f1f4388a 100644 --- a/tests/integration/study_data_blueprint/test_renewable.py +++ b/tests/integration/study_data_blueprint/test_renewable.py @@ -201,7 +201,7 @@ def test_lifecycle( json=bad_properties, ) assert res.status_code == 422, res.json() - assert res.json()["exception"] == "ValidationError", res.json() + assert res.json()["exception"] == "RequestValidationError", res.json() # The renewable cluster properties should not have been updated. res = client.get( diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index ed3a45a360..1065a68e8b 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -223,7 +223,7 @@ def test_lifecycle__nominal( json=bad_properties, ) assert res.status_code == 422, res.json() - assert res.json()["exception"] == "ValidationError", res.json() + assert res.json()["exception"] == "RequestValidationError", res.json() # The short-term storage properties should not have been updated. res = client.get( diff --git a/tests/integration/study_data_blueprint/test_thermal.py b/tests/integration/study_data_blueprint/test_thermal.py index 587a5abcf5..1890d44acf 100644 --- a/tests/integration/study_data_blueprint/test_thermal.py +++ b/tests/integration/study_data_blueprint/test_thermal.py @@ -526,7 +526,7 @@ def test_lifecycle( json=bad_properties, ) assert res.status_code == 422, res.json() - assert res.json()["exception"] == "ValidationError", res.json() + assert res.json()["exception"] == "RequestValidationError", res.json() # The thermal cluster properties should not have been updated. res = client.get( From 9ecdddb821daf3864080f2bae33f227155696172 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 26 Jan 2024 08:39:54 +0100 Subject: [PATCH 07/12] chore(cluster): add a "fixme" to point a possible issue in Thermals and Renewables management --- antarest/study/business/areas/renewable_management.py | 1 + antarest/study/business/areas/thermal_management.py | 1 + 2 files changed, 2 insertions(+) diff --git a/antarest/study/business/areas/renewable_management.py b/antarest/study/business/areas/renewable_management.py index abb0d17503..48f349e527 100644 --- a/antarest/study/business/areas/renewable_management.py +++ b/antarest/study/business/areas/renewable_management.py @@ -248,6 +248,7 @@ def update_cluster( command_context=self.storage_service.variant_study_service.command_factory.command_context, ) + # fixme: The `file_study` is already retrieved at the beginning of the function. file_study = self.storage_service.get_storage(study).get_raw(study) execute_or_add_commands(study, file_study, [command], self.storage_service) return RenewableClusterOutput(**new_config.dict(by_alias=False)) diff --git a/antarest/study/business/areas/thermal_management.py b/antarest/study/business/areas/thermal_management.py index 6f24c41ee7..29d3e566d4 100644 --- a/antarest/study/business/areas/thermal_management.py +++ b/antarest/study/business/areas/thermal_management.py @@ -265,6 +265,7 @@ def update_cluster( # create the update config command with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context command = UpdateConfig(target=path, data=data, command_context=command_context) + # fixme: The `file_study` is already retrieved at the beginning of the function. file_study = self.storage_service.get_storage(study).get_raw(study) execute_or_add_commands(study, file_study, [command], self.storage_service) From e86b5c41f31de1164d7034f0867d84660c06c213 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Fri, 26 Jan 2024 23:02:58 +0100 Subject: [PATCH 08/12] chore: remove empty unit test file --- tests/core/test_config.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/core/test_config.py diff --git a/tests/core/test_config.py b/tests/core/test_config.py deleted file mode 100644 index e69de29bb2..0000000000 From d4ba1de05832fdd0e8fd8c36a67daf648af255f3 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Tue, 23 Jan 2024 15:13:06 +0100 Subject: [PATCH 09/12] fix(st-storage): correct the create and update commands --- .../business/areas/st_storage_management.py | 47 ++-- .../model/command/create_st_storage.py | 11 +- .../study_data_blueprint/test_st_storage.py | 242 ++++++++++++++---- .../areas/test_st_storage_management.py | 71 ++++- 4 files changed, 295 insertions(+), 76 deletions(-) diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index e607edca58..02f7af2e07 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -347,7 +347,20 @@ def update_storage( Updated form of short-term storage. """ study_version = study.version + + # review: reading the configuration poses a problem for variants, + # because it requires generating a snapshot, which takes time. + # This reading could be avoided if we don't need the previous values + # (no cross-field validation, no default values, etc.). + # In return, we won't be able to return a complete `STStorageOutput` object. + # So, we need to make sure the frontend doesn't need the missing fields. + # This missing information could also be a problem for the API users. + # The solution would be to avoid reading the configuration if the study is a variant + # (we then use the default values), otherwise, for a RAW study, we read the configuration + # and update the modified values. + file_study = self._get_file_study(study) + path = STORAGE_LIST_PATH.format(area_id=area_id, storage_id=storage_id) try: values = file_study.tree.get(path.split("/"), depth=1) @@ -357,31 +370,27 @@ def update_storage( old_config = create_st_storage_config(study_version, **values) # use Python values to synchronize Config and Form values - old_values = old_config.dict(exclude={"id"}) new_values = form.dict(by_alias=False, exclude_none=True) - updated = {**old_values, **new_values} - new_config = create_st_storage_config(study_version, **updated, id=storage_id) + new_config = old_config.copy(exclude={"id"}, update=new_values) new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) - # create the dict containing the old values (excluding defaults), - # the updated values (including defaults) - data: Dict[str, Any] = {} - for field_name, field in new_config.__fields__.items(): - if field_name in {"id"}: - continue - value = getattr(new_config, field_name) - if field_name in new_values or value != field.get_default(): - # use the JSON-converted value - data[field.alias] = new_data[field.alias] - - # create the update config command with the modified data + # create the dict containing the new values using aliases + data: Dict[str, Any] = { + field.alias: new_data[field.alias] + for field_name, field in new_config.__fields__.items() + if field_name in new_values + } + + # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context - command = UpdateConfig(target=path, data=data, command_context=command_context) - file_study = self._get_file_study(study) - execute_or_add_commands(study, file_study, [command], self.storage_service) + commands = [ + UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context) + for key, value in data.items() + ] + execute_or_add_commands(study, file_study, commands, self.storage_service) values = new_config.dict(by_alias=False) - return STStorageOutput(**values) + return STStorageOutput(**values, id=storage_id) def delete_storages( self, diff --git a/antarest/study/storage/variantstudy/model/command/create_st_storage.py b/antarest/study/storage/variantstudy/model/command/create_st_storage.py index 8342a92cb7..90fb980c4b 100644 --- a/antarest/study/storage/variantstudy/model/command/create_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/create_st_storage.py @@ -215,15 +215,10 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: if not output.status: return output - # Fill-in the "list.ini" file with the parameters + # Fill-in the "list.ini" file with the parameters. + # On creation, it's better to write all the parameters in the file. config = study_data.tree.get(["input", "st-storage", "clusters", self.area_id, "list"]) - config[self.storage_id] = json.loads( - self.parameters.json( - by_alias=True, - exclude={"id"}, - exclude_defaults=True, - ) - ) + config[self.storage_id] = json.loads(self.parameters.json(by_alias=True, exclude={"id"})) new_data: JSON = { "input": { diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index 1065a68e8b..fdffe5efe1 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -1,25 +1,23 @@ +import json import re +from unittest.mock import ANY import numpy as np import pytest from starlette.testclient import TestClient from antarest.core.tasks.model import TaskStatus +from antarest.study.business.areas.st_storage_management import STStorageOutput from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id +from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import STStorageConfig from tests.integration.utils import wait_task_completion -DEFAULT_PROPERTIES = { - # `name` field is required - "group": "Other1", - "injectionNominalCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - "reservoirCapacity": 0.0, - "efficiency": 1.0, - "initialLevel": 0.0, - "initialLevelOptim": False, -} +DEFAULT_CONFIG = json.loads(STStorageConfig(id="dummy", name="dummy").json(by_alias=True, exclude={"id", "name"})) +DEFAULT_PROPERTIES = json.loads(STStorageOutput(name="dummy").json(by_alias=True, exclude={"id", "name"})) + +# noinspection SpellCheckingInspection @pytest.mark.unit_test class TestSTStorage: # noinspection GrazieInspection @@ -395,16 +393,7 @@ def test_lifecycle__nominal( res = client.post( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages", headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "name": siemens_battery, - "group": "Battery", - "initialLevel": 0.0, - "initialLevelOptim": False, - "injectionNominalCapacity": 0.0, - "reservoirCapacity": 0.0, - "withdrawalNominalCapacity": 0.0, - "efficiency": 1.0, - }, + json={"name": siemens_battery, "group": "Battery"}, ) assert res.status_code == 500, res.json() obj = res.json() @@ -428,15 +417,7 @@ def test_lifecycle__nominal( res = client.patch( f"/v1/studies/{study_id}/areas/{bad_area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "efficiency": 1.0, - "initialLevel": 0.0, - "initialLevelOptim": True, - "injectionNominalCapacity": 2450, - "name": "New Siemens Battery", - "reservoirCapacity": 2500, - "withdrawalNominalCapacity": 2350, - }, + json={"efficiency": 1.0}, ) assert res.status_code == 404, res.json() obj = res.json() @@ -449,15 +430,7 @@ def test_lifecycle__nominal( res = client.patch( f"/v1/studies/{study_id}/areas/{area_id}/storages/{bad_storage_id}", headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "efficiency": 1.0, - "initialLevel": 0.0, - "initialLevelOptim": True, - "injectionNominalCapacity": 2450, - "name": "New Siemens Battery", - "reservoirCapacity": 2500, - "withdrawalNominalCapacity": 2350, - }, + json={"efficiency": 1.0}, ) assert res.status_code == 404, res.json() obj = res.json() @@ -470,17 +443,192 @@ def test_lifecycle__nominal( res = client.patch( f"/v1/studies/{bad_study_id}/areas/{area_id}/storages/{siemens_battery_id}", headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "efficiency": 1.0, - "initialLevel": 0.0, - "initialLevelOptim": True, - "injectionNominalCapacity": 2450, - "name": "New Siemens Battery", - "reservoirCapacity": 2500, - "withdrawalNominalCapacity": 2350, - }, + json={"efficiency": 1.0}, ) assert res.status_code == 404, res.json() obj = res.json() description = obj["description"] assert bad_study_id in description + + def test__default_values( + self, + client: TestClient, + user_access_token: str, + ) -> None: + """ + The purpose of this integration test is to test the default values of + the properties of a short-term storage. + + Given a new study with an area "FR", at least in version 860, + When I create a short-term storage with a name "Tesla Battery", with the default values, + Then the short-term storage is created with initialLevel = 0.0, and initialLevelOptim = False. + """ + # Create a new study in version 860 (or higher) + res = client.post( + "/v1/studies", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"name": "MyStudy", "version": 860}, + ) + assert res.status_code in {200, 201}, res.json() + study_id = res.json() + + # Create a new area named "FR" + res = client.post( + f"/v1/studies/{study_id}/areas", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"name": "FR", "type": "AREA"}, + ) + assert res.status_code in {200, 201}, res.json() + area_id = res.json()["id"] + + # Create a new short-term storage named "Tesla Battery" + tesla_battery = "Tesla Battery" + res = client.post( + f"/v1/studies/{study_id}/areas/{area_id}/storages", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"name": tesla_battery, "group": "Battery"}, + ) + assert res.status_code == 200, res.json() + tesla_battery_id = res.json()["id"] + tesla_config = {**DEFAULT_PROPERTIES, "id": tesla_battery_id, "name": tesla_battery, "group": "Battery"} + assert res.json() == tesla_config + + # Use the Debug mode to make sure that the initialLevel and initialLevelOptim properties + # are properly set in the configuration file. + res = client.get( + f"/v1/studies/{study_id}/raw", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"path": f"input/st-storage/clusters/{area_id}/list/{tesla_battery_id}"}, + ) + assert res.status_code == 200, res.json() + actual = res.json() + expected = {**DEFAULT_CONFIG, "name": tesla_battery, "group": "Battery"} + assert actual == expected + + # We want to make sure that the default properties are applied to a study variant. + # We want to make sure that updating the initialLevel property is taken into account + # in the variant commands. + + # Create a variant of the study + res = client.post( + f"/v1/studies/{study_id}/variants", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"name": "MyVariant"}, + ) + assert res.status_code in {200, 201}, res.json() + variant_id = res.json() + + # Create a new short-term storage named "Siemens Battery" + siemens_battery = "Siemens Battery" + res = client.post( + f"/v1/studies/{variant_id}/areas/{area_id}/storages", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"name": siemens_battery, "group": "Battery"}, + ) + assert res.status_code == 200, res.json() + + # Check the variant commands + res = client.get( + f"/v1/studies/{variant_id}/commands", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200, res.json() + commands = res.json() + assert len(commands) == 1 + actual = commands[0] + expected = { + "id": ANY, + "action": "create_st_storage", + "args": { + "area_id": "fr", + "parameters": {**DEFAULT_CONFIG, "name": siemens_battery, "group": "Battery"}, + "pmax_injection": ANY, + "pmax_withdrawal": ANY, + "lower_rule_curve": ANY, + "upper_rule_curve": ANY, + "inflows": ANY, + }, + "version": 1, + } + assert actual == expected + + # Update the initialLevel property of the "Siemens Battery" short-term storage to 0.5 + siemens_battery_id = transform_name_to_id(siemens_battery) + res = client.patch( + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"initialLevel": 0.5}, + ) + assert res.status_code == 200, res.json() + + # Check the variant commands + res = client.get( + f"/v1/studies/{variant_id}/commands", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200, res.json() + commands = res.json() + assert len(commands) == 2 + actual = commands[1] + expected = { + "id": ANY, + "action": "update_config", + "args": { + "data": "0.5", + "target": "input/st-storage/clusters/fr/list/siemens battery/initiallevel", + }, + "version": 1, + } + assert actual == expected + + # Update the initialLevel property of the "Siemens Battery" short-term storage back to 0 + res = client.patch( + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", + headers={"Authorization": f"Bearer {user_access_token}"}, + json={"initialLevel": 0.0, "injectionNominalCapacity": 1600}, + ) + assert res.status_code == 200, res.json() + + # Check the variant commands + res = client.get( + f"/v1/studies/{variant_id}/commands", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == 200, res.json() + commands = res.json() + assert len(commands) == 3 + actual = commands[2] + expected = { + "id": ANY, + "action": "update_config", + "args": [ + { + "data": "1600.0", + "target": "input/st-storage/clusters/fr/list/siemens battery/injectionnominalcapacity", + }, + { + "data": "0.0", + "target": "input/st-storage/clusters/fr/list/siemens battery/initiallevel", + }, + ], + "version": 1, + } + assert actual == expected + + # Use the Debug mode to make sure that the initialLevel and initialLevelOptim properties + # are properly set in the configuration file. + res = client.get( + f"/v1/studies/{variant_id}/raw", + headers={"Authorization": f"Bearer {user_access_token}"}, + params={"path": f"input/st-storage/clusters/{area_id}/list/{siemens_battery_id}"}, + ) + assert res.status_code == 200, res.json() + actual = res.json() + expected = { + **DEFAULT_CONFIG, + "name": siemens_battery, + "group": "Battery", + "injectionnominalcapacity": 1600, + "initiallevel": 0.0, + } + assert actual == expected diff --git a/tests/study/business/areas/test_st_storage_management.py b/tests/study/business/areas/test_st_storage_management.py index 61731718b5..646dc26c78 100644 --- a/tests/study/business/areas/test_st_storage_management.py +++ b/tests/study/business/areas/test_st_storage_management.py @@ -17,11 +17,12 @@ ) from antarest.core.model import PublicMode from antarest.login.model import Group, User -from antarest.study.business.areas.st_storage_management import STStorageManager +from antarest.study.business.areas.st_storage_management import STStorageInput, STStorageManager from antarest.study.model import RawStudy, Study, StudyContentStatus from antarest.study.storage.rawstudy.ini_reader import IniReader from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import STStorageGroup from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy +from antarest.study.storage.rawstudy.model.filesystem.ini_file_node import IniFileNode from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import FileStudyTree from antarest.study.storage.rawstudy.raw_study_service import RawStudyService from antarest.study.storage.storage_service import StudyStorageService @@ -86,7 +87,7 @@ def study_uuid_fixture(self, db_session: Session) -> str: raw_study = RawStudy( id=str(uuid.uuid4()), name="Dummy", - version="850", + version="860", # version 860 is required for the storage feature author="John Smith", created_at=datetime.datetime.now(datetime.timezone.utc), updated_at=datetime.datetime.now(datetime.timezone.utc), @@ -254,6 +255,72 @@ def test_get_st_storage__nominal_case( } assert actual == expected + # noinspection SpellCheckingInspection + def test_update_storage__nominal_case( + self, + db_session: Session, + study_storage_service: StudyStorageService, + study_uuid: str, + ) -> None: + """ + Test the `update_st_storage` method of the `STStorageManager` class under nominal conditions. + + This test verifies that the `update_st_storage` method correctly updates the storage fields + for a specific study, area, and storage ID combination. + + Args: + db_session: A database session fixture. + study_storage_service: A study storage service fixture. + study_uuid: The UUID of the study to be tested. + """ + + # The study must be fetched from the database + study: RawStudy = db_session.query(Study).get(study_uuid) + + # Prepare the mocks + storage = study_storage_service.get_storage(study) + file_study = storage.get_raw(study) + ini_file_node = IniFileNode(context=Mock(), config=Mock()) + file_study.tree = Mock( + spec=FileStudyTree, + get=Mock(return_value=LIST_CFG["storage1"]), + get_node=Mock(return_value=ini_file_node), + ) + + # Given the following arguments + manager = STStorageManager(study_storage_service) + + # Run the method being tested + edit_form = STStorageInput(initial_level=0, initial_level_optim=False) + manager.update_storage(study, area_id="West", storage_id="storage1", form=edit_form) + + # Assert that the storage fields have been updated + # + # Currently, the method used to update the fields is the `UpdateConfig` command + # which only does a partial update of the configuration file: only the fields + # that are explicitly mentioned in the form are updated. The other fields are left unchanged. + # + # The effective update of the fields is done by the `save` method of the `IniFileNode` class. + # The signature of the `save` method is: `save(self, value: Any, path: Sequence[str]) -> None` + + assert file_study.tree.save.call_count == 2 + + # Fields "initiallevel" and "initialleveloptim" could be updated in any order. + # We construct a *set* of the actual calls to the `save` method and compare it + # to the expected set of calls. + actual = {(call_args[0][0], tuple(call_args[0][1])) for call_args in file_study.tree.save.call_args_list} + expected = { + ( + str(0.0), + ("input", "st-storage", "clusters", "West", "list", "storage1", "initiallevel"), + ), + ( + str(False), + ("input", "st-storage", "clusters", "West", "list", "storage1", "initialleveloptim"), + ), + } + assert actual == expected + def test_get_st_storage__config_not_found( self, db_session: Session, From f9e7bd84aa955e86e318baf9518e0d0c1b8200fe Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 22:45:51 +0100 Subject: [PATCH 10/12] style(st-storage): use `t` alias to import `typing` --- .../business/areas/st_storage_management.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index 02f7af2e07..d18dce9f9c 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -1,7 +1,7 @@ import functools import json import operator -from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence +import typing as t import numpy as np from pydantic import BaseModel, Extra, root_validator, validator @@ -44,7 +44,7 @@ class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass, use_no class Config: @staticmethod - def schema_extra(schema: MutableMapping[str, Any]) -> None: + def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = STStorageInput( name="Siemens Battery", group=STStorageGroup.BATTERY, @@ -64,7 +64,7 @@ class STStorageCreation(STStorageInput): # noinspection Pydantic @validator("name", pre=True) - def validate_name(cls, name: Optional[str]) -> str: + def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. """ @@ -86,7 +86,7 @@ class STStorageOutput(STStorageConfig): class Config: @staticmethod - def schema_extra(schema: MutableMapping[str, Any]) -> None: + def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = STStorageOutput( id="siemens_battery", name="Siemens Battery", @@ -99,7 +99,7 @@ def schema_extra(schema: MutableMapping[str, Any]) -> None: ) @classmethod - def from_config(cls, storage_id: str, config: Mapping[str, Any]) -> "STStorageOutput": + def from_config(cls, storage_id: str, config: t.Mapping[str, t.Any]) -> "STStorageOutput": storage = STStorageConfig(**config, id=storage_id) values = storage.dict(by_alias=False) return cls(**values) @@ -126,12 +126,12 @@ class STStorageMatrix(BaseModel): class Config: extra = Extra.forbid - data: List[List[float]] - index: List[int] - columns: List[int] + data: t.List[t.List[float]] + index: t.List[int] + columns: t.List[int] @validator("data") - def validate_time_series(cls, data: List[List[float]]) -> List[List[float]]: + def validate_time_series(cls, data: t.List[t.List[float]]) -> t.List[t.List[float]]: """ Validator to check the integrity of the time series data. @@ -189,7 +189,9 @@ def validate_time_series(cls, matrix: STStorageMatrix) -> STStorageMatrix: return matrix @root_validator() - def validate_rule_curve(cls, values: MutableMapping[str, STStorageMatrix]) -> MutableMapping[str, STStorageMatrix]: + def validate_rule_curve( + cls, values: t.MutableMapping[str, STStorageMatrix] + ) -> t.MutableMapping[str, STStorageMatrix]: """ Validator to ensure 'lower_rule_curve' values are less than or equal to 'upper_rule_curve' values. @@ -275,7 +277,7 @@ def get_storages( self, study: Study, area_id: str, - ) -> Sequence[STStorageOutput]: + ) -> t.Sequence[STStorageOutput]: """ Get the list of short-term storage configurations for the given `study`, and `area_id`. @@ -375,7 +377,7 @@ def update_storage( new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) # create the dict containing the new values using aliases - data: Dict[str, Any] = { + data: t.Dict[str, t.Any] = { field.alias: new_data[field.alias] for field_name, field in new_config.__fields__.items() if field_name in new_values @@ -396,7 +398,7 @@ def delete_storages( self, study: Study, area_id: str, - storage_ids: Sequence[str], + storage_ids: t.Sequence[str], ) -> None: """ Delete short-term storage configurations form the given study and area_id. @@ -444,7 +446,7 @@ def _get_matrix_obj( area_id: str, storage_id: str, ts_name: STStorageTimeSeries, - ) -> MutableMapping[str, Any]: + ) -> t.MutableMapping[str, t.Any]: file_study = self._get_file_study(study) path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name) try: @@ -480,7 +482,7 @@ def _save_matrix_obj( area_id: str, storage_id: str, ts_name: STStorageTimeSeries, - matrix_obj: Dict[str, Any], + matrix_obj: t.Dict[str, t.Any], ) -> None: file_study = self._get_file_study(study) path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name) From c0f4373f6fb51ed409d3dfc96cdd83568116d40c Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 23:05:34 +0100 Subject: [PATCH 11/12] fix(thermal,renewable): correct the update command --- .../business/areas/renewable_management.py | 42 ++++++++++--------- .../business/areas/thermal_management.py | 35 +++++++--------- 2 files changed, 37 insertions(+), 40 deletions(-) diff --git a/antarest/study/business/areas/renewable_management.py b/antarest/study/business/areas/renewable_management.py index 48f349e527..ab9a2e9802 100644 --- a/antarest/study/business/areas/renewable_management.py +++ b/antarest/study/business/areas/renewable_management.py @@ -199,7 +199,11 @@ def get_cluster(self, study: Study, area_id: str, cluster_id: str) -> RenewableC return create_renewable_output(study.version, cluster_id, cluster) def update_cluster( - self, study: Study, area_id: str, cluster_id: str, cluster_data: RenewableClusterInput + self, + study: Study, + area_id: str, + cluster_id: str, + cluster_data: RenewableClusterInput, ) -> RenewableClusterOutput: """ Updates the configuration of an existing cluster within an area in the study. @@ -225,33 +229,31 @@ def update_cluster( values = file_study.tree.get(path.split("/"), depth=1) except KeyError: raise ClusterNotFound(cluster_id) from None + else: + old_config = create_renewable_config(study_version, **values) - # merge old and new values - updated_values = { - **create_renewable_config(study_version, **values).dict(exclude={"id"}), - **cluster_data.dict(by_alias=False, exclude_none=True), - "id": cluster_id, - } - new_config = create_renewable_config(study_version, **updated_values) + # use Python values to synchronize Config and Form values + new_values = cluster_data.dict(by_alias=False, exclude_none=True) + new_config = old_config.copy(exclude={"id"}, update=new_values) new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) - data = { + # create the dict containing the new values using aliases + data: t.Dict[str, t.Any] = { field.alias: new_data[field.alias] for field_name, field in new_config.__fields__.items() - if field_name not in {"id"} - and (field_name in updated_values or getattr(new_config, field_name) != field.get_default()) + if field_name in new_values } - command = UpdateConfig( - target=path, - data=data, - command_context=self.storage_service.variant_study_service.command_factory.command_context, - ) + # create the update config commands with the modified data + command_context = self.storage_service.variant_study_service.command_factory.command_context + commands = [ + UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context) + for key, value in data.items() + ] + execute_or_add_commands(study, file_study, commands, self.storage_service) - # fixme: The `file_study` is already retrieved at the beginning of the function. - file_study = self.storage_service.get_storage(study).get_raw(study) - execute_or_add_commands(study, file_study, [command], self.storage_service) - return RenewableClusterOutput(**new_config.dict(by_alias=False)) + values = new_config.dict(by_alias=False) + return RenewableClusterOutput(**values, id=cluster_id) def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None: """ diff --git a/antarest/study/business/areas/thermal_management.py b/antarest/study/business/areas/thermal_management.py index 29d3e566d4..dfcc52a2a0 100644 --- a/antarest/study/business/areas/thermal_management.py +++ b/antarest/study/business/areas/thermal_management.py @@ -245,32 +245,27 @@ def update_cluster( old_config = create_thermal_config(study_version, **values) # Use Python values to synchronize Config and Form values - old_values = old_config.dict(exclude={"id"}) new_values = cluster_data.dict(by_alias=False, exclude_none=True) - updated = {**old_values, **new_values} - new_config = create_thermal_config(study_version, **updated, id=cluster_id) + new_config = old_config.copy(exclude={"id"}, update=new_values) new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) - # Create the dict containing the old values (excluding defaults), - # and the updated values (including defaults) - data: t.Dict[str, t.Any] = {} - for field_name, field in new_config.__fields__.items(): - if field_name in {"id"}: - continue - value = getattr(new_config, field_name) - if field_name in new_values or value != field.get_default(): - # use the JSON-converted value - data[field.alias] = new_data[field.alias] - - # create the update config command with the modified data + # create the dict containing the new values using aliases + data: t.Dict[str, t.Any] = { + field.alias: new_data[field.alias] + for field_name, field in new_config.__fields__.items() + if field_name in new_values + } + + # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context - command = UpdateConfig(target=path, data=data, command_context=command_context) - # fixme: The `file_study` is already retrieved at the beginning of the function. - file_study = self.storage_service.get_storage(study).get_raw(study) - execute_or_add_commands(study, file_study, [command], self.storage_service) + commands = [ + UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context) + for key, value in data.items() + ] + execute_or_add_commands(study, file_study, commands, self.storage_service) values = new_config.dict(by_alias=False) - return ThermalClusterOutput(**values) + return ThermalClusterOutput(**values, id=cluster_id) def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None: """ From f54bb94cf08ddc1136dfd6ed4338c7f6b0ee9360 Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE Date: Mon, 5 Feb 2024 23:44:47 +0100 Subject: [PATCH 12/12] test(st-storage): correct default value for `initial_level` --- tests/variantstudy/model/command/test_create_st_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/variantstudy/model/command/test_create_st_storage.py b/tests/variantstudy/model/command/test_create_st_storage.py index c430b92eeb..c35ea37665 100644 --- a/tests/variantstudy/model/command/test_create_st_storage.py +++ b/tests/variantstudy/model/command/test_create_st_storage.py @@ -311,7 +311,7 @@ def test_apply__nominal_case(self, recent_study: FileStudy, command_context: Com "storage1": { "efficiency": 0.94, "group": "Battery", - # "initiallevel": 0, # default value is 0 + "initiallevel": 0.5, "initialleveloptim": True, "injectionnominalcapacity": 1500, "name": "Storage1",