Skip to content

Commit

Permalink
fix(api-model): correct AllOptionalMetaclass for field validation i…
Browse files Browse the repository at this point in the history
…n form models (#1924)

Merge pull request #1924 from AntaresSimulatorTeam/bugfix/ensure-validation-in-AllOptionalMetaclass
  • Loading branch information
laurent-laporte-pro authored Feb 5, 2024
2 parents 938e603 + f54bb94 commit d11b7a5
Show file tree
Hide file tree
Showing 23 changed files with 860 additions and 245 deletions.
2 changes: 1 addition & 1 deletion antarest/launcher/ssh_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import socket
import shlex
import socket
from typing import Any, List, Tuple

import paramiko
Expand Down
25 changes: 12 additions & 13 deletions antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, Optional, Sequence, Tuple, cast

import pydantic
import sqlalchemy.ext.baked # type: ignore
import uvicorn # type: ignore
import uvicorn.config # type: ignore
from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -56,18 +55,16 @@ class PathType:
which specify whether the path argument must exist, whether it can be a file,
and whether it can be a directory, respectively.
Example Usage:
Example Usage::
```python
import argparse
from antarest.main import PathType
import argparse
from antarest.main import PathType
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=PathType(file_ok=True, exists=True))
args = parser.parse_args()
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=PathType(file_ok=True, exists=True))
args = parser.parse_args()
print(args.input)
```
print(args.input)
In the above example, `PathType` is used to specify the type of the `--input`
argument for the `argparse` parser. The argument must be an existing file path.
Expand Down Expand Up @@ -401,9 +398,11 @@ def handle_all_exception(request: Request, exc: Exception) -> Any:
application.add_middleware(
RateLimitMiddleware,
authenticate=auth_manager.create_auth_function(),
backend=RedisBackend(config.redis.host, config.redis.port, 1, config.redis.password)
if config.redis is not None
else MemoryBackend(),
backend=(
MemoryBackend()
if config.redis is None
else RedisBackend(config.redis.host, config.redis.port, 1, config.redis.password)
),
config=RATE_LIMIT_CONFIG,
)

Expand Down
2 changes: 1 addition & 1 deletion antarest/study/business/areas/hydro_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def update_inflow_structure(self, study: Study, area_id: str, values: InflowStru
values: The new inflow structure values to be updated.
Raises:
RequestValidationError: If the provided `values` parameter is None or invalid.
ValidationError: If the provided `values` parameter is None or invalid.
"""
# NOTE: Updates only "intermonthly-correlation" due to current model scope.
path = INFLOW_PATH.format(area_id=area_id)
Expand Down
45 changes: 24 additions & 21 deletions antarest/study/business/areas/renewable_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TimeSeriesInterpretation(EnumIgnoreCase):


@camel_case_model
class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass):
class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the data structure required to edit an existing renewable cluster.
"""
Expand Down Expand Up @@ -76,7 +76,7 @@ def to_config(self, study_version: t.Union[str, int]) -> RenewableConfigType:


@camel_case_model
class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass):
class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the output data structure to display the details of a renewable cluster.
"""
Expand Down Expand Up @@ -199,7 +199,11 @@ def get_cluster(self, study: Study, area_id: str, cluster_id: str) -> RenewableC
return create_renewable_output(study.version, cluster_id, cluster)

def update_cluster(
self, study: Study, area_id: str, cluster_id: str, cluster_data: RenewableClusterInput
self,
study: Study,
area_id: str,
cluster_id: str,
cluster_data: RenewableClusterInput,
) -> RenewableClusterOutput:
"""
Updates the configuration of an existing cluster within an area in the study.
Expand All @@ -225,32 +229,31 @@ def update_cluster(
values = file_study.tree.get(path.split("/"), depth=1)
except KeyError:
raise ClusterNotFound(cluster_id) from None
else:
old_config = create_renewable_config(study_version, **values)

# merge old and new values
updated_values = {
**create_renewable_config(study_version, **values).dict(exclude={"id"}),
**cluster_data.dict(by_alias=False, exclude_none=True),
"id": cluster_id,
}
new_config = create_renewable_config(study_version, **updated_values)
# use Python values to synchronize Config and Form values
new_values = cluster_data.dict(by_alias=False, exclude_none=True)
new_config = old_config.copy(exclude={"id"}, update=new_values)
new_data = json.loads(new_config.json(by_alias=True, exclude={"id"}))

data = {
# create the dict containing the new values using aliases
data: t.Dict[str, t.Any] = {
field.alias: new_data[field.alias]
for field_name, field in new_config.__fields__.items()
if field_name not in {"id"}
and (field_name in updated_values or getattr(new_config, field_name) != field.get_default())
if field_name in new_values
}

command = UpdateConfig(
target=path,
data=data,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
)
# create the update config commands with the modified data
command_context = self.storage_service.variant_study_service.command_factory.command_context
commands = [
UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context)
for key, value in data.items()
]
execute_or_add_commands(study, file_study, commands, self.storage_service)

file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
return RenewableClusterOutput(**new_config.dict(by_alias=False))
values = new_config.dict(by_alias=False)
return RenewableClusterOutput(**values, id=cluster_id)

def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None:
"""
Expand Down
79 changes: 45 additions & 34 deletions antarest/study/business/areas/st_storage_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import json
import operator
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence
import typing as t

import numpy as np
from pydantic import BaseModel, Extra, root_validator, validator
Expand Down Expand Up @@ -37,14 +37,14 @@


@camel_case_model
class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass):
class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the form used to EDIT an existing short-term storage.
"""

class Config:
@staticmethod
def schema_extra(schema: MutableMapping[str, Any]) -> None:
def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None:
schema["example"] = STStorageInput(
name="Siemens Battery",
group=STStorageGroup.BATTERY,
Expand All @@ -64,7 +64,7 @@ class STStorageCreation(STStorageInput):

# noinspection Pydantic
@validator("name", pre=True)
def validate_name(cls, name: Optional[str]) -> str:
def validate_name(cls, name: t.Optional[str]) -> str:
"""
Validator to check if the name is not empty.
"""
Expand All @@ -86,7 +86,7 @@ class STStorageOutput(STStorageConfig):

class Config:
@staticmethod
def schema_extra(schema: MutableMapping[str, Any]) -> None:
def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None:
schema["example"] = STStorageOutput(
id="siemens_battery",
name="Siemens Battery",
Expand All @@ -99,7 +99,7 @@ def schema_extra(schema: MutableMapping[str, Any]) -> None:
)

@classmethod
def from_config(cls, storage_id: str, config: Mapping[str, Any]) -> "STStorageOutput":
def from_config(cls, storage_id: str, config: t.Mapping[str, t.Any]) -> "STStorageOutput":
storage = STStorageConfig(**config, id=storage_id)
values = storage.dict(by_alias=False)
return cls(**values)
Expand All @@ -126,12 +126,12 @@ class STStorageMatrix(BaseModel):
class Config:
extra = Extra.forbid

data: List[List[float]]
index: List[int]
columns: List[int]
data: t.List[t.List[float]]
index: t.List[int]
columns: t.List[int]

@validator("data")
def validate_time_series(cls, data: List[List[float]]) -> List[List[float]]:
def validate_time_series(cls, data: t.List[t.List[float]]) -> t.List[t.List[float]]:
"""
Validator to check the integrity of the time series data.
Expand Down Expand Up @@ -189,7 +189,9 @@ def validate_time_series(cls, matrix: STStorageMatrix) -> STStorageMatrix:
return matrix

@root_validator()
def validate_rule_curve(cls, values: MutableMapping[str, STStorageMatrix]) -> MutableMapping[str, STStorageMatrix]:
def validate_rule_curve(
cls, values: t.MutableMapping[str, STStorageMatrix]
) -> t.MutableMapping[str, STStorageMatrix]:
"""
Validator to ensure 'lower_rule_curve' values are less than
or equal to 'upper_rule_curve' values.
Expand Down Expand Up @@ -275,7 +277,7 @@ def get_storages(
self,
study: Study,
area_id: str,
) -> Sequence[STStorageOutput]:
) -> t.Sequence[STStorageOutput]:
"""
Get the list of short-term storage configurations for the given `study`, and `area_id`.
Expand Down Expand Up @@ -347,7 +349,20 @@ def update_storage(
Updated form of short-term storage.
"""
study_version = study.version

# review: reading the configuration poses a problem for variants,
# because it requires generating a snapshot, which takes time.
# This reading could be avoided if we don't need the previous values
# (no cross-field validation, no default values, etc.).
# In return, we won't be able to return a complete `STStorageOutput` object.
# So, we need to make sure the frontend doesn't need the missing fields.
# This missing information could also be a problem for the API users.
# The solution would be to avoid reading the configuration if the study is a variant
# (we then use the default values), otherwise, for a RAW study, we read the configuration
# and update the modified values.

file_study = self._get_file_study(study)

path = STORAGE_LIST_PATH.format(area_id=area_id, storage_id=storage_id)
try:
values = file_study.tree.get(path.split("/"), depth=1)
Expand All @@ -357,37 +372,33 @@ def update_storage(
old_config = create_st_storage_config(study_version, **values)

# use Python values to synchronize Config and Form values
old_values = old_config.dict(exclude={"id"})
new_values = form.dict(by_alias=False, exclude_none=True)
updated = {**old_values, **new_values}
new_config = create_st_storage_config(study_version, **updated, id=storage_id)
new_config = old_config.copy(exclude={"id"}, update=new_values)
new_data = json.loads(new_config.json(by_alias=True, exclude={"id"}))

# create the dict containing the old values (excluding defaults),
# the updated values (including defaults)
data: Dict[str, Any] = {}
for field_name, field in new_config.__fields__.items():
if field_name in {"id"}:
continue
value = getattr(new_config, field_name)
if field_name in new_values or value != field.get_default():
# use the JSON-converted value
data[field.alias] = new_data[field.alias]

# create the update config command with the modified data
# create the dict containing the new values using aliases
data: t.Dict[str, t.Any] = {
field.alias: new_data[field.alias]
for field_name, field in new_config.__fields__.items()
if field_name in new_values
}

# create the update config commands with the modified data
command_context = self.storage_service.variant_study_service.command_factory.command_context
command = UpdateConfig(target=path, data=data, command_context=command_context)
file_study = self._get_file_study(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
commands = [
UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context)
for key, value in data.items()
]
execute_or_add_commands(study, file_study, commands, self.storage_service)

values = new_config.dict(by_alias=False)
return STStorageOutput(**values)
return STStorageOutput(**values, id=storage_id)

def delete_storages(
self,
study: Study,
area_id: str,
storage_ids: Sequence[str],
storage_ids: t.Sequence[str],
) -> None:
"""
Delete short-term storage configurations form the given study and area_id.
Expand Down Expand Up @@ -435,7 +446,7 @@ def _get_matrix_obj(
area_id: str,
storage_id: str,
ts_name: STStorageTimeSeries,
) -> MutableMapping[str, Any]:
) -> t.MutableMapping[str, t.Any]:
file_study = self._get_file_study(study)
path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name)
try:
Expand Down Expand Up @@ -471,7 +482,7 @@ def _save_matrix_obj(
area_id: str,
storage_id: str,
ts_name: STStorageTimeSeries,
matrix_obj: Dict[str, Any],
matrix_obj: t.Dict[str, t.Any],
) -> None:
file_study = self._get_file_study(study)
path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name)
Expand Down
38 changes: 17 additions & 21 deletions antarest/study/business/areas/thermal_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


@camel_case_model
class ThermalClusterInput(Thermal860Properties, metaclass=AllOptionalMetaclass):
class ThermalClusterInput(Thermal860Properties, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the data structure required to edit an existing thermal cluster within a study.
"""
Expand Down Expand Up @@ -70,7 +70,7 @@ def to_config(self, study_version: t.Union[str, int]) -> ThermalConfigType:


@camel_case_model
class ThermalClusterOutput(Thermal860Config, metaclass=AllOptionalMetaclass):
class ThermalClusterOutput(Thermal860Config, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the output data structure to display the details of a thermal cluster within a study.
"""
Expand Down Expand Up @@ -245,31 +245,27 @@ def update_cluster(
old_config = create_thermal_config(study_version, **values)

# Use Python values to synchronize Config and Form values
old_values = old_config.dict(exclude={"id"})
new_values = cluster_data.dict(by_alias=False, exclude_none=True)
updated = {**old_values, **new_values}
new_config = create_thermal_config(study_version, **updated, id=cluster_id)
new_config = old_config.copy(exclude={"id"}, update=new_values)
new_data = json.loads(new_config.json(by_alias=True, exclude={"id"}))

# Create the dict containing the old values (excluding defaults),
# and the updated values (including defaults)
data: t.Dict[str, t.Any] = {}
for field_name, field in new_config.__fields__.items():
if field_name in {"id"}:
continue
value = getattr(new_config, field_name)
if field_name in new_values or value != field.get_default():
# use the JSON-converted value
data[field.alias] = new_data[field.alias]

# create the update config command with the modified data
# create the dict containing the new values using aliases
data: t.Dict[str, t.Any] = {
field.alias: new_data[field.alias]
for field_name, field in new_config.__fields__.items()
if field_name in new_values
}

# create the update config commands with the modified data
command_context = self.storage_service.variant_study_service.command_factory.command_context
command = UpdateConfig(target=path, data=data, command_context=command_context)
file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
commands = [
UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context)
for key, value in data.items()
]
execute_or_add_commands(study, file_study, commands, self.storage_service)

values = new_config.dict(by_alias=False)
return ThermalClusterOutput(**values)
return ThermalClusterOutput(**values, id=cluster_id)

def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None:
"""
Expand Down
Loading

0 comments on commit d11b7a5

Please sign in to comment.