Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(api-model): correct AllOptionalMetaclass for field validation in form models #1924

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion antarest/launcher/ssh_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import socket
import shlex
import socket
from typing import Any, List, Tuple

import paramiko
Expand Down
25 changes: 12 additions & 13 deletions antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Dict, Optional, Sequence, Tuple, cast

import pydantic
import sqlalchemy.ext.baked # type: ignore
import uvicorn # type: ignore
import uvicorn.config # type: ignore
from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -56,18 +55,16 @@ class PathType:
which specify whether the path argument must exist, whether it can be a file,
and whether it can be a directory, respectively.

Example Usage:
Example Usage::

```python
import argparse
from antarest.main import PathType
import argparse
from antarest.main import PathType

parser = argparse.ArgumentParser()
parser.add_argument('--input', type=PathType(file_ok=True, exists=True))
args = parser.parse_args()
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=PathType(file_ok=True, exists=True))
args = parser.parse_args()

print(args.input)
```
print(args.input)

In the above example, `PathType` is used to specify the type of the `--input`
argument for the `argparse` parser. The argument must be an existing file path.
Expand Down Expand Up @@ -401,9 +398,11 @@ def handle_all_exception(request: Request, exc: Exception) -> Any:
application.add_middleware(
RateLimitMiddleware,
authenticate=auth_manager.create_auth_function(),
backend=RedisBackend(config.redis.host, config.redis.port, 1, config.redis.password)
if config.redis is not None
else MemoryBackend(),
backend=(
MemoryBackend()
if config.redis is None
else RedisBackend(config.redis.host, config.redis.port, 1, config.redis.password)
),
config=RATE_LIMIT_CONFIG,
)

Expand Down
2 changes: 1 addition & 1 deletion antarest/study/business/areas/hydro_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def update_inflow_structure(self, study: Study, area_id: str, values: InflowStru
values: The new inflow structure values to be updated.

Raises:
RequestValidationError: If the provided `values` parameter is None or invalid.
ValidationError: If the provided `values` parameter is None or invalid.
"""
# NOTE: Updates only "intermonthly-correlation" due to current model scope.
path = INFLOW_PATH.format(area_id=area_id)
Expand Down
45 changes: 24 additions & 21 deletions antarest/study/business/areas/renewable_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TimeSeriesInterpretation(EnumIgnoreCase):


@camel_case_model
class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass):
class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the data structure required to edit an existing renewable cluster.
"""
Expand Down Expand Up @@ -76,7 +76,7 @@ def to_config(self, study_version: t.Union[str, int]) -> RenewableConfigType:


@camel_case_model
class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass):
class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the output data structure to display the details of a renewable cluster.
"""
Expand Down Expand Up @@ -199,7 +199,11 @@ def get_cluster(self, study: Study, area_id: str, cluster_id: str) -> RenewableC
return create_renewable_output(study.version, cluster_id, cluster)

def update_cluster(
self, study: Study, area_id: str, cluster_id: str, cluster_data: RenewableClusterInput
self,
study: Study,
area_id: str,
cluster_id: str,
cluster_data: RenewableClusterInput,
) -> RenewableClusterOutput:
"""
Updates the configuration of an existing cluster within an area in the study.
Expand All @@ -225,32 +229,31 @@ def update_cluster(
values = file_study.tree.get(path.split("/"), depth=1)
except KeyError:
raise ClusterNotFound(cluster_id) from None
else:
old_config = create_renewable_config(study_version, **values)

# merge old and new values
updated_values = {
**create_renewable_config(study_version, **values).dict(exclude={"id"}),
**cluster_data.dict(by_alias=False, exclude_none=True),
"id": cluster_id,
}
new_config = create_renewable_config(study_version, **updated_values)
# use Python values to synchronize Config and Form values
new_values = cluster_data.dict(by_alias=False, exclude_none=True)
new_config = old_config.copy(exclude={"id"}, update=new_values)
new_data = json.loads(new_config.json(by_alias=True, exclude={"id"}))

data = {
# create the dict containing the new values using aliases
data: t.Dict[str, t.Any] = {
field.alias: new_data[field.alias]
for field_name, field in new_config.__fields__.items()
if field_name not in {"id"}
and (field_name in updated_values or getattr(new_config, field_name) != field.get_default())
if field_name in new_values
}

command = UpdateConfig(
target=path,
data=data,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
)
# create the update config commands with the modified data
command_context = self.storage_service.variant_study_service.command_factory.command_context
commands = [
UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context)
for key, value in data.items()
]
execute_or_add_commands(study, file_study, commands, self.storage_service)

file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
return RenewableClusterOutput(**new_config.dict(by_alias=False))
values = new_config.dict(by_alias=False)
return RenewableClusterOutput(**values, id=cluster_id)

def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None:
"""
Expand Down
79 changes: 45 additions & 34 deletions antarest/study/business/areas/st_storage_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import json
import operator
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence
import typing as t

import numpy as np
from pydantic import BaseModel, Extra, root_validator, validator
Expand Down Expand Up @@ -37,14 +37,14 @@


@camel_case_model
class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass):
class STStorageInput(STStorageProperties, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the form used to EDIT an existing short-term storage.
"""

class Config:
@staticmethod
def schema_extra(schema: MutableMapping[str, Any]) -> None:
def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None:
schema["example"] = STStorageInput(
name="Siemens Battery",
group=STStorageGroup.BATTERY,
Expand All @@ -64,7 +64,7 @@ class STStorageCreation(STStorageInput):

# noinspection Pydantic
@validator("name", pre=True)
def validate_name(cls, name: Optional[str]) -> str:
def validate_name(cls, name: t.Optional[str]) -> str:
"""
Validator to check if the name is not empty.
"""
Expand All @@ -86,7 +86,7 @@ class STStorageOutput(STStorageConfig):

class Config:
@staticmethod
def schema_extra(schema: MutableMapping[str, Any]) -> None:
def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None:
schema["example"] = STStorageOutput(
id="siemens_battery",
name="Siemens Battery",
Expand All @@ -99,7 +99,7 @@ def schema_extra(schema: MutableMapping[str, Any]) -> None:
)

@classmethod
def from_config(cls, storage_id: str, config: Mapping[str, Any]) -> "STStorageOutput":
def from_config(cls, storage_id: str, config: t.Mapping[str, t.Any]) -> "STStorageOutput":
storage = STStorageConfig(**config, id=storage_id)
values = storage.dict(by_alias=False)
return cls(**values)
Expand All @@ -126,12 +126,12 @@ class STStorageMatrix(BaseModel):
class Config:
extra = Extra.forbid

data: List[List[float]]
index: List[int]
columns: List[int]
data: t.List[t.List[float]]
index: t.List[int]
columns: t.List[int]

@validator("data")
def validate_time_series(cls, data: List[List[float]]) -> List[List[float]]:
def validate_time_series(cls, data: t.List[t.List[float]]) -> t.List[t.List[float]]:
"""
Validator to check the integrity of the time series data.

Expand Down Expand Up @@ -189,7 +189,9 @@ def validate_time_series(cls, matrix: STStorageMatrix) -> STStorageMatrix:
return matrix

@root_validator()
def validate_rule_curve(cls, values: MutableMapping[str, STStorageMatrix]) -> MutableMapping[str, STStorageMatrix]:
def validate_rule_curve(
cls, values: t.MutableMapping[str, STStorageMatrix]
) -> t.MutableMapping[str, STStorageMatrix]:
"""
Validator to ensure 'lower_rule_curve' values are less than
or equal to 'upper_rule_curve' values.
Expand Down Expand Up @@ -275,7 +277,7 @@ def get_storages(
self,
study: Study,
area_id: str,
) -> Sequence[STStorageOutput]:
) -> t.Sequence[STStorageOutput]:
"""
Get the list of short-term storage configurations for the given `study`, and `area_id`.

Expand Down Expand Up @@ -347,7 +349,20 @@ def update_storage(
Updated form of short-term storage.
"""
study_version = study.version

# review: reading the configuration poses a problem for variants,
# because it requires generating a snapshot, which takes time.
# This reading could be avoided if we don't need the previous values
# (no cross-field validation, no default values, etc.).
# In return, we won't be able to return a complete `STStorageOutput` object.
# So, we need to make sure the frontend doesn't need the missing fields.
# This missing information could also be a problem for the API users.
# The solution would be to avoid reading the configuration if the study is a variant
# (we then use the default values), otherwise, for a RAW study, we read the configuration
# and update the modified values.

file_study = self._get_file_study(study)

path = STORAGE_LIST_PATH.format(area_id=area_id, storage_id=storage_id)
try:
values = file_study.tree.get(path.split("/"), depth=1)
Expand All @@ -357,37 +372,33 @@ def update_storage(
old_config = create_st_storage_config(study_version, **values)

# use Python values to synchronize Config and Form values
old_values = old_config.dict(exclude={"id"})
new_values = form.dict(by_alias=False, exclude_none=True)
updated = {**old_values, **new_values}
new_config = create_st_storage_config(study_version, **updated, id=storage_id)
new_config = old_config.copy(exclude={"id"}, update=new_values)
new_data = json.loads(new_config.json(by_alias=True, exclude={"id"}))

# create the dict containing the old values (excluding defaults),
# the updated values (including defaults)
data: Dict[str, Any] = {}
for field_name, field in new_config.__fields__.items():
if field_name in {"id"}:
continue
value = getattr(new_config, field_name)
if field_name in new_values or value != field.get_default():
# use the JSON-converted value
data[field.alias] = new_data[field.alias]

# create the update config command with the modified data
# create the dict containing the new values using aliases
data: t.Dict[str, t.Any] = {
field.alias: new_data[field.alias]
for field_name, field in new_config.__fields__.items()
if field_name in new_values
}

# create the update config commands with the modified data
command_context = self.storage_service.variant_study_service.command_factory.command_context
command = UpdateConfig(target=path, data=data, command_context=command_context)
file_study = self._get_file_study(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
commands = [
UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context)
for key, value in data.items()
]
execute_or_add_commands(study, file_study, commands, self.storage_service)

values = new_config.dict(by_alias=False)
return STStorageOutput(**values)
return STStorageOutput(**values, id=storage_id)

def delete_storages(
self,
study: Study,
area_id: str,
storage_ids: Sequence[str],
storage_ids: t.Sequence[str],
) -> None:
"""
Delete short-term storage configurations form the given study and area_id.
Expand Down Expand Up @@ -435,7 +446,7 @@ def _get_matrix_obj(
area_id: str,
storage_id: str,
ts_name: STStorageTimeSeries,
) -> MutableMapping[str, Any]:
) -> t.MutableMapping[str, t.Any]:
file_study = self._get_file_study(study)
path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name)
try:
Expand Down Expand Up @@ -471,7 +482,7 @@ def _save_matrix_obj(
area_id: str,
storage_id: str,
ts_name: STStorageTimeSeries,
matrix_obj: Dict[str, Any],
matrix_obj: t.Dict[str, t.Any],
) -> None:
file_study = self._get_file_study(study)
path = STORAGE_SERIES_PATH.format(area_id=area_id, storage_id=storage_id, ts_name=ts_name)
Expand Down
38 changes: 17 additions & 21 deletions antarest/study/business/areas/thermal_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


@camel_case_model
class ThermalClusterInput(Thermal860Properties, metaclass=AllOptionalMetaclass):
class ThermalClusterInput(Thermal860Properties, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the data structure required to edit an existing thermal cluster within a study.
"""
Expand Down Expand Up @@ -70,7 +70,7 @@ def to_config(self, study_version: t.Union[str, int]) -> ThermalConfigType:


@camel_case_model
class ThermalClusterOutput(Thermal860Config, metaclass=AllOptionalMetaclass):
class ThermalClusterOutput(Thermal860Config, metaclass=AllOptionalMetaclass, use_none=True):
"""
Model representing the output data structure to display the details of a thermal cluster within a study.
"""
Expand Down Expand Up @@ -245,31 +245,27 @@ def update_cluster(
old_config = create_thermal_config(study_version, **values)

# Use Python values to synchronize Config and Form values
old_values = old_config.dict(exclude={"id"})
new_values = cluster_data.dict(by_alias=False, exclude_none=True)
updated = {**old_values, **new_values}
new_config = create_thermal_config(study_version, **updated, id=cluster_id)
new_config = old_config.copy(exclude={"id"}, update=new_values)
new_data = json.loads(new_config.json(by_alias=True, exclude={"id"}))

# Create the dict containing the old values (excluding defaults),
# and the updated values (including defaults)
data: t.Dict[str, t.Any] = {}
for field_name, field in new_config.__fields__.items():
if field_name in {"id"}:
continue
value = getattr(new_config, field_name)
if field_name in new_values or value != field.get_default():
# use the JSON-converted value
data[field.alias] = new_data[field.alias]

# create the update config command with the modified data
# create the dict containing the new values using aliases
data: t.Dict[str, t.Any] = {
field.alias: new_data[field.alias]
for field_name, field in new_config.__fields__.items()
if field_name in new_values
}

# create the update config commands with the modified data
command_context = self.storage_service.variant_study_service.command_factory.command_context
command = UpdateConfig(target=path, data=data, command_context=command_context)
file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, [command], self.storage_service)
commands = [
UpdateConfig(target=f"{path}/{key}", data=value, command_context=command_context)
for key, value in data.items()
]
execute_or_add_commands(study, file_study, commands, self.storage_service)

values = new_config.dict(by_alias=False)
return ThermalClusterOutput(**values)
return ThermalClusterOutput(**values, id=cluster_id)

def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None:
"""
Expand Down
Loading
Loading