Skip to content

Commit

Permalink
refactor(bc): better handle optional values in BC commands
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Apr 4, 2024
1 parent 1b20ea7 commit 9b9903d
Show file tree
Hide file tree
Showing 14 changed files with 264 additions and 184 deletions.
94 changes: 94 additions & 0 deletions antarest/study/business/all_optional_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import typing as t

import pydantic.fields
import pydantic.main
from pydantic import BaseModel

from antarest.core.utils.string import to_camel_case


class AllOptionalMetaclass(pydantic.main.ModelMetaclass):
"""
Metaclass that makes all fields of a Pydantic model optional.
Usage:
class MyModel(BaseModel, metaclass=AllOptionalMetaclass):
field1: str
field2: int
...
Instances of the model can be created even if not all fields are provided during initialization.
Default values, when provided, are used unless `use_none` is set to `True`.
"""

def __new__(
cls: t.Type["AllOptionalMetaclass"],
name: str,
bases: t.Tuple[t.Type[t.Any], ...],
namespaces: t.Dict[str, t.Any],
use_none: bool = False,
**kwargs: t.Dict[str, t.Any],
) -> t.Any:
"""
Create a new instance of the metaclass.
Args:
name: Name of the class to create.
bases: Base classes of the class to create (a Pydantic model).
namespaces: namespace of the class to create that defines the fields of the model.
use_none: If `True`, the default value of the fields is set to `None`.
Note that this field is not part of the Pydantic model, but it is an extension.
**kwargs: Additional keyword arguments used by the metaclass.
"""
# Modify the annotations of the class (but not of the ancestor classes)
# in order to make all fields optional.
# If the current model inherits from another model, the annotations of the ancestor models
# are not modified, because the fields are already converted to `ModelField`.
annotations = namespaces.get("__annotations__", {})
for field_name, field_type in annotations.items():
if not field_name.startswith("__"):
# Making already optional fields optional is not a problem (nothing is changed).
annotations[field_name] = t.Optional[field_type]
namespaces["__annotations__"] = annotations

if use_none:
# Modify the namespace fields to set their default value to `None`.
for field_name, field_info in namespaces.items():
if isinstance(field_info, pydantic.fields.FieldInfo):
field_info.default = None
field_info.default_factory = None

# Create the class: all annotations are converted into `ModelField`.
instance = super().__new__(cls, name, bases, namespaces, **kwargs)

# Modify the inherited fields of the class to make them optional
# and set their default value to `None`.
model_field: pydantic.fields.ModelField
for field_name, model_field in instance.__fields__.items():
model_field.required = False
model_field.allow_none = True
if use_none:
model_field.default = None
model_field.default_factory = None
model_field.field_info.default = None

return instance


MODEL = t.TypeVar("MODEL", bound=t.Type[BaseModel])


def camel_case_model(model: MODEL) -> MODEL:
"""
This decorator can be used to modify a model to use camel case aliases.
Args:
model: The pydantic model to modify.
Returns:
The modified model.
"""
model.__config__.alias_generator = to_camel_case
for field_name, field in model.__fields__.items():
field.alias = to_camel_case(field_name)
return model
3 changes: 2 additions & 1 deletion antarest/study/business/areas/renewable_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from pydantic import validator

from antarest.core.exceptions import DuplicateRenewableCluster, RenewableClusterConfigNotFound, RenewableClusterNotFound
from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model
from antarest.study.business.enum_ignore_case import EnumIgnoreCase
from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands
from antarest.study.business.utils import execute_or_add_commands
from antarest.study.model import Study
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.config.renewable import (
Expand Down
3 changes: 2 additions & 1 deletion antarest/study/business/areas/st_storage_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
STStorageMatrixNotFound,
STStorageNotFound,
)
from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands
from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model
from antarest.study.business.utils import execute_or_add_commands
from antarest.study.model import Study
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.config.st_storage import (
Expand Down
3 changes: 2 additions & 1 deletion antarest/study/business/areas/thermal_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pydantic import validator

from antarest.core.exceptions import DuplicateThermalCluster, ThermalClusterConfigNotFound, ThermalClusterNotFound
from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands
from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model
from antarest.study.business.utils import execute_or_add_commands
from antarest.study.model import Study
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.config.thermal import (
Expand Down
119 changes: 69 additions & 50 deletions antarest/study/business/binding_constraint_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
NoConstraintError,
)
from antarest.core.utils.string import to_camel_case
from antarest.study.business.utils import AllOptionalMetaclass, camel_case_model, execute_or_add_commands
from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model
from antarest.study.business.utils import execute_or_add_commands
from antarest.study.model import Study
from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
Expand Down Expand Up @@ -48,6 +49,8 @@
from antarest.study.storage.variantstudy.model.command.update_binding_constraint import UpdateBindingConstraint
from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy

_TERM_MATRICES = ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]

logger = logging.getLogger(__name__)

DEFAULT_GROUP = "default"
Expand Down Expand Up @@ -298,17 +301,23 @@ def check_matrices_dimensions(cls, values: Dict[str, Any]) -> Dict[str, Any]:
class ConstraintOutputBase(BindingConstraintProperties):
id: str
name: str
terms: MutableSequence[ConstraintTerm] = Field(
default_factory=lambda: [],
)
terms: MutableSequence[ConstraintTerm] = Field(default_factory=lambda: [])


@camel_case_model
class ConstraintOutput870(ConstraintOutputBase):
class ConstraintOutput830(ConstraintOutputBase):
filter_year_by_year: str = ""
filter_synthesis: str = ""


@camel_case_model
class ConstraintOutput870(ConstraintOutput830):
group: str = DEFAULT_GROUP


ConstraintOutput = Union[ConstraintOutputBase, ConstraintOutput870]
# WARNING: Do not change the order of the following line, it is used to determine
# the type of the output constraint in the FastAPI endpoint.
ConstraintOutput = Union[ConstraintOutputBase, ConstraintOutput830, ConstraintOutput870]


def _validate_binding_constraints(file_study: FileStudy, bcs: Sequence[ConstraintOutput]) -> bool:
Expand Down Expand Up @@ -355,9 +364,7 @@ def __init__(
self.storage_service = storage_service

@staticmethod
def parse_and_add_terms(
key: str, value: Any, adapted_constraint: Union[ConstraintOutputBase, ConstraintOutput870]
) -> None:
def parse_and_add_terms(key: str, value: Any, adapted_constraint: ConstraintOutput) -> None:
"""Parse a single term from the constraint dictionary and add it to the adapted_constraint model."""
if "%" in key or "." in key:
separator = "%" if "%" in key else "."
Expand Down Expand Up @@ -393,24 +400,24 @@ def parse_and_add_terms(
@staticmethod
def constraint_model_adapter(constraint: Mapping[str, Any], version: int) -> ConstraintOutput:
"""
Adapts a constraint configuration to the appropriate version-specific format.
Adapts a binding constraint configuration to the appropriate model version.
Parameters:
- constraint: A dictionary or model representing the constraint to be adapted.
This can either be a dictionary coming from client input or an existing
model that needs reformatting.
- version: An integer indicating the target version of the study configuration. This is used to
Args:
constraint: A dictionary or model representing the constraint to be adapted.
This can either be a dictionary coming from client input or an existing
model that needs reformatting.
version: An integer indicating the target version of the study configuration. This is used to
determine which model class to instantiate and which default values to apply.
Returns:
- A new instance of either `ConstraintOutputBase` or `ConstraintOutput870`,
populated with the adapted values from the input constraint, and conforming to the
structure expected by the specified version.
A new instance of either `ConstraintOutputBase`, `ConstraintOutput830`, or `ConstraintOutput870`,
populated with the adapted values from the input constraint, and conforming to the
structure expected by the specified version.
Note:
This method is crucial for ensuring backward compatibility and future-proofing the application
as it evolves. It allows client-side data to be accurately represented within the config and
ensures data integrity when storing or retrieving constraint configurations from the database.
This method is crucial for ensuring backward compatibility and future-proofing the application
as it evolves. It allows client-side data to be accurately represented within the config and
ensures data integrity when storing or retrieving constraint configurations from the database.
"""

constraint_output = {
Expand All @@ -423,19 +430,20 @@ def constraint_model_adapter(constraint: Mapping[str, Any], version: int) -> Con
"terms": constraint.get("terms", []),
}

# TODO: Implement a model for version-specific fields. Output filters are sent regardless of the version.
if version >= 840:
constraint_output["filter_year_by_year"] = constraint.get("filter_year_by_year") or constraint.get(
"filter-year-by-year", ""
)
constraint_output["filter_synthesis"] = constraint.get("filter_synthesis") or constraint.get(
"filter-synthesis", ""
)

adapted_constraint: Union[ConstraintOutputBase, ConstraintOutput870]
if version >= 830:
_filter_year_by_year = constraint.get("filter_year_by_year") or constraint.get("filter-year-by-year", "")
_filter_synthesis = constraint.get("filter_synthesis") or constraint.get("filter-synthesis", "")
constraint_output["filter_year_by_year"] = _filter_year_by_year
constraint_output["filter_synthesis"] = _filter_synthesis
if version >= 870:
constraint_output["group"] = constraint.get("group", DEFAULT_GROUP)

# Choose the right model according to the version
adapted_constraint: ConstraintOutput
if version >= 870:
adapted_constraint = ConstraintOutput870(**constraint_output)
elif version >= 830:
adapted_constraint = ConstraintOutput830(**constraint_output)
else:
adapted_constraint = ConstraintOutputBase(**constraint_output)

Expand Down Expand Up @@ -648,17 +656,20 @@ def create_binding_constraint(
"time_step": data.time_step,
"operator": data.operator,
"coeffs": self.terms_to_coeffs(data.terms),
"values": data.values,
"less_term_matrix": data.less_term_matrix,
"equal_term_matrix": data.equal_term_matrix,
"greater_term_matrix": data.greater_term_matrix,
"filter_year_by_year": data.filter_year_by_year,
"filter_synthesis": data.filter_synthesis,
"comments": data.comments or "",
}

if version >= 830:
new_constraint["filter_year_by_year"] = data.filter_year_by_year or ""
new_constraint["filter_synthesis"] = data.filter_synthesis or ""

if version >= 870:
new_constraint["group"] = data.group or DEFAULT_GROUP
new_constraint["less_term_matrix"] = data.less_term_matrix
new_constraint["equal_term_matrix"] = data.equal_term_matrix
new_constraint["greater_term_matrix"] = data.greater_term_matrix
else:
new_constraint["values"] = data.values

command = CreateBindingConstraint(
**new_constraint, command_context=self.storage_service.variant_study_service.command_factory.command_context
Expand Down Expand Up @@ -699,12 +710,12 @@ def update_binding_constraint(
"comments": data.comments or existing_constraint.comments,
}

if study_version >= 840:
if isinstance(existing_constraint, ConstraintOutput830):
upd_constraint["filter_year_by_year"] = data.filter_year_by_year or existing_constraint.filter_year_by_year
upd_constraint["filter_synthesis"] = data.filter_synthesis or existing_constraint.filter_synthesis

if study_version >= 870:
upd_constraint["group"] = data.group or existing_constraint.group # type: ignore
if isinstance(existing_constraint, ConstraintOutput870):
upd_constraint["group"] = data.group or existing_constraint.group

args = {
**upd_constraint,
Expand All @@ -723,9 +734,7 @@ def update_binding_constraint(

# Validates the matrices. Needed when the study is a variant because we only append the command to the list
if isinstance(study, VariantStudy):
updated_matrices = [
term for term in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"] if getattr(data, term)
]
updated_matrices = [term for term in _TERM_MATRICES if getattr(data, term)]
command.validates_and_fills_matrices(
specific_matrices=updated_matrices, version=study_version, create=False
)
Expand Down Expand Up @@ -794,15 +803,20 @@ def update_constraint_term(

coeffs = {term.id: [term.weight, term.offset] if term.offset else [term.weight] for term in constraint_terms}

filter_year_by_year = constraint.filter_year_by_year if isinstance(constraint, ConstraintOutput830) else None
filter_synthesis = constraint.filter_synthesis if isinstance(constraint, ConstraintOutput830) else None
group = constraint.group if isinstance(constraint, ConstraintOutput870) else None

command = UpdateBindingConstraint(
id=constraint.id,
enabled=constraint.enabled,
time_step=constraint.time_step,
operator=constraint.operator,
coeffs=coeffs,
filter_year_by_year=constraint.filter_year_by_year,
filter_synthesis=constraint.filter_synthesis,
comments=constraint.comments,
filter_year_by_year=filter_year_by_year,
filter_synthesis=filter_synthesis,
group=group,
coeffs=coeffs,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
)
execute_or_add_commands(study, file_study, [command], self.storage_service)
Expand Down Expand Up @@ -840,15 +854,20 @@ def create_constraint_term(
if term.offset:
coeffs[term.id].append(term.offset)

filter_year_by_year = constraint.filter_year_by_year if isinstance(constraint, ConstraintOutput830) else None
filter_synthesis = constraint.filter_synthesis if isinstance(constraint, ConstraintOutput830) else None
group = constraint.group if isinstance(constraint, ConstraintOutput870) else None

command = UpdateBindingConstraint(
id=constraint.id,
enabled=constraint.enabled,
time_step=constraint.time_step,
operator=constraint.operator,
coeffs=coeffs,
comments=constraint.comments,
filter_year_by_year=constraint.filter_year_by_year,
filter_synthesis=constraint.filter_synthesis,
filter_year_by_year=filter_year_by_year,
filter_synthesis=filter_synthesis,
group=group,
coeffs=coeffs,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
)
execute_or_add_commands(study, file_study, [command], self.storage_service)
Expand Down Expand Up @@ -880,7 +899,7 @@ def _replace_matrices_according_to_frequency_and_version(
BindingConstraintFrequency.DAILY.value: default_bc_weekly_daily_87,
BindingConstraintFrequency.WEEKLY.value: default_bc_weekly_daily_87,
}[data.time_step].tolist()
for term in ["less_term_matrix", "equal_term_matrix", "greater_term_matrix"]:
for term in _TERM_MATRICES:
if term not in args:
args[term] = matrix
return args
Expand Down
3 changes: 2 additions & 1 deletion antarest/study/business/thematic_trimming_field_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import typing as t

from antarest.study.business.utils import AllOptionalMetaclass, FormFieldsBaseModel
from antarest.study.business.all_optional_meta import AllOptionalMetaclass
from antarest.study.business.utils import FormFieldsBaseModel


class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetaclass, use_none=True):
Expand Down
Loading

0 comments on commit 9b9903d

Please sign in to comment.