From 92f9e9636c89af94fba3e268a51bcae78c1d073a Mon Sep 17 00:00:00 2001 From: belthlemar Date: Thu, 29 Aug 2024 16:13:14 +0200 Subject: [PATCH] build(python): bump project dependencies --- antarest/core/cache/business/redis_cache.py | 2 +- antarest/core/config.py | 17 +- antarest/core/core_blueprint.py | 16 +- antarest/core/filesystem_blueprint.py | 28 +- antarest/core/filetransfer/model.py | 2 +- antarest/core/model.py | 2 +- antarest/core/permissions.py | 15 +- antarest/core/requests.py | 45 +- antarest/core/tasks/model.py | 41 +- antarest/core/tasks/service.py | 10 +- .../utils/fastapi_sqlalchemy/middleware.py | 4 +- antarest/core/version_info.py | 2 +- antarest/eventbus/business/redis_eventbus.py | 4 +- antarest/eventbus/web.py | 4 +- antarest/fastapi_jwt_auth/LICENSE | 21 + antarest/fastapi_jwt_auth/__init__.py | 7 + antarest/fastapi_jwt_auth/auth_config.py | 114 +++ antarest/fastapi_jwt_auth/auth_jwt.py | 836 ++++++++++++++++++ antarest/fastapi_jwt_auth/config.py | 90 ++ antarest/fastapi_jwt_auth/exceptions.py | 89 ++ antarest/gui.py | 16 +- .../launcher/adapters/abstractlauncher.py | 4 +- antarest/launcher/model.py | 33 +- antarest/launcher/service.py | 19 +- antarest/launcher/ssh_client.py | 2 +- antarest/launcher/ssh_config.py | 4 +- antarest/launcher/web.py | 31 +- antarest/login/auth.py | 6 +- antarest/login/ldap.py | 9 +- antarest/login/main.py | 4 +- antarest/login/service.py | 2 +- antarest/login/web.py | 24 +- antarest/main.py | 10 +- antarest/matrixstore/matrix_editor.py | 27 +- antarest/matrixstore/service.py | 3 +- .../business/adequacy_patch_management.py | 22 +- .../advanced_parameters_management.py | 52 +- antarest/study/business/all_optional_meta.py | 92 +- .../study/business/allocation_management.py | 40 +- antarest/study/business/area_management.py | 49 +- .../study/business/areas/hydro_management.py | 59 +- .../business/areas/properties_management.py | 24 +- .../business/areas/renewable_management.py | 70 +- .../business/areas/st_storage_management.py | 119 +-- .../business/areas/thermal_management.py | 87 +- .../business/binding_constraint_management.py | 68 +- antarest/study/business/config_management.py | 2 +- .../study/business/correlation_management.py | 37 +- antarest/study/business/district_manager.py | 27 +- antarest/study/business/general_management.py | 57 +- antarest/study/business/link_management.py | 11 +- .../study/business/optimization_management.py | 33 +- .../study/business/table_mode_management.py | 27 +- .../business/thematic_trimming_field_infos.py | 8 +- .../business/thematic_trimming_management.py | 2 +- .../business/timeseries_config_management.py | 49 +- antarest/study/business/utils.py | 6 +- .../study/business/xpansion_management.py | 68 +- antarest/study/model.py | 37 +- antarest/study/service.py | 9 +- .../study/storage/abstract_storage_service.py | 4 +- antarest/study/storage/patch_service.py | 6 +- .../rawstudy/model/filesystem/config/area.py | 81 +- .../model/filesystem/config/cluster.py | 8 +- .../filesystem/config/field_validators.py | 2 +- .../rawstudy/model/filesystem/config/files.py | 1 - .../model/filesystem/config/identifier.py | 39 +- .../model/filesystem/config/ini_properties.py | 18 +- .../rawstudy/model/filesystem/config/links.py | 12 +- .../rawstudy/model/filesystem/config/model.py | 6 +- .../model/filesystem/config/st_storage.py | 2 +- .../model/filesystem/config/thermal.py | 2 +- .../rawstudy/model/filesystem/factory.py | 4 +- .../storage/rawstudy/raw_study_service.py | 2 +- .../study/storage/study_download_utils.py | 2 +- .../business/command_extractor.py | 11 +- .../variantstudy/business/command_reverter.py | 2 +- .../storage/variantstudy/business/utils.py | 4 +- .../business/utils_binding_constraint.py | 1 - .../storage/variantstudy/command_factory.py | 10 +- .../variantstudy/model/command/create_area.py | 4 +- .../command/create_binding_constraint.py | 74 +- .../model/command/create_cluster.py | 37 +- .../model/command/create_district.py | 8 +- .../variantstudy/model/command/create_link.py | 23 +- .../command/create_renewables_cluster.py | 10 +- .../model/command/create_st_storage.py | 65 +- .../generate_thermal_cluster_timeseries.py | 4 +- .../variantstudy/model/command/icommand.py | 10 +- .../variantstudy/model/command/remove_area.py | 2 +- .../command/remove_binding_constraint.py | 2 +- .../model/command/remove_cluster.py | 4 +- .../model/command/remove_district.py | 4 +- .../variantstudy/model/command/remove_link.py | 22 +- .../command/remove_renewables_cluster.py | 4 +- .../model/command/remove_st_storage.py | 8 +- .../model/command/replace_matrix.py | 12 +- .../command/update_binding_constraint.py | 15 +- .../model/command/update_comments.py | 4 +- .../model/command/update_config.py | 4 +- .../model/command/update_district.py | 12 +- .../model/command/update_playlist.py | 4 +- .../model/command/update_raw_file.py | 4 +- .../model/command/update_scenario_builder.py | 8 +- .../variantstudy/model/command_context.py | 1 - .../study/storage/variantstudy/model/model.py | 2 +- .../variantstudy/snapshot_generator.py | 6 +- .../variantstudy/variant_study_service.py | 4 +- antarest/study/web/studies_blueprint.py | 2 +- antarest/study/web/study_data_blueprint.py | 27 +- antarest/study/web/variant_blueprint.py | 7 +- .../study/web/xpansion_studies_blueprint.py | 2 +- antarest/tools/lib.py | 33 +- antarest/utils.py | 9 +- antarest/worker/archive_worker.py | 6 +- antarest/worker/simulator_worker.py | 2 +- antarest/worker/worker.py | 6 +- pyproject.toml | 9 +- requirements-dev.txt | 2 +- requirements-test.txt | 3 +- requirements.txt | 32 +- tests/cache/test_local_cache.py | 6 +- tests/cache/test_redis_cache.py | 6 +- tests/core/test_tasks.py | 2 +- tests/eventbus/test_redis_event_bus.py | 2 +- tests/eventbus/test_websocket_manager.py | 4 +- .../filesystem_blueprint/test_model.py | 6 +- .../launcher_blueprint/test_launcher_local.py | 12 +- .../test_aggregate_raw_data.py | 211 +++-- .../test_download_matrices.py | 2 +- .../test_synthesis/raw_study.synthesis.json | 612 +------------ .../variant_study.synthesis.json | 612 +------------ .../studies_blueprint/test_comments.py | 42 +- .../studies_blueprint/test_disk_usage.py | 2 +- .../studies_blueprint/test_get_studies.py | 18 +- .../studies_blueprint/test_update_tags.py | 39 +- .../test_advanced_parameters.py | 2 +- .../test_binding_constraints.py | 4 +- .../test_config_general.py | 2 +- .../study_data_blueprint/test_renewable.py | 48 +- .../study_data_blueprint/test_st_storage.py | 124 +-- .../study_data_blueprint/test_thermal.py | 78 +- tests/integration/test_apidoc.py | 15 +- tests/integration/test_core_blueprint.py | 21 - tests/integration/test_integration.py | 28 +- .../variant_blueprint/test_st_storage.py | 17 +- .../variant_blueprint/test_thermal_cluster.py | 2 +- .../variant_blueprint/test_variant_manager.py | 6 +- .../test_integration_xpansion.py | 96 +- tests/launcher/test_service.py | 35 +- tests/launcher/test_web.py | 21 +- tests/login/test_login_service.py | 36 +- tests/login/test_web.py | 36 +- tests/matrixstore/test_matrix_editor.py | 8 +- tests/matrixstore/test_service.py | 14 +- tests/matrixstore/test_web.py | 8 +- .../storage/business/test_arealink_manager.py | 29 +- tests/storage/business/test_config_manager.py | 12 +- tests/storage/business/test_patch_service.py | 2 +- .../test_timeseries_config_manager.py | 2 +- .../storage/business/test_xpansion_manager.py | 32 +- tests/storage/rawstudies/test_factory.py | 4 +- .../filesystem/config/test_config_files.py | 2 +- tests/storage/test_model.py | 2 +- tests/storage/test_service.py | 10 +- tests/storage/web/test_studies_bp.py | 13 +- .../areas/test_st_storage_management.py | 106 +-- .../business/areas/test_thermal_management.py | 8 +- .../business/test_all_optional_metaclass.py | 375 +------- .../study/business/test_allocation_manager.py | 2 +- .../variantstudy/model/test_dbmodel.py | 2 +- .../variantstudy/test_snapshot_generator.py | 8 +- tests/study/test_repository.py | 44 +- .../model/command/test_create_cluster.py | 32 +- .../model/command/test_create_link.py | 17 +- .../command/test_create_renewables_cluster.py | 6 +- .../model/command/test_create_st_storage.py | 101 +-- .../model/command/test_manage_district.py | 8 +- .../model/command/test_remove_area.py | 2 +- .../model/command/test_remove_link.py | 2 +- .../model/command/test_remove_st_storage.py | 6 +- .../model/command/test_replace_matrix.py | 6 +- .../model/command/test_update_config.py | 2 +- .../variantstudy/model/test_variant_model.py | 2 +- 184 files changed, 2859 insertions(+), 3467 deletions(-) create mode 100644 antarest/fastapi_jwt_auth/LICENSE create mode 100644 antarest/fastapi_jwt_auth/__init__.py create mode 100644 antarest/fastapi_jwt_auth/auth_config.py create mode 100644 antarest/fastapi_jwt_auth/auth_jwt.py create mode 100644 antarest/fastapi_jwt_auth/config.py create mode 100644 antarest/fastapi_jwt_auth/exceptions.py diff --git a/antarest/core/cache/business/redis_cache.py b/antarest/core/cache/business/redis_cache.py index 176e583b76..514bc81fdd 100644 --- a/antarest/core/cache/business/redis_cache.py +++ b/antarest/core/cache/business/redis_cache.py @@ -28,7 +28,7 @@ def put(self, id: str, data: JSON, duration: int = 3600) -> None: redis_element = RedisCacheElement(duration=duration, data=data) redis_key = f"cache:{id}" logger.info(f"Adding cache key {id}") - self.redis.set(redis_key, redis_element.json()) + self.redis.set(redis_key, redis_element.model_dump_json()) self.redis.expire(redis_key, duration) def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]: diff --git a/antarest/core/config.py b/antarest/core/config.py index d7b7ed1243..494ad19c47 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -1,6 +1,7 @@ import multiprocessing import tempfile from dataclasses import asdict, dataclass, field +from enum import Enum from pathlib import Path from typing import Dict, List, Optional @@ -12,6 +13,12 @@ DEFAULT_WORKSPACE_NAME = "default" +class Launcher(str, Enum): + SLURM = "slurm" + LOCAL = "local" + DEFAULT = "default" + + @dataclass(frozen=True) class ExternalAuthConfig: """ @@ -387,7 +394,7 @@ def __post_init__(self) -> None: msg = f"Invalid configuration: {self.default=} must be one of {possible!r}" raise ValueError(msg) - def get_nb_cores(self, launcher: str) -> "NbCoresConfig": + def get_nb_cores(self, launcher: Launcher) -> "NbCoresConfig": """ Retrieve the number of cores configuration for a given launcher: "local" or "slurm". If "default" is specified, retrieve the configuration of the default launcher. @@ -404,12 +411,12 @@ def get_nb_cores(self, launcher: str) -> "NbCoresConfig": """ config_map = {"local": self.local, "slurm": self.slurm} config_map["default"] = config_map[self.default] - launcher_config = config_map.get(launcher) + launcher_config = config_map.get(launcher.value) if launcher_config is None: - raise InvalidConfigurationError(launcher) + raise InvalidConfigurationError(launcher.value) return launcher_config.nb_cores - def get_time_limit(self, launcher: str) -> TimeLimitConfig: + def get_time_limit(self, launcher: Launcher) -> TimeLimitConfig: """ Retrieve the time limit for a job of the given launcher: "local" or "slurm". If "default" is specified, retrieve the configuration of the default launcher. @@ -426,7 +433,7 @@ def get_time_limit(self, launcher: str) -> TimeLimitConfig: """ config_map = {"local": self.local, "slurm": self.slurm} config_map["default"] = config_map[self.default] - launcher_config = config_map.get(launcher) + launcher_config = config_map.get(launcher.value) if launcher_config is None: raise InvalidConfigurationError(launcher) return launcher_config.time_limit diff --git a/antarest/core/core_blueprint.py b/antarest/core/core_blueprint.py index 5eac42ce3b..d308ed60ca 100644 --- a/antarest/core/core_blueprint.py +++ b/antarest/core/core_blueprint.py @@ -1,12 +1,9 @@ -import logging from typing import Any -from fastapi import APIRouter, Depends +from fastapi import APIRouter from pydantic import BaseModel from antarest.core.config import Config -from antarest.core.jwt import JWTUser -from antarest.core.requests import UserHasNotPermissionError from antarest.core.utils.web import APITag from antarest.core.version_info import VersionInfoDTO, get_commit_id, get_dependencies from antarest.login.auth import Auth @@ -54,15 +51,4 @@ def version_info() -> Any: dependencies=get_dependencies(), ) - @bp.get("/kill", include_in_schema=False) - def kill_worker( - current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: - if not current_user.is_site_admin(): - raise UserHasNotPermissionError() - logging.getLogger(__name__).critical("Killing the worker") - # PyInstaller modifies the behavior of built-in functions, such as `exit`. - # It is advisable to use `sys.exit` or raise the `SystemExit` exception instead. - raise SystemExit(f"Worker killed by the user #{current_user.id}") - return bp diff --git a/antarest/core/filesystem_blueprint.py b/antarest/core/filesystem_blueprint.py index bf247978b2..4d6869023a 100644 --- a/antarest/core/filesystem_blueprint.py +++ b/antarest/core/filesystem_blueprint.py @@ -11,21 +11,21 @@ import typing_extensions as te from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Field from starlette.responses import PlainTextResponse, StreamingResponse from antarest.core.config import Config from antarest.core.utils.web import APITag from antarest.login.auth import Auth -FilesystemName = te.Annotated[str, Field(regex=r"^\w+$", description="Filesystem name")] -MountPointName = te.Annotated[str, Field(regex=r"^\w+$", description="Mount point name")] +FilesystemName = te.Annotated[str, Field(pattern=r"^\w+$", description="Filesystem name")] +MountPointName = te.Annotated[str, Field(pattern=r"^\w+$", description="Mount point name")] class FilesystemDTO( BaseModel, - extra=Extra.forbid, - schema_extra={ + extra="forbid", + json_schema_extra={ "example": { "name": "ws", "mount_dirs": { @@ -50,8 +50,8 @@ class FilesystemDTO( class MountPointDTO( BaseModel, - extra=Extra.forbid, - schema_extra={ + extra="forbid", + json_schema_extra={ "example": { "name": "default", "path": "/path/to/workspaces/internal_studies", @@ -77,10 +77,10 @@ class MountPointDTO( name: MountPointName path: Path = Field(description="Full path of the mount point in Antares Web Server") - total_bytes: int = Field(0, description="Total size of the mount point in bytes") - used_bytes: int = Field(0, description="Used size of the mount point in bytes") - free_bytes: int = Field(0, description="Free size of the mount point in bytes") - message: str = Field("", description="A message describing the status of the mount point") + total_bytes: t.Optional[int] = 0 # Total size of the mount point in bytes + used_bytes: t.Optional[int] = 0 # Used size of the mount point in bytes + free_bytes: t.Optional[int] = 0 # Free size of the mount point in bytes + message: t.Optional[str] = "" # A message describing the status of the mount point @classmethod async def from_path(cls, name: str, path: Path) -> "MountPointDTO": @@ -98,8 +98,8 @@ async def from_path(cls, name: str, path: Path) -> "MountPointDTO": class FileInfoDTO( BaseModel, - extra=Extra.forbid, - schema_extra={ + extra="forbid", + json_schema_extra={ "example": { "path": "/path/to/workspaces/internal_studies/5a503c20-24a3-4734-9cf8-89565c9db5ec/study.antares", "file_type": "file", @@ -148,6 +148,7 @@ async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfo path=full_path, file_type="unknown", file_count=0, # missing + size_bytes=0, # missing created=datetime.datetime.min, modified=datetime.datetime.min, accessed=datetime.datetime.min, @@ -162,6 +163,7 @@ async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfo created=datetime.datetime.fromtimestamp(file_stat.st_ctime), modified=datetime.datetime.fromtimestamp(file_stat.st_mtime), accessed=datetime.datetime.fromtimestamp(file_stat.st_atime), + message="OK", ) if stat.S_ISDIR(file_stat.st_mode): diff --git a/antarest/core/filetransfer/model.py b/antarest/core/filetransfer/model.py index bbb61c00b6..3e11e61ada 100644 --- a/antarest/core/filetransfer/model.py +++ b/antarest/core/filetransfer/model.py @@ -29,7 +29,7 @@ class FileDownloadDTO(BaseModel): id: str name: str filename: str - expiration_date: Optional[str] + expiration_date: Optional[str] = None ready: bool failed: bool = False error_message: str = "" diff --git a/antarest/core/model.py b/antarest/core/model.py index 4c8c0d5f0e..6e80f2de8e 100644 --- a/antarest/core/model.py +++ b/antarest/core/model.py @@ -9,7 +9,7 @@ JSON = Dict[str, Any] ELEMENT = Union[str, int, float, bool, bytes] -SUB_JSON = Union[ELEMENT, JSON, List, None] +SUB_JSON = Union[ELEMENT, JSON, List[Any], None] class PublicMode(str, enum.Enum): diff --git a/antarest/core/permissions.py b/antarest/core/permissions.py index 8ecba4d85f..08d3e4e456 100644 --- a/antarest/core/permissions.py +++ b/antarest/core/permissions.py @@ -1,4 +1,5 @@ import logging +import typing as t from antarest.core.jwt import JWTUser from antarest.core.model import PermissionInfo, PublicMode, StudyPermissionType @@ -7,8 +8,8 @@ logger = logging.getLogger(__name__) -permission_matrix = { - StudyPermissionType.READ: { +permission_matrix: t.Dict[str, t.Dict[str, t.Sequence[t.Union[RoleType, PublicMode]]]] = { + StudyPermissionType.READ.value: { "roles": [ RoleType.ADMIN, RoleType.RUNNER, @@ -22,15 +23,15 @@ PublicMode.READ, ], }, - StudyPermissionType.RUN: { + StudyPermissionType.RUN.value: { "roles": [RoleType.ADMIN, RoleType.RUNNER, RoleType.WRITER], "public_modes": [PublicMode.FULL, PublicMode.EDIT, PublicMode.EXECUTE], }, - StudyPermissionType.WRITE: { + StudyPermissionType.WRITE.value: { "roles": [RoleType.ADMIN, RoleType.WRITER], "public_modes": [PublicMode.FULL, PublicMode.EDIT], }, - StudyPermissionType.MANAGE_PERMISSIONS: { + StudyPermissionType.MANAGE_PERMISSIONS.value: { "roles": [RoleType.ADMIN], "public_modes": [], }, @@ -65,11 +66,11 @@ def check_permission( allowed_roles = permission_matrix[permission]["roles"] group_permission = any( - role in allowed_roles # type: ignore + role in allowed_roles for role in [group.role for group in (user.groups or []) if group.id in permission_info.groups] ) if group_permission: return True allowed_public_modes = permission_matrix[permission]["public_modes"] - return permission_info.public_mode in allowed_public_modes # type: ignore + return permission_info.public_mode in allowed_public_modes diff --git a/antarest/core/requests.py b/antarest/core/requests.py index d33285c7ab..682e65827f 100644 --- a/antarest/core/requests.py +++ b/antarest/core/requests.py @@ -1,5 +1,7 @@ +import typing as t +from collections import OrderedDict from dataclasses import dataclass -from typing import Optional +from typing import Any, Generator, Tuple from fastapi import HTTPException from markupsafe import escape @@ -17,13 +19,52 @@ } +class CaseInsensitiveDict(t.MutableMapping[str, t.Any]): # copy of the requests class to avoid importing the package + def __init__(self, data=None, **kwargs) -> None: # type: ignore + self._store: OrderedDict[str, t.Any] = OrderedDict() + if data is None: + data = {} + self.update(data, **kwargs) + + def __setitem__(self, key: str, value: t.Any) -> None: + self._store[key.lower()] = (key, value) + + def __getitem__(self, key: str) -> t.Any: + return self._store[key.lower()][1] + + def __delitem__(self, key: str) -> None: + del self._store[key.lower()] + + def __iter__(self) -> t.Any: + return (casedkey for casedkey, mappedvalue in self._store.values()) + + def __len__(self) -> int: + return len(self._store) + + def lower_items(self) -> Generator[Tuple[Any, Any], Any, None]: + return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) + + def __eq__(self, other: t.Any) -> bool: + if isinstance(other, t.Mapping): + other = CaseInsensitiveDict(other) + else: + return NotImplemented + return dict(self.lower_items()) == dict(other.lower_items()) + + def copy(self) -> "CaseInsensitiveDict": + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self) -> str: + return str(dict(self.items())) + + @dataclass class RequestParameters: """ DTO object to handle data inside request to send to service """ - user: Optional[JWTUser] = None + user: t.Optional[JWTUser] = None def get_user_id(self) -> str: return str(escape(str(self.user.id))) if self.user else "Unknown" diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 601e0db8ce..996f4557d1 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import Enum -from pydantic import BaseModel, Extra +from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.orm import relationship, sessionmaker # type: ignore @@ -44,43 +44,43 @@ def is_final(self) -> bool: ] -class TaskResult(BaseModel, extra=Extra.forbid): +class TaskResult(BaseModel, extra="forbid"): success: bool message: str # Can be used to store json serialized result - return_value: t.Optional[str] + return_value: t.Optional[str] = None -class TaskLogDTO(BaseModel, extra=Extra.forbid): +class TaskLogDTO(BaseModel, extra="forbid"): id: str message: str -class CustomTaskEventMessages(BaseModel, extra=Extra.forbid): +class CustomTaskEventMessages(BaseModel, extra="forbid"): start: str running: str end: str -class TaskEventPayload(BaseModel, extra=Extra.forbid): +class TaskEventPayload(BaseModel, extra="forbid"): id: str message: str -class TaskDTO(BaseModel, extra=Extra.forbid): +class TaskDTO(BaseModel, extra="forbid"): id: str name: str - owner: t.Optional[int] + owner: t.Optional[int] = None status: TaskStatus creation_date_utc: str - completion_date_utc: t.Optional[str] - result: t.Optional[TaskResult] - logs: t.Optional[t.List[TaskLogDTO]] + completion_date_utc: t.Optional[str] = None + result: t.Optional[TaskResult] = None + logs: t.Optional[t.List[TaskLogDTO]] = None type: t.Optional[str] = None ref_id: t.Optional[str] = None -class TaskListFilter(BaseModel, extra=Extra.forbid): +class TaskListFilter(BaseModel, extra="forbid"): status: t.List[TaskStatus] = [] name: t.Optional[str] = None type: t.List[TaskType] = [] @@ -158,6 +158,15 @@ class TaskJob(Base): # type: ignore study: "Study" = relationship("Study", back_populates="jobs", uselist=False) def to_dto(self, with_logs: bool = False) -> TaskDTO: + result = None + if self.completion_date: + assert self.result_status is not None + assert self.result_msg is not None + result = TaskResult( + success=self.result_status, + message=self.result_msg, + return_value=self.result, + ) return TaskDTO( id=self.id, owner=self.owner_id, @@ -165,13 +174,7 @@ def to_dto(self, with_logs: bool = False) -> TaskDTO: completion_date_utc=str(self.completion_date) if self.completion_date else None, name=self.name, status=TaskStatus(self.status), - result=TaskResult( - success=self.result_status, - message=self.result_msg, - return_value=self.result, - ) - if self.completion_date - else None, + result=result, logs=sorted([log.to_dto() for log in self.logs], key=lambda log: log.id) if with_logs else None, type=self.type, ref_id=self.ref_id, diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 07832e7365..0ac0e9d875 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -132,7 +132,7 @@ def _create_awaiter( res_wrapper: t.List[TaskResult], ) -> t.Callable[[Event], t.Awaitable[None]]: async def _await_task_end(event: Event) -> None: - task_event = WorkerTaskResult.parse_obj(event.payload) + task_event = WorkerTaskResult.model_validate(event.payload) if task_event.task_id == task_id: res_wrapper.append(task_event.task_result) @@ -240,7 +240,7 @@ def _launch_task( message=custom_event_messages.start if custom_event_messages is not None else f"Task {task.id} added", - ).dict(), + ).model_dump(), permissions=PermissionInfo(owner=request_params.user.impersonator), ) ) @@ -349,7 +349,7 @@ def _run_task( message=custom_event_messages.running if custom_event_messages is not None else f"Task {task_id} is running", - ).dict(), + ).model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, ) @@ -395,7 +395,7 @@ def _run_task( if custom_event_messages is not None else f"Task {task_id} {event_msg}" ), - ).dict(), + ).model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, ) @@ -420,7 +420,7 @@ def _run_task( self.event_bus.push( Event( type=EventType.TASK_FAILED, - payload=TaskEventPayload(id=task_id, message=message).dict(), + payload=TaskEventPayload(id=task_id, message=message).model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, ) diff --git a/antarest/core/utils/fastapi_sqlalchemy/middleware.py b/antarest/core/utils/fastapi_sqlalchemy/middleware.py index 9a98b4ef1b..dcc1f95b25 100644 --- a/antarest/core/utils/fastapi_sqlalchemy/middleware.py +++ b/antarest/core/utils/fastapi_sqlalchemy/middleware.py @@ -12,7 +12,7 @@ from antarest.core.utils.fastapi_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError -_Session: sessionmaker = None +_Session: Optional[sessionmaker] = None _session: ContextVar[Optional[Session]] = ContextVar("_session", default=None) @@ -93,4 +93,4 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: _session.reset(self.token) -db: DBSessionMeta = DBSession +db: Type[DBSession] = DBSession diff --git a/antarest/core/version_info.py b/antarest/core/version_info.py index 8b1b50d84b..7614c64574 100644 --- a/antarest/core/version_info.py +++ b/antarest/core/version_info.py @@ -16,7 +16,7 @@ class VersionInfoDTO(BaseModel): dependencies: Dict[str, str] class Config: - schema_extra = { + json_schema_extra = { "example": { "name": "AntaREST", "version": "2.13.2", diff --git a/antarest/eventbus/business/redis_eventbus.py b/antarest/eventbus/business/redis_eventbus.py index 94d7f5a36c..e4b541f891 100644 --- a/antarest/eventbus/business/redis_eventbus.py +++ b/antarest/eventbus/business/redis_eventbus.py @@ -17,10 +17,10 @@ def __init__(self, redis_client: Redis) -> None: # type: ignore self.pubsub.subscribe(REDIS_STORE_KEY) def push_event(self, event: Event) -> None: - self.redis.publish(REDIS_STORE_KEY, event.json()) + self.redis.publish(REDIS_STORE_KEY, event.model_dump_json()) def queue_event(self, event: Event, queue: str) -> None: - self.redis.rpush(queue, event.json()) + self.redis.rpush(queue, event.model_dump_json()) def pull_queue(self, queue: str) -> Optional[Event]: event = self.redis.lpop(queue) diff --git a/antarest/eventbus/web.py b/antarest/eventbus/web.py index cae0ffb99f..3b757ce4a2 100644 --- a/antarest/eventbus/web.py +++ b/antarest/eventbus/web.py @@ -6,7 +6,6 @@ from typing import List, Optional from fastapi import Depends, FastAPI, HTTPException, Query -from fastapi_jwt_auth import AuthJWT # type: ignore from pydantic import BaseModel from starlette.websockets import WebSocket, WebSocketDisconnect @@ -15,6 +14,7 @@ from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser from antarest.core.model import PermissionInfo, StudyPermissionType from antarest.core.permissions import check_permission +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth logger = logging.getLogger(__name__) @@ -83,7 +83,7 @@ def configure_websockets(application: FastAPI, config: Config, event_bus: IEvent manager = ConnectionManager() async def send_event_to_ws(event: Event) -> None: - event_data = event.dict() + event_data = event.model_dump() del event_data["permissions"] del event_data["channel"] await manager.broadcast(json.dumps(event_data), event.permissions, event.channel) diff --git a/antarest/fastapi_jwt_auth/LICENSE b/antarest/fastapi_jwt_auth/LICENSE new file mode 100644 index 0000000000..ad5e3c8004 --- /dev/null +++ b/antarest/fastapi_jwt_auth/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Nyoman Pradipta Dewantara + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/antarest/fastapi_jwt_auth/__init__.py b/antarest/fastapi_jwt_auth/__init__.py new file mode 100644 index 0000000000..a3748ffe33 --- /dev/null +++ b/antarest/fastapi_jwt_auth/__init__.py @@ -0,0 +1,7 @@ +"""FastAPI extension that provides JWT Auth support (secure, easy to use and lightweight)""" + +__version__ = "0.5.0" + +__all__ = ["AuthJWT"] + +from .auth_jwt import AuthJWT diff --git a/antarest/fastapi_jwt_auth/auth_config.py b/antarest/fastapi_jwt_auth/auth_config.py new file mode 100644 index 0000000000..55ca04a25d --- /dev/null +++ b/antarest/fastapi_jwt_auth/auth_config.py @@ -0,0 +1,114 @@ +from datetime import timedelta +from typing import Callable, List + +from pydantic import ValidationError + +from .config import LoadConfig + + +class AuthConfig: + _token = None + _token_location = {"headers"} + + _secret_key = None + _public_key = None + _private_key = None + _algorithm = "HS256" + _decode_algorithms = None + _decode_leeway = 0 + _encode_issuer = None + _decode_issuer = None + _decode_audience = None + _denylist_enabled = False + _denylist_token_checks = {"access", "refresh"} + _header_name = "Authorization" + _header_type = "Bearer" + _token_in_denylist_callback = None + _access_token_expires = timedelta(minutes=15) + _refresh_token_expires = timedelta(days=30) + + # option for create cookies + _access_cookie_key = "access_token_cookie" + _refresh_cookie_key = "refresh_token_cookie" + _access_cookie_path = "/" + _refresh_cookie_path = "/" + _cookie_max_age = None + _cookie_domain = None + _cookie_secure = False + _cookie_samesite = None + + # option for double submit csrf protection + _cookie_csrf_protect = True + _access_csrf_cookie_key = "csrf_access_token" + _refresh_csrf_cookie_key = "csrf_refresh_token" + _access_csrf_cookie_path = "/" + _refresh_csrf_cookie_path = "/" + _access_csrf_header_name = "X-CSRF-Token" + _refresh_csrf_header_name = "X-CSRF-Token" + _csrf_methods = {"POST", "PUT", "PATCH", "DELETE"} + + @property + def jwt_in_cookies(self) -> bool: + return "cookies" in self._token_location + + @property + def jwt_in_headers(self) -> bool: + return "headers" in self._token_location + + @classmethod + def load_config(cls, settings: Callable[..., List[tuple]]) -> "AuthConfig": + try: + config = LoadConfig(**{key.lower(): value for key, value in settings()}) + + cls._token_location = config.authjwt_token_location + cls._secret_key = config.authjwt_secret_key + cls._public_key = config.authjwt_public_key + cls._private_key = config.authjwt_private_key + cls._algorithm = config.authjwt_algorithm + cls._decode_algorithms = config.authjwt_decode_algorithms + cls._decode_leeway = config.authjwt_decode_leeway + cls._encode_issuer = config.authjwt_encode_issuer + cls._decode_issuer = config.authjwt_decode_issuer + cls._decode_audience = config.authjwt_decode_audience + cls._denylist_enabled = config.authjwt_denylist_enabled + cls._denylist_token_checks = config.authjwt_denylist_token_checks + cls._header_name = config.authjwt_header_name + cls._header_type = config.authjwt_header_type + cls._access_token_expires = config.authjwt_access_token_expires + cls._refresh_token_expires = config.authjwt_refresh_token_expires + # option for create cookies + cls._access_cookie_key = config.authjwt_access_cookie_key + cls._refresh_cookie_key = config.authjwt_refresh_cookie_key + cls._access_cookie_path = config.authjwt_access_cookie_path + cls._refresh_cookie_path = config.authjwt_refresh_cookie_path + cls._cookie_max_age = config.authjwt_cookie_max_age + cls._cookie_domain = config.authjwt_cookie_domain + cls._cookie_secure = config.authjwt_cookie_secure + cls._cookie_samesite = config.authjwt_cookie_samesite + # option for double submit csrf protection + cls._cookie_csrf_protect = config.authjwt_cookie_csrf_protect + cls._access_csrf_cookie_key = config.authjwt_access_csrf_cookie_key + cls._refresh_csrf_cookie_key = config.authjwt_refresh_csrf_cookie_key + cls._access_csrf_cookie_path = config.authjwt_access_csrf_cookie_path + cls._refresh_csrf_cookie_path = config.authjwt_refresh_csrf_cookie_path + cls._access_csrf_header_name = config.authjwt_access_csrf_header_name + cls._refresh_csrf_header_name = config.authjwt_refresh_csrf_header_name + cls._csrf_methods = config.authjwt_csrf_methods + except ValidationError as e: + raise e + except Exception: + raise TypeError("Config must be pydantic 'BaseSettings' or list of tuple") + + @classmethod + def token_in_denylist_loader(cls, callback: Callable[..., bool]) -> "AuthConfig": + """ + This decorator sets the callback function that will be called when + a protected endpoint is accessed and will check if the JWT has been + been revoked. By default, this callback is not used. + + *HINT*: The callback must be a function that takes decrypted_token argument, + args for object AuthJWT and this is not used, decrypted_token is decode + JWT (python dictionary) and returns *`True`* if the token has been deny, + or *`False`* otherwise. + """ + cls._token_in_denylist_callback = callback diff --git a/antarest/fastapi_jwt_auth/auth_jwt.py b/antarest/fastapi_jwt_auth/auth_jwt.py new file mode 100644 index 0000000000..4eb5949d43 --- /dev/null +++ b/antarest/fastapi_jwt_auth/auth_jwt.py @@ -0,0 +1,836 @@ +import hmac +import re +import uuid +from datetime import datetime, timedelta, timezone +from typing import Dict, Optional, Sequence, Union + +import jwt +from fastapi import Request, Response, WebSocket +from jwt.algorithms import has_crypto, requires_cryptography + +from .auth_config import AuthConfig +from .exceptions import ( + AccessTokenRequired, + CSRFError, + FreshTokenRequired, + InvalidHeaderError, + JWTDecodeError, + MissingTokenError, + RefreshTokenRequired, + RevokedTokenError, +) + + +class AuthJWT(AuthConfig): + def __init__(self, req: Request = None, res: Response = None): + """ + Get jwt header from incoming request or get + request and response object if jwt in the cookie + + :param req: all incoming request + :param res: response from endpoint + """ + if res and self.jwt_in_cookies: + self._response = res + + if req: + # get request object when cookies in token location + if self.jwt_in_cookies: + self._request = req + # get jwt in headers when headers in token location + if self.jwt_in_headers: + auth = req.headers.get(self._header_name.lower()) + if auth: + self._get_jwt_from_headers(auth) + + def _get_jwt_from_headers(self, auth: str) -> "AuthJWT": + """ + Get token from the headers + + :param auth: value from HeaderName + """ + header_name, header_type = self._header_name, self._header_type + + parts = auth.split() + + # Make sure the header is in a valid format that we are expecting, ie + if not header_type: + # : + if len(parts) != 1: + msg = "Bad {} header. Expected value ''".format(header_name) + raise InvalidHeaderError(status_code=422, message=msg) + self._token = parts[0] + else: + # : + if not re.match(r"{}\s".format(header_type), auth) or len(parts) != 2: + msg = "Bad {} header. Expected value '{} '".format(header_name, header_type) + raise InvalidHeaderError(status_code=422, message=msg) + self._token = parts[1] + + def _get_jwt_identifier(self) -> str: + return str(uuid.uuid4()) + + def _get_int_from_datetime(self, value: datetime) -> int: + """ + :param value: datetime with or without timezone, if don't contains timezone + it will managed as it is UTC + :return: Seconds since the Epoch + """ + if not isinstance(value, datetime): # pragma: no cover + raise TypeError("a datetime is required") + return int(value.timestamp()) + + def _get_secret_key(self, algorithm: str, process: str) -> str: + """ + Get key with a different algorithm + + :param algorithm: algorithm for decode and encode token + :param process: for indicating get key for encode or decode token + + :return: plain text or RSA depends on algorithm + """ + symmetric_algorithms, asymmetric_algorithms = {"HS256", "HS384", "HS512"}, requires_cryptography + + if algorithm not in symmetric_algorithms and algorithm not in asymmetric_algorithms: + raise ValueError("Algorithm {} could not be found".format(algorithm)) + + if algorithm in symmetric_algorithms: + if not self._secret_key: + raise RuntimeError("authjwt_secret_key must be set when using symmetric algorithm {}".format(algorithm)) + + return self._secret_key + + if algorithm in asymmetric_algorithms and not has_crypto: + raise RuntimeError( + "Missing dependencies for using asymmetric algorithms. run 'pip install fastapi-jwt-auth[asymmetric]'" + ) + + if process == "encode": + if not self._private_key: + raise RuntimeError( + "authjwt_private_key must be set when using asymmetric algorithm {}".format(algorithm) + ) + + return self._private_key + + if process == "decode": + if not self._public_key: + raise RuntimeError( + "authjwt_public_key must be set when using asymmetric algorithm {}".format(algorithm) + ) + + return self._public_key + + def _create_token( + self, + subject: Union[str, int], + type_token: str, + exp_time: Optional[int], + fresh: Optional[bool] = False, + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + issuer: Optional[str] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, + ) -> str: + """ + Create token for access_token and refresh_token (utf-8) + + :param subject: Identifier for who this token is for example id or username from database. + :param type_token: indicate token is access_token or refresh_token + :param exp_time: Set the duration of the JWT + :param fresh: Optional when token is access_token this param required + :param algorithm: algorithm allowed to encode the token + :param headers: valid dict for specifying additional headers in JWT header section + :param issuer: expected issuer in the JWT + :param audience: expected audience in the JWT + :param user_claims: Custom claims to include in this token. This data must be dictionary + + :return: Encoded token + """ + # Validation type data + if not isinstance(subject, (str, int)): + raise TypeError("subject must be a string or integer") + if not isinstance(fresh, bool): + raise TypeError("fresh must be a boolean") + if audience and not isinstance(audience, (str, list, tuple, set, frozenset)): + raise TypeError("audience must be a string or sequence") + if algorithm and not isinstance(algorithm, str): + raise TypeError("algorithm must be a string") + if user_claims and not isinstance(user_claims, dict): + raise TypeError("user_claims must be a dictionary") + + # Data section + reserved_claims = { + "sub": subject, + "iat": self._get_int_from_datetime(datetime.now(timezone.utc)), + "nbf": self._get_int_from_datetime(datetime.now(timezone.utc)), + "jti": self._get_jwt_identifier(), + } + + custom_claims = {"type": type_token} + + # for access_token only fresh needed + if type_token == "access": + custom_claims["fresh"] = fresh + # if cookie in token location and csrf protection enabled + if self.jwt_in_cookies and self._cookie_csrf_protect: + custom_claims["csrf"] = self._get_jwt_identifier() + + if exp_time: + reserved_claims["exp"] = exp_time + if issuer: + reserved_claims["iss"] = issuer + if audience: + reserved_claims["aud"] = audience + + algorithm = algorithm or self._algorithm + + secret_key = self._get_secret_key(algorithm, "encode") + + return jwt.encode( + {**reserved_claims, **custom_claims, **user_claims}, secret_key, algorithm=algorithm, headers=headers + ) + + def _has_token_in_denylist_callback(self) -> bool: + """ + Return True if token denylist callback set + """ + return self._token_in_denylist_callback is not None + + def _check_token_is_revoked(self, raw_token: Dict[str, Union[str, int, bool]]) -> None: + """ + Ensure that AUTHJWT_DENYLIST_ENABLED is true and callback regulated, and then + call function denylist callback with passing decode JWT, if true + raise exception Token has been revoked + """ + if not self._denylist_enabled: + return + + if not self._has_token_in_denylist_callback(): + raise RuntimeError( + "A token_in_denylist_callback must be provided via " + "the '@AuthJWT.token_in_denylist_loader' if " + "authjwt_denylist_enabled is 'True'" + ) + + if self._token_in_denylist_callback.__func__(raw_token): + raise RevokedTokenError(status_code=401, message="Token has been revoked") + + def _get_expired_time( + self, type_token: str, expires_time: Optional[Union[timedelta, int, bool]] = None + ) -> Union[None, int]: + """ + Dynamic token expired, if expires_time is False exp claim not created + + :param type_token: indicate token is access_token or refresh_token + :param expires_time: duration expired jwt + + :return: duration exp claim jwt + """ + if expires_time and not isinstance(expires_time, (timedelta, int, bool)): + raise TypeError("expires_time must be between timedelta, int, bool") + + if expires_time is not False: + if type_token == "access": + expires_time = expires_time or self._access_token_expires + if type_token == "refresh": + expires_time = expires_time or self._refresh_token_expires + + if expires_time is not False: + if isinstance(expires_time, bool): + if type_token == "access": + expires_time = self._access_token_expires + if type_token == "refresh": + expires_time = self._refresh_token_expires + if isinstance(expires_time, timedelta): + expires_time = int(expires_time.total_seconds()) + + return self._get_int_from_datetime(datetime.now(timezone.utc)) + expires_time + else: + return None + + def create_access_token( + self, + subject: Union[str, int], + fresh: Optional[bool] = False, + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, + ) -> str: + """ + Create a access token with 15 minutes for expired time (default), + info for param and return check to function create token + + :return: hash token + """ + return self._create_token( + subject=subject, + type_token="access", + exp_time=self._get_expired_time("access", expires_time), + fresh=fresh, + algorithm=algorithm, + headers=headers, + audience=audience, + user_claims=user_claims, + issuer=self._encode_issuer, + ) + + def create_refresh_token( + self, + subject: Union[str, int], + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, + ) -> str: + """ + Create a refresh token with 30 days for expired time (default), + info for param and return check to function create token + + :return: hash token + """ + return self._create_token( + subject=subject, + type_token="refresh", + exp_time=self._get_expired_time("refresh", expires_time), + algorithm=algorithm, + headers=headers, + audience=audience, + user_claims=user_claims, + ) + + def _get_csrf_token(self, encoded_token: str) -> str: + """ + Returns the CSRF double submit token from an encoded JWT. + + :param encoded_token: The encoded JWT + :return: The CSRF double submit token + """ + return self._verified_token(encoded_token)["csrf"] + + def set_access_cookies( + self, encoded_access_token: str, response: Optional[Response] = None, max_age: Optional[int] = None + ) -> None: + """ + Configures the response to set access token in a cookie. + this will also set the CSRF double submit values in a separate cookie + + :param encoded_access_token: The encoded access token to set in the cookies + :param response: The FastAPI response object to set the access cookies in + :param max_age: The max age of the cookie value should be the number of seconds (integer) + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "set_access_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if max_age and not isinstance(max_age, int): + raise TypeError("max_age must be a integer") + if response and not isinstance(response, Response): + raise TypeError("The response must be an object response FastAPI") + + response = response or self._response + + # Set the access JWT in the cookie + response.set_cookie( + self._access_cookie_key, + encoded_access_token, + max_age=max_age or self._cookie_max_age, + path=self._access_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=True, + samesite=self._cookie_samesite, + ) + + # If enabled, set the csrf double submit access cookie + if self._cookie_csrf_protect: + response.set_cookie( + self._access_csrf_cookie_key, + self._get_csrf_token(encoded_access_token), + max_age=max_age or self._cookie_max_age, + path=self._access_csrf_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=False, + samesite=self._cookie_samesite, + ) + + def set_refresh_cookies( + self, encoded_refresh_token: str, response: Optional[Response] = None, max_age: Optional[int] = None + ) -> None: + """ + Configures the response to set refresh token in a cookie. + this will also set the CSRF double submit values in a separate cookie + + :param encoded_refresh_token: The encoded refresh token to set in the cookies + :param response: The FastAPI response object to set the refresh cookies in + :param max_age: The max age of the cookie value should be the number of seconds (integer) + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "set_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if max_age and not isinstance(max_age, int): + raise TypeError("max_age must be a integer") + if response and not isinstance(response, Response): + raise TypeError("The response must be an object response FastAPI") + + response = response or self._response + + # Set the refresh JWT in the cookie + response.set_cookie( + self._refresh_cookie_key, + encoded_refresh_token, + max_age=max_age or self._cookie_max_age, + path=self._refresh_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=True, + samesite=self._cookie_samesite, + ) + + # If enabled, set the csrf double submit refresh cookie + if self._cookie_csrf_protect: + response.set_cookie( + self._refresh_csrf_cookie_key, + self._get_csrf_token(encoded_refresh_token), + max_age=max_age or self._cookie_max_age, + path=self._refresh_csrf_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=False, + samesite=self._cookie_samesite, + ) + + def unset_jwt_cookies(self, response: Optional[Response] = None) -> None: + """ + Unset (delete) all jwt stored in a cookie + + :param response: The FastAPI response object to delete the JWT cookies in. + """ + self.unset_access_cookies(response) + self.unset_refresh_cookies(response) + + def unset_access_cookies(self, response: Optional[Response] = None) -> None: + """ + Remove access token and access CSRF double submit from the response cookies + + :param response: The FastAPI response object to delete the access cookies in. + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "unset_access_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if response and not isinstance(response, Response): + raise TypeError("The response must be an object response FastAPI") + + response = response or self._response + + response.delete_cookie(self._access_cookie_key, path=self._access_cookie_path, domain=self._cookie_domain) + + if self._cookie_csrf_protect: + response.delete_cookie( + self._access_csrf_cookie_key, path=self._access_csrf_cookie_path, domain=self._cookie_domain + ) + + def unset_refresh_cookies(self, response: Optional[Response] = None) -> None: + """ + Remove refresh token and refresh CSRF double submit from the response cookies + + :param response: The FastAPI response object to delete the refresh cookies in. + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "unset_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if response and not isinstance(response, Response): + raise TypeError("The response must be an object response FastAPI") + + response = response or self._response + + response.delete_cookie(self._refresh_cookie_key, path=self._refresh_cookie_path, domain=self._cookie_domain) + + if self._cookie_csrf_protect: + response.delete_cookie( + self._refresh_csrf_cookie_key, path=self._refresh_csrf_cookie_path, domain=self._cookie_domain + ) + + def _verify_and_get_jwt_optional_in_cookies( + self, + request: Union[Request, WebSocket], + csrf_token: Optional[str] = None, + ) -> "AuthJWT": + """ + Optionally check if cookies have a valid access token. if an access token present in + cookies, self._token will set. raises exception error when an access token is invalid + or doesn't match with CSRF token double submit + + :param request: for identity get cookies from HTTP or WebSocket + :param csrf_token: the CSRF double submit token + """ + if not isinstance(request, (Request, WebSocket)): + raise TypeError("request must be an instance of 'Request' or 'WebSocket'") + + cookie_key = self._access_cookie_key + cookie = request.cookies.get(cookie_key) + if not isinstance(request, WebSocket): + csrf_token = request.headers.get(self._access_csrf_header_name) + + if cookie and self._cookie_csrf_protect and not csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + raise CSRFError(status_code=401, message="Missing CSRF Token") + + # set token from cookie and verify jwt + self._token = cookie + self._verify_jwt_optional_in_request(self._token) + + decoded_token = self.get_raw_jwt() + + if decoded_token and self._cookie_csrf_protect and csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + if "csrf" not in decoded_token: + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token["csrf"]): + raise CSRFError(status_code=401, message="CSRF double submit tokens do not match") + + def _verify_and_get_jwt_in_cookies( + self, + type_token: str, + request: Union[Request, WebSocket], + csrf_token: Optional[str] = None, + fresh: Optional[bool] = False, + ) -> "AuthJWT": + """ + Check if cookies have a valid access or refresh token. if an token present in + cookies, self._token will set. raises exception error when an access or refresh token + is invalid or doesn't match with CSRF token double submit + + :param type_token: indicate token is access or refresh token + :param request: for identity get cookies from HTTP or WebSocket + :param csrf_token: the CSRF double submit token + :param fresh: check freshness token if True + """ + if type_token not in ["access", "refresh"]: + raise ValueError("type_token must be between 'access' or 'refresh'") + if not isinstance(request, (Request, WebSocket)): + raise TypeError("request must be an instance of 'Request' or 'WebSocket'") + + if type_token == "access": + cookie_key = self._access_cookie_key + cookie = request.cookies.get(cookie_key) + if not isinstance(request, WebSocket): + csrf_token = request.headers.get(self._access_csrf_header_name) + if type_token == "refresh": + cookie_key = self._refresh_cookie_key + cookie = request.cookies.get(cookie_key) + if not isinstance(request, WebSocket): + csrf_token = request.headers.get(self._refresh_csrf_header_name) + + if not cookie: + raise MissingTokenError(status_code=401, message="Missing cookie {}".format(cookie_key)) + + if self._cookie_csrf_protect and not csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + raise CSRFError(status_code=401, message="Missing CSRF Token") + + # set token from cookie and verify jwt + self._token = cookie + self._verify_jwt_in_request(self._token, type_token, "cookies", fresh) + + decoded_token = self.get_raw_jwt() + + if self._cookie_csrf_protect and csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + if "csrf" not in decoded_token: + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token["csrf"]): + raise CSRFError(status_code=401, message="CSRF double submit tokens do not match") + + def _verify_jwt_optional_in_request(self, token: str) -> None: + """ + Optionally check if this request has a valid access token + + :param token: The encoded JWT + """ + if token: + self._verifying_token(token) + + if token and self.get_raw_jwt(token)["type"] != "access": + raise AccessTokenRequired(status_code=422, message="Only access tokens are allowed") + + def _verify_jwt_in_request( + self, token: str, type_token: str, token_from: str, fresh: Optional[bool] = False + ) -> None: + """ + Ensure that the requester has a valid token. this also check the freshness of the access token + + :param token: The encoded JWT + :param type_token: indicate token is access or refresh token + :param token_from: indicate token from headers cookies, websocket + :param fresh: check freshness token if True + """ + if type_token not in ["access", "refresh"]: + raise ValueError("type_token must be between 'access' or 'refresh'") + if token_from not in ["headers", "cookies", "websocket"]: + raise ValueError("token_from must be between 'headers', 'cookies', 'websocket'") + + if not token: + if token_from == "headers": + raise MissingTokenError(status_code=401, message="Missing {} Header".format(self._header_name)) + if token_from == "websocket": + raise MissingTokenError( + status_code=1008, message="Missing {} token from Query or Path".format(type_token) + ) + + # verify jwt + issuer = self._decode_issuer if type_token == "access" else None + self._verifying_token(token, issuer) + + if self.get_raw_jwt(token)["type"] != type_token: + msg = "Only {} tokens are allowed".format(type_token) + if type_token == "access": + raise AccessTokenRequired(status_code=422, message=msg) + if type_token == "refresh": + raise RefreshTokenRequired(status_code=422, message=msg) + + if fresh and not self.get_raw_jwt(token)["fresh"]: + raise FreshTokenRequired(status_code=401, message="Fresh token required") + + def _verifying_token(self, encoded_token: str, issuer: Optional[str] = None) -> None: + """ + Verified token and check if token is revoked + + :param encoded_token: token hash + :param issuer: expected issuer in the JWT + """ + raw_token = self._verified_token(encoded_token, issuer) + if raw_token["type"] in self._denylist_token_checks: + self._check_token_is_revoked(raw_token) + + def _verified_token(self, encoded_token: str, issuer: Optional[str] = None) -> Dict[str, Union[str, int, bool]]: + """ + Verified token and catch all error from jwt package and return decode token + + :param encoded_token: token hash + :param issuer: expected issuer in the JWT + + :return: raw data from the hash token in the form of a dictionary + """ + algorithms = self._decode_algorithms or [self._algorithm] + + try: + unverified_headers = self.get_unverified_jwt_headers(encoded_token) + except Exception as err: + raise InvalidHeaderError(status_code=422, message=str(err)) + + try: + secret_key = self._get_secret_key(unverified_headers["alg"], "decode") + except Exception: + raise + + try: + return jwt.decode( + encoded_token, + secret_key, + issuer=issuer, + audience=self._decode_audience, + leeway=self._decode_leeway, + algorithms=algorithms, + ) + except Exception as err: + raise JWTDecodeError(status_code=422, message=str(err)) + + def jwt_required( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + Only access token can access this function + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_in_cookies("access", websocket, csrf_token) + else: + self._verify_jwt_in_request(token, "access", "websocket") + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers") + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request) + else: + if self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers") + if self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request) + + def jwt_optional( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + If an access token in present in the request you can get data from get_raw_jwt() or get_jwt_subject(), + If no access token is present in the request, this endpoint will still be called, but + get_raw_jwt() or get_jwt_subject() will return None + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_optional_in_cookies(websocket, csrf_token) + else: + self._verify_jwt_optional_in_request(token) + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_optional_in_request(self._token) + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_optional_in_cookies(self._request) + else: + if self.jwt_in_headers: + self._verify_jwt_optional_in_request(self._token) + if self.jwt_in_cookies: + self._verify_and_get_jwt_optional_in_cookies(self._request) + + def jwt_refresh_token_required( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + This function will ensure that the requester has a valid refresh token + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_in_cookies("refresh", websocket, csrf_token) + else: + self._verify_jwt_in_request(token, "refresh", "websocket") + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "refresh", "headers") + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("refresh", self._request) + else: + if self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "refresh", "headers") + if self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("refresh", self._request) + + def fresh_jwt_required( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + This function will ensure that the requester has a valid access token and fresh token + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_in_cookies("access", websocket, csrf_token, True) + else: + self._verify_jwt_in_request(token, "access", "websocket", True) + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers", True) + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request, fresh=True) + else: + if self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers", True) + if self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request, fresh=True) + + def get_raw_jwt(self, encoded_token: Optional[str] = None) -> Optional[Dict[str, Union[str, int, bool]]]: + """ + this will return the python dictionary which has all of the claims of the JWT that is accessing the endpoint. + If no JWT is currently present, return None instead + + :param encoded_token: The encoded JWT from parameter + :return: claims of JWT + """ + token = encoded_token or self._token + + if token: + return self._verified_token(token) + return None + + def get_jti(self, encoded_token: str) -> str: + """ + Returns the JTI (unique identifier) of an encoded JWT + + :param encoded_token: The encoded JWT from parameter + :return: string of JTI + """ + return self._verified_token(encoded_token)["jti"] + + def get_jwt_subject(self) -> Optional[Union[str, int]]: + """ + this will return the subject of the JWT that is accessing this endpoint. + If no JWT is present, `None` is returned instead. + + :return: sub of JWT + """ + if self._token: + return self._verified_token(self._token)["sub"] + return None + + def get_unverified_jwt_headers(self, encoded_token: Optional[str] = None) -> dict: + """ + Returns the Headers of an encoded JWT without verifying the actual signature of JWT + + :param encoded_token: The encoded JWT to get the Header from + :return: JWT header parameters as a dictionary + """ + encoded_token = encoded_token or self._token + + return jwt.get_unverified_header(encoded_token) diff --git a/antarest/fastapi_jwt_auth/config.py b/antarest/fastapi_jwt_auth/config.py new file mode 100644 index 0000000000..f04ace9cd7 --- /dev/null +++ b/antarest/fastapi_jwt_auth/config.py @@ -0,0 +1,90 @@ +import typing as t +from datetime import timedelta +from typing import List, Optional, Sequence, Union + +from pydantic import BaseModel, StrictBool, StrictInt, StrictStr, ValidationError, field_validator, model_validator + + +class LoadConfig(BaseModel): + authjwt_token_location: Optional[t.Set[StrictStr]] = {"headers"} + authjwt_secret_key: Optional[StrictStr] = None + authjwt_public_key: Optional[StrictStr] = None + authjwt_private_key: Optional[StrictStr] = None + authjwt_algorithm: Optional[StrictStr] = "HS256" + authjwt_decode_algorithms: Optional[List[StrictStr]] = None + authjwt_decode_leeway: Optional[Union[StrictInt, timedelta]] = 0 + authjwt_encode_issuer: Optional[StrictStr] = None + authjwt_decode_issuer: Optional[StrictStr] = None + authjwt_decode_audience: Optional[Union[StrictStr, Sequence[StrictStr]]] = None + authjwt_denylist_enabled: Optional[StrictBool] = False + authjwt_denylist_token_checks: Optional[t.Set[StrictStr]] = {"access", "refresh"} + authjwt_header_name: Optional[StrictStr] = "Authorization" + authjwt_header_type: Optional[StrictStr] = "Bearer" + authjwt_access_token_expires: Optional[Union[StrictBool, StrictInt, timedelta]] = timedelta(minutes=15) + authjwt_refresh_token_expires: Optional[Union[StrictBool, StrictInt, timedelta]] = timedelta(days=30) + # option for create cookies + authjwt_access_cookie_key: Optional[StrictStr] = "access_token_cookie" + authjwt_refresh_cookie_key: Optional[StrictStr] = "refresh_token_cookie" + authjwt_access_cookie_path: Optional[StrictStr] = "/" + authjwt_refresh_cookie_path: Optional[StrictStr] = "/" + authjwt_cookie_max_age: Optional[StrictInt] = None + authjwt_cookie_domain: Optional[StrictStr] = None + authjwt_cookie_secure: Optional[StrictBool] = False + authjwt_cookie_samesite: Optional[StrictStr] = None + # option for double submit csrf protection + authjwt_cookie_csrf_protect: Optional[StrictBool] = True + authjwt_access_csrf_cookie_key: Optional[StrictStr] = "csrf_access_token" + authjwt_refresh_csrf_cookie_key: Optional[StrictStr] = "csrf_refresh_token" + authjwt_access_csrf_cookie_path: Optional[StrictStr] = "/" + authjwt_refresh_csrf_cookie_path: Optional[StrictStr] = "/" + authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" + authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" + authjwt_csrf_methods: Optional[t.Set[StrictStr]] = {"POST", "PUT", "PATCH", "DELETE"} + + @field_validator("authjwt_access_token_expires") + def validate_access_token_expires( + cls, v: Optional[Union[StrictBool, StrictInt, timedelta]] + ) -> Optional[Union[StrictBool, StrictInt, timedelta]]: + if v is True: + raise ValueError("The 'authjwt_access_token_expires' only accept value False (bool)") + return v + + @field_validator("authjwt_refresh_token_expires") + def validate_refresh_token_expires( + cls, v: Optional[Union[StrictBool, StrictInt, timedelta]] + ) -> Optional[Union[StrictBool, StrictInt, timedelta]]: + if v is True: + raise ValueError("The 'authjwt_refresh_token_expires' only accept value False (bool)") + return v + + @field_validator("authjwt_cookie_samesite") + def validate_cookie_samesite(cls, v: Optional[StrictStr]) -> Optional[StrictStr]: + if v not in ["strict", "lax", "none"]: + raise ValueError("The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'") + return v + + @model_validator(mode="before") + def check_type_validity(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + for _ in values.get("authjwt_csrf_methods", []): + if _.upper() not in ["POST", "PUT", "PATCH", "DELETE"]: + raise ValidationError( + f"The 'authjwt_csrf_methods' must be between http request methods and it's {_.upper()}" + ) + + for _ in values.get("authjwt_cookie_samesite", []): + if _ not in ["strict", "lax", "none"]: + raise ValidationError( + f"The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none' and it's {_}" + ) + + for _ in values.get("authjwt_token_location", []): + if _ not in ["headers", "cookies"]: + raise ValidationError( + f"The 'authjwt_token_location' must be between 'headers' or 'cookies' and it's {_}" + ) + + return values + + class Config: + str_min_length = 1 + str_strip_whitespace = True diff --git a/antarest/fastapi_jwt_auth/exceptions.py b/antarest/fastapi_jwt_auth/exceptions.py new file mode 100644 index 0000000000..1057571c0a --- /dev/null +++ b/antarest/fastapi_jwt_auth/exceptions.py @@ -0,0 +1,89 @@ +class AuthJWTException(Exception): + """ + Base except which all fastapi_jwt_auth errors extend + """ + + pass + + +class InvalidHeaderError(AuthJWTException): + """ + An error getting jwt in header or jwt header information from a request + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class JWTDecodeError(AuthJWTException): + """ + An error decoding a JWT + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class CSRFError(AuthJWTException): + """ + An error with CSRF protection + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class MissingTokenError(AuthJWTException): + """ + Error raised when token not found + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class RevokedTokenError(AuthJWTException): + """ + Error raised when a revoked token attempt to access a protected endpoint + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class AccessTokenRequired(AuthJWTException): + """ + Error raised when a valid, non-access JWT attempt to access an endpoint + protected by jwt_required, jwt_optional, fresh_jwt_required + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class RefreshTokenRequired(AuthJWTException): + """ + Error raised when a valid, non-refresh JWT attempt to access an endpoint + protected by jwt_refresh_token_required + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class FreshTokenRequired(AuthJWTException): + """ + Error raised when a valid, non-fresh JWT attempt to access an endpoint + protected by fresh_jwt_required + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message diff --git a/antarest/gui.py b/antarest/gui.py index 399b0a2518..3d1d2d40e3 100644 --- a/antarest/gui.py +++ b/antarest/gui.py @@ -6,16 +6,8 @@ from multiprocessing import Process from pathlib import Path -try: - # `httpx` is a modern alternative to the `requests` library - import httpx as requests - from httpx import ConnectError as ConnectionError -except ImportError: - # noinspection PyUnresolvedReferences, PyPackageRequirements - import requests - from requests import ConnectionError - -import uvicorn # type: ignore +import httpx +import uvicorn from PyQt5.QtGui import QIcon from PyQt5.QtWidgets import QAction, QApplication, QMenu, QSystemTrayIcon @@ -90,8 +82,8 @@ def main() -> None: ) server.start() for _ in range(30, 0, -1): - with contextlib.suppress(ConnectionError): - res = requests.get("http://localhost:8080") + with contextlib.suppress(httpx.ConnectError): + res = httpx.get("http://localhost:8080") if res.status_code == 200: break time.sleep(1) diff --git a/antarest/launcher/adapters/abstractlauncher.py b/antarest/launcher/adapters/abstractlauncher.py index 121f952899..78e4f0dc7b 100644 --- a/antarest/launcher/adapters/abstractlauncher.py +++ b/antarest/launcher/adapters/abstractlauncher.py @@ -88,7 +88,7 @@ def update_log(log_line: str) -> None: ) launch_progress_json = self.cache.get(id=f"Launch_Progress_{job_id}") or {} - launch_progress_dto = LaunchProgressDTO.parse_obj(launch_progress_json) + launch_progress_dto = LaunchProgressDTO.model_validate(launch_progress_json) if launch_progress_dto.parse_log_lines(log_line.splitlines()): self.event_bus.push( Event( @@ -102,6 +102,6 @@ def update_log(log_line: str) -> None: channel=EventChannelDirectory.JOB_STATUS + job_id, ) ) - self.cache.put(f"Launch_Progress_{job_id}", launch_progress_dto.dict()) + self.cache.put(f"Launch_Progress_{job_id}", launch_progress_dto.model_dump()) return update_log diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index 3bd3427a07..35400a05b0 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -8,12 +8,12 @@ from sqlalchemy.orm import relationship # type: ignore from antarest.core.persistence import Base -from antarest.core.utils.string import to_camel_case from antarest.login.model import Identity, UserInfo +from antarest.study.business.all_optional_meta import camel_case_model class XpansionParametersDTO(BaseModel): - output_id: t.Optional[str] + output_id: t.Optional[str] = None sensitivity_mode: bool = False enabled: bool = True @@ -42,7 +42,7 @@ def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO """ if params is None: return cls() - return cls.parse_obj(json.loads(params)) + return cls.model_validate(json.loads(params)) class LogType(str, enum.Enum): @@ -111,24 +111,6 @@ class JobResultDTO(BaseModel): solver_stats: t.Optional[str] owner: t.Optional[UserInfo] - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = JobResultDTO( - id="b2a9f6a7-7f8f-4f7a-9a8b-1f9b4c5d6e7f", - study_id="b2a9f6a7-7f8f-4f7a-9a8b-1f9b4c5d6e7f", - launcher="slurm", - launcher_params='{"nb_cpu": 4, "time_limit": 3600}', - status=JobStatus.SUCCESS, - creation_date="2023-11-25 12:00:00", - completion_date="2023-11-25 12:27:31", - msg="Study successfully executed", - output_id="20231125-1227eco", - exit_code=0, - solver_stats="time: 1651s, call_count: 1, optimization_issues: []", - owner=UserInfo(id=0o007, name="James BOND"), - ) - class JobLog(Base): # type: ignore __tablename__ = "launcherjoblog" @@ -228,13 +210,8 @@ class LauncherEnginesDTO(BaseModel): engines: t.List[str] -class LauncherLoadDTO( - BaseModel, - extra="forbid", - validate_assignment=True, - allow_population_by_field_name=True, - alias_generator=to_camel_case, -): +@camel_case_model +class LauncherLoadDTO(BaseModel, extra="forbid", validate_assignment=True, populate_by_name=True): """ DTO representing the load of the SLURM cluster or local machine. diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index d3a439e17c..d3780de19a 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -10,7 +10,7 @@ from fastapi import HTTPException -from antarest.core.config import Config, NbCoresConfig +from antarest.core.config import Config, Launcher, NbCoresConfig from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager @@ -103,7 +103,7 @@ def _init_extensions(self) -> Dict[str, ILauncherExtension]: def get_launchers(self) -> List[str]: return list(self.launchers.keys()) - def get_nb_cores(self, launcher: str) -> NbCoresConfig: + def get_nb_cores(self, launcher: Launcher) -> NbCoresConfig: """ Retrieve the configuration of the launcher's nb of cores. @@ -112,9 +112,6 @@ def get_nb_cores(self, launcher: str) -> NbCoresConfig: Returns: Number of cores of the launcher - - Raises: - InvalidConfigurationError: if the launcher configuration is not available """ return self.config.launcher.get_nb_cores(launcher) @@ -175,7 +172,7 @@ def update( self.event_bus.push( Event( type=EventType.STUDY_JOB_COMPLETED if final_status else EventType.STUDY_JOB_STATUS_UPDATE, - payload=job_result.to_dto().dict(), + payload=job_result.to_dto().model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.JOB_STATUS + job_result.id, ) @@ -236,7 +233,7 @@ def run_study( study_id=study_uuid, job_status=JobStatus.PENDING, launcher=launcher, - launcher_params=launcher_parameters.json() if launcher_parameters else None, + launcher_params=launcher_parameters.model_dump_json() if launcher_parameters else None, owner_id=(owner_id or None), ) self.job_result_repository.save(job_status) @@ -252,7 +249,7 @@ def run_study( self.event_bus.push( Event( type=EventType.STUDY_JOB_STARTED, - payload=job_status.to_dto().dict(), + payload=job_status.to_dto().model_dump(), permissions=PermissionInfo.from_study(study_info), ) ) @@ -293,7 +290,7 @@ def kill_job(self, job_id: str, params: RequestParameters) -> JobResult: self.event_bus.push( Event( type=EventType.STUDY_JOB_CANCELLED, - payload=job_status.to_dto().dict(), + payload=job_status.to_dto().model_dump(), permissions=PermissionInfo.from_study(study), channel=EventChannelDirectory.JOB_STATUS + job_result.id, ) @@ -710,5 +707,7 @@ def get_launch_progress(self, job_id: str, params: RequestParameters) -> float: if launcher is None: raise ValueError(f"Job {job_id} has no launcher") - launch_progress_json = self.launchers[launcher].cache.get(id=f"Launch_Progress_{job_id}") or {"progress": 0} + launch_progress_json: Dict[str, float] = self.launchers[launcher].cache.get(id=f"Launch_Progress_{job_id}") or { + "progress": 0 + } return launch_progress_json.get("progress", 0) diff --git a/antarest/launcher/ssh_client.py b/antarest/launcher/ssh_client.py index e52cb0072c..450fd6a92a 100644 --- a/antarest/launcher/ssh_client.py +++ b/antarest/launcher/ssh_client.py @@ -31,7 +31,7 @@ class SlurmError(Exception): def execute_command(ssh_config: SSHConfigDTO, args: List[str]) -> Any: command = " ".join(args) try: - with ssh_client(ssh_config) as client: # type: ignore + with ssh_client(ssh_config) as client: # type: paramiko.SSHClient _, stdout, stderr = client.exec_command(command, timeout=10) output = stdout.read().decode("utf-8").strip() error = stderr.read().decode("utf-8").strip() diff --git a/antarest/launcher/ssh_config.py b/antarest/launcher/ssh_config.py index 1fa4a4393c..3661751921 100644 --- a/antarest/launcher/ssh_config.py +++ b/antarest/launcher/ssh_config.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional import paramiko -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, model_validator class SSHConfigDTO(BaseModel): @@ -14,7 +14,7 @@ class SSHConfigDTO(BaseModel): key_password: Optional[str] = "" password: Optional[str] = "" - @root_validator() + @model_validator(mode="before") def validate_connection_information(cls, values: Dict[str, Any]) -> Dict[str, Any]: if "private_key_file" not in values and "password" not in values: raise paramiko.AuthenticationException("SSH config needs at least a private key or a password") diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index 9635c664f1..337dd3aba0 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends from fastapi.exceptions import HTTPException -from antarest.core.config import Config, InvalidConfigurationError +from antarest.core.config import Config, InvalidConfigurationError, Launcher from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters @@ -41,25 +41,6 @@ def __init__(self, solver: str) -> None: ) -LauncherQuery = Query( - "default", - examples={ - "Default launcher": { - "description": "Default solver (auto-detected)", - "value": "default", - }, - "SLURM launcher": { - "description": "SLURM solver configuration", - "value": "slurm", - }, - "Local launcher": { - "description": "Local solver configuration", - "value": "local", - }, - }, -) - - def create_launcher_api(service: LauncherService, config: Config) -> APIRouter: bp = APIRouter(prefix="/v1/launcher") @@ -233,7 +214,7 @@ def get_load() -> LauncherLoadDTO: summary="Get list of supported solver versions", response_model=List[str], ) - def get_solver_versions(solver: str = LauncherQuery) -> List[str]: + def get_solver_versions(solver: Launcher = Launcher.DEFAULT) -> List[str]: """ Get list of supported solver versions defined in the configuration. @@ -241,8 +222,6 @@ def get_solver_versions(solver: str = LauncherQuery) -> List[str]: - `solver`: name of the configuration to read: "default", "slurm" or "local". """ logger.info(f"Fetching the list of solver versions for the '{solver}' configuration") - if solver not in {"default", "slurm", "local"}: - raise UnknownSolverConfig(solver) return service.get_solver_versions(solver) # noinspection SpellCheckingInspection @@ -252,7 +231,7 @@ def get_solver_versions(solver: str = LauncherQuery) -> List[str]: summary="Retrieving Min, Default, and Max Core Count", response_model=Dict[str, int], ) - def get_nb_cores(launcher: str = LauncherQuery) -> Dict[str, int]: + def get_nb_cores(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]: """ Retrieve the numer of cores of the launcher. @@ -277,7 +256,7 @@ def get_nb_cores(launcher: str = LauncherQuery) -> Dict[str, int]: tags=[APITag.launcher], summary="Retrieve the time limit for a job (in hours)", ) - def get_time_limit(launcher: str = LauncherQuery) -> Dict[str, int]: + def get_time_limit(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]: """ Retrieve the time limit for a job (in hours) of the given launcher: "local" or "slurm". diff --git a/antarest/login/auth.py b/antarest/login/auth.py index cc7638b728..9c08f0fe99 100644 --- a/antarest/login/auth.py +++ b/antarest/login/auth.py @@ -4,13 +4,13 @@ from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union from fastapi import Depends -from fastapi_jwt_auth import AuthJWT # type: ignore from pydantic import BaseModel from ratelimit.types import Scope # type: ignore from starlette.requests import Request from antarest.core.config import Config from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser +from antarest.fastapi_jwt_auth import AuthJWT logger = logging.getLogger(__name__) @@ -54,14 +54,14 @@ def get_current_user(self, auth_jwt: AuthJWT = Depends()) -> JWTUser: auth_jwt.jwt_required() - user = JWTUser.parse_obj(json.loads(auth_jwt.get_jwt_subject())) + user = JWTUser.model_validate(json.loads(auth_jwt.get_jwt_subject())) return user @staticmethod def get_user_from_token(token: str, jwt_manager: AuthJWT) -> Optional[JWTUser]: try: token_data = jwt_manager._verified_token(token) - return JWTUser.parse_obj(json.loads(token_data["sub"])) + return JWTUser.model_validate(json.loads(token_data["sub"])) except Exception as e: logger.debug("Failed to retrieve user from token", exc_info=e) return None diff --git a/antarest/login/ldap.py b/antarest/login/ldap.py index d558238288..9524714da1 100644 --- a/antarest/login/ldap.py +++ b/antarest/login/ldap.py @@ -2,12 +2,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional -try: - # `httpx` is a modern alternative to the `requests` library - import httpx as requests -except ImportError: - # noinspection PyUnresolvedReferences, PyPackageRequirements - import requests +import httpx from antarest.core.config import Config from antarest.core.model import JSON @@ -98,7 +93,7 @@ def _fetch(self, name: str, password: str) -> Optional[ExternalUser]: auth = AuthDTO(user=name, password=password) try: - res = requests.post(url=f"{self.url}/auth", json=auth.to_json()) + res = httpx.post(url=f"{self.url}/auth", json=auth.to_json()) except Exception as e: logger.warning( "Failed to retrieve user from external auth service", diff --git a/antarest/login/main.py b/antarest/login/main.py index d87a082abd..c7936ee7a2 100644 --- a/antarest/login/main.py +++ b/antarest/login/main.py @@ -3,14 +3,14 @@ from typing import Any, Optional from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT # type: ignore -from fastapi_jwt_auth.exceptions import AuthJWTException # type: ignore from starlette.requests import Request from starlette.responses import JSONResponse from antarest.core.config import Config from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.fastapi_jwt_auth import AuthJWT +from antarest.fastapi_jwt_auth.exceptions import AuthJWTException from antarest.login.ldap import LdapService from antarest.login.repository import BotRepository, GroupRepository, RoleRepository, UserLdapRepository, UserRepository from antarest.login.service import LoginService diff --git a/antarest/login/service.py b/antarest/login/service.py index 55af0ab12c..01e66fdd7a 100644 --- a/antarest/login/service.py +++ b/antarest/login/service.py @@ -473,7 +473,7 @@ def authenticate(self, name: str, pwd: str) -> Optional[JWTUser]: """ intern: Optional[User] = self.users.get_by_name(name) - if intern and intern.password.check(pwd): # type: ignore + if intern and intern.password.check(pwd): logger.info("successful login from intern user %s", name) return self.get_jwt(intern.id) diff --git a/antarest/login/web.py b/antarest/login/web.py index 801325df7c..660f2753db 100644 --- a/antarest/login/web.py +++ b/antarest/login/web.py @@ -4,7 +4,6 @@ from typing import Any, List, Optional, Union from fastapi import APIRouter, Depends, HTTPException -from fastapi_jwt_auth import AuthJWT # type: ignore from markupsafe import escape from pydantic import BaseModel @@ -13,6 +12,7 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.roles import RoleType from antarest.core.utils.web import APITag +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth from antarest.login.model import ( BotCreateDTO, @@ -55,8 +55,8 @@ def create_login_api(service: LoginService, config: Config) -> APIRouter: auth = Auth(config) def generate_tokens(user: JWTUser, jwt_manager: AuthJWT, expire: Optional[timedelta] = None) -> CredentialsDTO: - access_token = jwt_manager.create_access_token(subject=user.json(), expires_time=expire) - refresh_token = jwt_manager.create_refresh_token(subject=user.json()) + access_token = jwt_manager.create_access_token(subject=user.model_dump_json(), expires_time=expire) + refresh_token = jwt_manager.create_refresh_token(subject=user.model_dump_json()) return CredentialsDTO( user=user.id, access_token=access_token.decode() if isinstance(access_token, bytes) else access_token, @@ -114,11 +114,7 @@ def users_get_all( params = RequestParameters(user=current_user) return service.get_all_users(params, details) - @bp.get( - "/users/{id}", - tags=[APITag.users], - response_model=Union[IdentityDTO, UserInfo], # type: ignore - ) + @bp.get("/users/{id}", tags=[APITag.users], response_model=Union[IdentityDTO, UserInfo]) def users_get_id( id: int, details: bool = False, @@ -192,11 +188,7 @@ def groups_get_all( params = RequestParameters(user=current_user) return service.get_all_groups(params, details) - @bp.get( - "/groups/{id}", - tags=[APITag.users], - response_model=Union[GroupDetailDTO, GroupDTO], # type: ignore - ) + @bp.get("/groups/{id}", tags=[APITag.users], response_model=Union[GroupDetailDTO, GroupDTO]) def groups_get_id( id: str, details: bool = False, @@ -314,11 +306,7 @@ def bots_create( tokens = generate_tokens(jwt, jwt_manager, expire=timedelta(days=368 * 200)) return tokens.access_token - @bp.get( - "/bots/{id}", - tags=[APITag.users], - response_model=Union[BotIdentityDTO, BotDTO], # type: ignore - ) + @bp.get("/bots/{id}", tags=[APITag.users], response_model=Union[BotIdentityDTO, BotDTO]) def get_bot( id: int, verbose: Optional[int] = None, diff --git a/antarest/main.py b/antarest/main.py index 3973c79ace..cf7d42a392 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -6,12 +6,11 @@ from typing import Any, Dict, Optional, Sequence, Tuple, cast import pydantic -import uvicorn # type: ignore -import uvicorn.config # type: ignore +import uvicorn +import uvicorn.config from fastapi import FastAPI, HTTPException from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError -from fastapi_jwt_auth import AuthJWT # type: ignore from ratelimit import RateLimitMiddleware # type: ignore from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore @@ -34,6 +33,7 @@ from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth, JwtSettings from antarest.login.model import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector @@ -249,6 +249,10 @@ def fastapi_app( # Database engine = init_db_engine(config_file, config, auto_upgrade_db) application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS) + # Since Starlette Version 0.24.0, the middlewares are lazily build inside this function + # But we need to instantiate this middleware as it's needed for the study service. + # So we manually instantiate it here. + DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS)) application.add_middleware(LoggingMiddleware) diff --git a/antarest/matrixstore/matrix_editor.py b/antarest/matrixstore/matrix_editor.py index ffaf619322..d5b3271dee 100644 --- a/antarest/matrixstore/matrix_editor.py +++ b/antarest/matrixstore/matrix_editor.py @@ -2,7 +2,7 @@ import operator from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Extra, Field, root_validator, validator +from pydantic import BaseModel, Field, field_validator, model_validator class MatrixSlice(BaseModel): @@ -23,8 +23,8 @@ class MatrixSlice(BaseModel): column_to: int class Config: - extra = Extra.forbid - schema_extra = { + extra = "forbid" + json_schema_extra = { "example": { "column_from": 5, "column_to": 8, @@ -33,7 +33,7 @@ class Config: } } - @root_validator(pre=True) + @model_validator(mode="before") def check_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Converts and validates the slice coordinates. @@ -95,12 +95,12 @@ class Operation(BaseModel): - `value`: The value associated with the operation. """ - operation: str = Field(regex=r"[=/*+-]|ABS") + operation: str = Field(pattern=r"[=/*+-]|ABS") value: float class Config: - extra = Extra.forbid - schema_extra = {"example": {"operation": "=", "value": 120.0}} + extra = "forbid" + json_schema_extra = {"example": {"operation": "=", "value": 120.0}} # noinspection SpellCheckingInspection def compute(self, x: Any, use_coords: bool = False) -> Any: @@ -145,16 +145,9 @@ class MatrixEditInstruction(BaseModel): operation: Operation class Config: - extra = Extra.forbid + extra = "forbid" - @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - schema["example"] = MatrixEditInstruction( - coordinates=[(0, 10), (0, 11), (0, 12)], - operation=Operation(operation="=", value=120.0), - ) - - @root_validator(pre=True) + @model_validator(mode="before") def check_slice_coordinates(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validates the 'slices' and 'coordinates' fields. @@ -179,7 +172,7 @@ def check_slice_coordinates(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values - @validator("coordinates") + @field_validator("coordinates") def validate_coordinates(cls, coordinates: Optional[List[Tuple[int, int]]]) -> Optional[List[Tuple[int, int]]]: """ Validates the `coordinates` field. diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index 33623c7fb4..666a7e1b71 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -152,7 +152,7 @@ def _from_dto(dto: MatrixDTO) -> t.Tuple[Matrix, MatrixContent]: created_at=datetime.fromtimestamp(dto.created_at), ) - content = MatrixContent(data=dto.data, index=dto.index, columns=dto.columns) + content = MatrixContent(data=dto.data, index=dto.index, columns=dto.columns) # type: ignore return matrix, content @@ -213,6 +213,7 @@ def create_by_importation(self, file: UploadFile, is_json: bool = False) -> t.Li A list of `MatrixInfoDTO` objects containing the SHA256 hash of the imported matrices. """ with file.file as f: + assert file.filename is not None if file.content_type == "application/zip": with contextlib.closing(f): buffer = io.BytesIO(f.read()) diff --git a/antarest/study/business/adequacy_patch_management.py b/antarest/study/business/adequacy_patch_management.py index 7d837df7b6..34313f1239 100644 --- a/antarest/study/business/adequacy_patch_management.py +++ b/antarest/study/business/adequacy_patch_management.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from pydantic.types import StrictBool, confloat, conint +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -17,18 +18,19 @@ class PriceTakingOrder(EnumIgnoreCase): ThresholdType = confloat(ge=0) +@all_optional_model class AdequacyPatchFormFields(FormFieldsBaseModel): # version 830 - enable_adequacy_patch: Optional[StrictBool] - ntc_from_physical_areas_out_to_physical_areas_in_adequacy_patch: Optional[StrictBool] - ntc_between_physical_areas_out_adequacy_patch: Optional[StrictBool] + enable_adequacy_patch: StrictBool + ntc_from_physical_areas_out_to_physical_areas_in_adequacy_patch: StrictBool + ntc_between_physical_areas_out_adequacy_patch: StrictBool # version 850 - price_taking_order: Optional[PriceTakingOrder] - include_hurdle_cost_csr: Optional[StrictBool] - check_csr_cost_function: Optional[StrictBool] - threshold_initiate_curtailment_sharing_rule: Optional[ThresholdType] # type: ignore - threshold_display_local_matching_rule_violations: Optional[ThresholdType] # type: ignore - threshold_csr_variable_bounds_relaxation: Optional[conint(ge=0, strict=True)] # type: ignore + price_taking_order: PriceTakingOrder + include_hurdle_cost_csr: StrictBool + check_csr_cost_function: StrictBool + threshold_initiate_curtailment_sharing_rule: ThresholdType # type: ignore + threshold_display_local_matching_rule_violations: ThresholdType # type: ignore + threshold_csr_variable_bounds_relaxation: conint(ge=0, strict=True) # type: ignore ADEQUACY_PATCH_PATH = f"{GENERAL_DATA_PATH}/adequacy patch" diff --git a/antarest/study/business/advanced_parameters_management.py b/antarest/study/business/advanced_parameters_management.py index 46c58f36d6..23251cf164 100644 --- a/antarest/study/business/advanced_parameters_management.py +++ b/antarest/study/business/advanced_parameters_management.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -from pydantic import validator +from pydantic import field_validator from pydantic.types import StrictInt, StrictStr from antarest.core.exceptions import InvalidFieldForVersionError +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -60,33 +61,34 @@ class RenewableGenerationModeling(EnumIgnoreCase): CLUSTERS = "clusters" +@all_optional_model class AdvancedParamsFormFields(FormFieldsBaseModel): # Advanced parameters - accuracy_on_correlation: Optional[StrictStr] + accuracy_on_correlation: StrictStr # Other preferences - initial_reservoir_levels: Optional[InitialReservoirLevel] - power_fluctuations: Optional[PowerFluctuation] - shedding_policy: Optional[SheddingPolicy] - hydro_pricing_mode: Optional[HydroPricingMode] - hydro_heuristic_policy: Optional[HydroHeuristicPolicy] - unit_commitment_mode: Optional[UnitCommitmentMode] - number_of_cores_mode: Optional[SimulationCore] - day_ahead_reserve_management: Optional[ReserveManagement] - renewable_generation_modelling: Optional[RenewableGenerationModeling] + initial_reservoir_levels: InitialReservoirLevel + power_fluctuations: PowerFluctuation + shedding_policy: SheddingPolicy + hydro_pricing_mode: HydroPricingMode + hydro_heuristic_policy: HydroHeuristicPolicy + unit_commitment_mode: UnitCommitmentMode + number_of_cores_mode: SimulationCore + day_ahead_reserve_management: ReserveManagement + renewable_generation_modelling: RenewableGenerationModeling # Seeds - seed_tsgen_wind: Optional[StrictInt] - seed_tsgen_load: Optional[StrictInt] - seed_tsgen_hydro: Optional[StrictInt] - seed_tsgen_thermal: Optional[StrictInt] - seed_tsgen_solar: Optional[StrictInt] - seed_tsnumbers: Optional[StrictInt] - seed_unsupplied_energy_costs: Optional[StrictInt] - seed_spilled_energy_costs: Optional[StrictInt] - seed_thermal_costs: Optional[StrictInt] - seed_hydro_costs: Optional[StrictInt] - seed_initial_reservoir_levels: Optional[StrictInt] - - @validator("accuracy_on_correlation") + seed_tsgen_wind: StrictInt + seed_tsgen_load: StrictInt + seed_tsgen_hydro: StrictInt + seed_tsgen_thermal: StrictInt + seed_tsgen_solar: StrictInt + seed_tsnumbers: StrictInt + seed_unsupplied_energy_costs: StrictInt + seed_spilled_energy_costs: StrictInt + seed_thermal_costs: StrictInt + seed_hydro_costs: StrictInt + seed_initial_reservoir_levels: StrictInt + + @field_validator("accuracy_on_correlation") def check_accuracy_on_correlation(cls, v: str) -> str: sanitized_v = v.strip().replace(" ", "") if not sanitized_v: diff --git a/antarest/study/business/all_optional_meta.py b/antarest/study/business/all_optional_meta.py index 06ddc012d8..bd7e31d04e 100644 --- a/antarest/study/business/all_optional_meta.py +++ b/antarest/study/business/all_optional_meta.py @@ -1,84 +1,23 @@ +import copy import typing as t -import pydantic.fields -import pydantic.main -from pydantic import BaseModel +from pydantic import BaseModel, create_model from antarest.core.utils.string import to_camel_case -class AllOptionalMetaclass(pydantic.main.ModelMetaclass): - """ - Metaclass that makes all fields of a Pydantic model optional. - - Usage: - class MyModel(BaseModel, metaclass=AllOptionalMetaclass): - field1: str - field2: int - ... - - Instances of the model can be created even if not all fields are provided during initialization. - Default values, when provided, are used unless `use_none` is set to `True`. - """ - - def __new__( - cls: t.Type["AllOptionalMetaclass"], - name: str, - bases: t.Tuple[t.Type[t.Any], ...], - namespaces: t.Dict[str, t.Any], - use_none: bool = False, - **kwargs: t.Dict[str, t.Any], - ) -> t.Any: - """ - Create a new instance of the metaclass. - - Args: - name: Name of the class to create. - bases: Base classes of the class to create (a Pydantic model). - namespaces: namespace of the class to create that defines the fields of the model. - use_none: If `True`, the default value of the fields is set to `None`. - Note that this field is not part of the Pydantic model, but it is an extension. - **kwargs: Additional keyword arguments used by the metaclass. - """ - # Modify the annotations of the class (but not of the ancestor classes) - # in order to make all fields optional. - # If the current model inherits from another model, the annotations of the ancestor models - # are not modified, because the fields are already converted to `ModelField`. - annotations = namespaces.get("__annotations__", {}) - for field_name, field_type in annotations.items(): - if not field_name.startswith("__"): - # Making already optional fields optional is not a problem (nothing is changed). - annotations[field_name] = t.Optional[field_type] - namespaces["__annotations__"] = annotations - - if use_none: - # Modify the namespace fields to set their default value to `None`. - for field_name, field_info in namespaces.items(): - if isinstance(field_info, pydantic.fields.FieldInfo): - field_info.default = None - field_info.default_factory = None - - # Create the class: all annotations are converted into `ModelField`. - instance = super().__new__(cls, name, bases, namespaces, **kwargs) - - # Modify the inherited fields of the class to make them optional - # and set their default value to `None`. - model_field: pydantic.fields.ModelField - for field_name, model_field in instance.__fields__.items(): - model_field.required = False - model_field.allow_none = True - if use_none: - model_field.default = None - model_field.default_factory = None - model_field.field_info.default = None - - return instance - +def all_optional_model(model: t.Type[BaseModel]) -> t.Type[BaseModel]: + kwargs = {} + for field_name, field_info in model.model_fields.items(): + new = copy.deepcopy(field_info) + new.default = None + new.annotation = t.Optional[field_info.annotation] # type: ignore + kwargs[field_name] = (new.annotation, new) -MODEL = t.TypeVar("MODEL", bound=t.Type[BaseModel]) + return create_model(f"Partial{model.__name__}", __base__=model, __module__=model.__module__, **kwargs) # type: ignore -def camel_case_model(model: MODEL) -> MODEL: +def camel_case_model(model: t.Type[BaseModel]) -> t.Type[BaseModel]: """ This decorator can be used to modify a model to use camel case aliases. @@ -88,7 +27,10 @@ def camel_case_model(model: MODEL) -> MODEL: Returns: The modified model. """ - model.__config__.alias_generator = to_camel_case - for field_name, field in model.__fields__.items(): - field.alias = to_camel_case(field_name) + model.model_config["alias_generator"] = to_camel_case + for field_name, field in model.model_fields.items(): + new_alias = to_camel_case(field_name) + field.alias = new_alias + field.validation_alias = new_alias + field.serialization_alias = new_alias return model diff --git a/antarest/study/business/allocation_management.py b/antarest/study/business/allocation_management.py index 8c539ceac3..a5e0809daf 100644 --- a/antarest/study/business/allocation_management.py +++ b/antarest/study/business/allocation_management.py @@ -1,8 +1,10 @@ -from typing import Dict, List +from typing import Dict, List, Union import numpy import numpy as np -from pydantic import conlist, root_validator, validator +from annotated_types import Len +from pydantic import ValidationInfo, field_validator, model_validator +from typing_extensions import Annotated from antarest.core.exceptions import AllocationDataNotFound, AreaNotFound from antarest.study.business.area_management import AreaInfoDTO @@ -24,9 +26,9 @@ class AllocationFormFields(FormFieldsBaseModel): allocation: List[AllocationField] - @root_validator - def check_allocation(cls, values: Dict[str, List[AllocationField]]) -> Dict[str, List[AllocationField]]: - allocation = values.get("allocation", []) + @model_validator(mode="after") + def check_allocation(self) -> "AllocationFormFields": + allocation = self.allocation if not allocation: raise ValueError("allocation must not be empty") @@ -44,7 +46,7 @@ def check_allocation(cls, values: Dict[str, List[AllocationField]]) -> Dict[str, if sum(a.coefficient for a in allocation) <= 0: raise ValueError("sum of allocation coefficients must be positive") - return values + return self class AllocationMatrix(FormFieldsBaseModel): @@ -55,14 +57,14 @@ class AllocationMatrix(FormFieldsBaseModel): data: 2D-array matrix of consumption coefficients """ - index: conlist(str, min_items=1) # type: ignore - columns: conlist(str, min_items=1) # type: ignore + index: Annotated[List[str], Len(min_length=1)] + columns: Annotated[List[str], Len(min_length=1)] data: List[List[float]] # NonNegativeFloat not necessary # noinspection PyMethodParameters - @validator("data") + @field_validator("data", mode="before") def validate_hydro_allocation_matrix( - cls, data: List[List[float]], values: Dict[str, List[str]] + cls, data: List[List[float]], values: Union[Dict[str, List[str]], ValidationInfo] ) -> List[List[float]]: """ Validate the hydraulic allocation matrix. @@ -77,8 +79,9 @@ def validate_hydro_allocation_matrix( """ array = np.array(data) - rows = len(values.get("index", [])) - cols = len(values.get("columns", [])) + new_values = values if isinstance(values, dict) else values.data + rows = len(new_values.get("index", [])) + cols = len(new_values.get("columns", [])) if array.size == 0: raise ValueError("allocation matrix must not be empty") @@ -124,7 +127,7 @@ def get_allocation_data(self, study: Study, area_id: str) -> Dict[str, List[Allo if not allocation_data: raise AllocationDataNotFound(area_id) - return allocation_data.get("[allocation]", {}) + return allocation_data.get("[allocation]", {}) # type: ignore def get_allocation_form_fields( self, all_areas: List[AreaInfoDTO], study: Study, area_id: str @@ -148,13 +151,10 @@ def get_allocation_form_fields( allocations = self.get_allocation_data(study, area_id) filtered_allocations = {area: value for area, value in allocations.items() if area in areas_ids} - - return AllocationFormFields.construct( - allocation=[ - AllocationField.construct(area_id=area, coefficient=value) - for area, value in filtered_allocations.items() - ] - ) + final_allocations = [ + AllocationField.construct(area_id=area, coefficient=value) for area, value in filtered_allocations.items() + ] + return AllocationFormFields.model_validate({"allocation": final_allocations}) def set_allocation_form_fields( self, diff --git a/antarest/study/business/area_management.py b/antarest/study/business/area_management.py index 6c87d8cb80..188720aea7 100644 --- a/antarest/study/business/area_management.py +++ b/antarest/study/business/area_management.py @@ -3,11 +3,11 @@ import re import typing as t -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Field from antarest.core.exceptions import ConfigFileNotFound, DuplicateAreaName, LayerNotAllowedToBeDeleted, LayerNotFound from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Patch, PatchArea, PatchCluster, RawStudy, Study from antarest.study.repository import StudyMetadataRepository @@ -38,8 +38,8 @@ class AreaType(enum.Enum): class AreaCreationDTO(BaseModel): name: str type: AreaType - metadata: t.Optional[PatchArea] - set: t.Optional[t.List[str]] + metadata: t.Optional[PatchArea] = None + set: t.Optional[t.List[str]] = None # review: is this class necessary? @@ -70,7 +70,7 @@ class LayerInfoDTO(BaseModel): areas: t.List[str] -class UpdateAreaUi(BaseModel, extra="forbid", allow_population_by_field_name=True): +class UpdateAreaUi(BaseModel, extra="forbid", populate_by_name=True): """ DTO for updating area UI @@ -95,7 +95,7 @@ class UpdateAreaUi(BaseModel, extra="forbid", allow_population_by_field_name=Tru ... } >>> model = UpdateAreaUi(**obj) - >>> pprint(model.dict(by_alias=True), width=80) + >>> pprint(model.model_dump(by_alias=True), width=80) {'colorRgb': [230, 108, 44], 'layerColor': {0: '230, 108, 44', 4: '230, 108, 44', @@ -167,9 +167,9 @@ class _BaseAreaDTO( OptimizationProperties.FilteringSection, OptimizationProperties.ModalOptimizationSection, AdequacyPathProperties.AdequacyPathSection, - extra=Extra.forbid, + extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Represents an area output. @@ -188,8 +188,9 @@ class _BaseAreaDTO( # noinspection SpellCheckingInspection +@all_optional_model @camel_case_model -class AreaOutput(_BaseAreaDTO, metaclass=AllOptionalMetaclass, use_none=True): +class AreaOutput(_BaseAreaDTO): """ DTO object use to get the area information using a flat structure. """ @@ -215,30 +216,32 @@ def from_model( obj = { "average_unsupplied_energy_cost": average_unsupplied_energy_cost, "average_spilled_energy_cost": average_spilled_energy_cost, - **area_folder.optimization.filtering.dict(by_alias=False), - **area_folder.optimization.nodal_optimization.dict(by_alias=False), + **area_folder.optimization.filtering.model_dump(by_alias=False), + **area_folder.optimization.nodal_optimization.model_dump(by_alias=False), # adequacy_patch is only available if study version >= 830. - **(area_folder.adequacy_patch.adequacy_patch.dict(by_alias=False) if area_folder.adequacy_patch else {}), + **( + area_folder.adequacy_patch.adequacy_patch.model_dump(by_alias=False) + if area_folder.adequacy_patch + else {} + ), } return cls(**obj) def _to_optimization(self) -> OptimizationProperties: - obj = {name: getattr(self, name) for name in OptimizationProperties.FilteringSection.__fields__} + obj = {name: getattr(self, name) for name in OptimizationProperties.FilteringSection.model_fields} filtering_section = OptimizationProperties.FilteringSection(**obj) - obj = {name: getattr(self, name) for name in OptimizationProperties.ModalOptimizationSection.__fields__} + obj = {name: getattr(self, name) for name in OptimizationProperties.ModalOptimizationSection.model_fields} nodal_optimization_section = OptimizationProperties.ModalOptimizationSection(**obj) - return OptimizationProperties( - filtering=filtering_section, - nodal_optimization=nodal_optimization_section, - ) + args = {"filtering": filtering_section, "nodal_optimization": nodal_optimization_section} + return OptimizationProperties.model_validate(args) def _to_adequacy_patch(self) -> t.Optional[AdequacyPathProperties]: - obj = {name: getattr(self, name) for name in AdequacyPathProperties.AdequacyPathSection.__fields__} + obj = {name: getattr(self, name) for name in AdequacyPathProperties.AdequacyPathSection.model_fields} # If all fields are `None`, the object is empty. if all(value is None for value in obj.values()): return None adequacy_path_section = AdequacyPathProperties.AdequacyPathSection(**obj) - return AdequacyPathProperties(adequacy_patch=adequacy_path_section) + return AdequacyPathProperties.model_validate({"adequacy_patch": adequacy_path_section}) @property def area_folder(self) -> AreaFolder: @@ -347,7 +350,7 @@ def update_areas_props( for area_id, update_area in update_areas_by_ids.items(): # Update the area properties. old_area = old_areas_by_ids[area_id] - new_area = old_area.copy(update=update_area.dict(by_alias=False, exclude_none=True)) + new_area = old_area.copy(update=update_area.model_dump(by_alias=False, exclude_none=True)) new_areas_by_ids[area_id] = new_area # Convert the DTO to a configuration object and update the configuration file. @@ -645,7 +648,7 @@ def update_area_metadata( self.patch_service.save(study, patch) return AreaInfoDTO( id=area_id, - name=area_or_set.name if area_or_set is not None else area_id, + name=area_or_set.name if area_or_set is not None else area_id, # type: ignore type=AreaType.AREA if isinstance(area_or_set, Area) else AreaType.DISTRICT, metadata=patch.areas.get(area_id), set=area_or_set.get_areas(list(file_study.config.areas)) if isinstance(area_or_set, DistrictSet) else [], @@ -728,7 +731,7 @@ def update_thermal_cluster_metadata( id=area_id, name=file_study.config.areas[area_id].name, type=AreaType.AREA, - metadata=patch.areas.get(area_id, PatchArea()).dict(), + metadata=patch.areas.get(area_id, PatchArea()), thermals=self._get_clusters(file_study, area_id, patch), set=None, ) diff --git a/antarest/study/business/areas/hydro_management.py b/antarest/study/business/areas/hydro_management.py index e0a52ee4e1..b29abf107d 100644 --- a/antarest/study/business/areas/hydro_management.py +++ b/antarest/study/business/areas/hydro_management.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union -from pydantic import Field +from pydantic import Field, model_validator +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.utils import FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.storage_service import StudyStorageService @@ -24,22 +25,46 @@ class InflowStructure(FormFieldsBaseModel): ) +@all_optional_model class ManagementOptionsFormFields(FormFieldsBaseModel): - inter_daily_breakdown: Optional[float] = Field(ge=0) - intra_daily_modulation: Optional[float] = Field(ge=1) - inter_monthly_breakdown: Optional[float] = Field(ge=0) - reservoir: Optional[bool] - reservoir_capacity: Optional[float] = Field(ge=0) - follow_load: Optional[bool] - use_water: Optional[bool] - hard_bounds: Optional[bool] - initialize_reservoir_date: Optional[int] = Field(ge=0, le=11) - use_heuristic: Optional[bool] - power_to_level: Optional[bool] - use_leeway: Optional[bool] - leeway_low: Optional[float] = Field(ge=0) - leeway_up: Optional[float] = Field(ge=0) - pumping_efficiency: Optional[float] = Field(ge=0) + inter_daily_breakdown: float + intra_daily_modulation: float + inter_monthly_breakdown: float + reservoir: bool + reservoir_capacity: float + follow_load: bool + use_water: bool + hard_bounds: bool + initialize_reservoir_date: int + use_heuristic: bool + power_to_level: bool + use_leeway: bool + leeway_low: float + leeway_up: float + pumping_efficiency: float + + @model_validator(mode="before") + def check_type_validity(cls, values: Dict[str, Any]) -> Dict[str, Optional[Any]]: + cls.validate_ge("inter_daily_breakdown", values.get("inter_daily_breakdown", 0), 0) + cls.validate_ge("intra_daily_modulation", values.get("intra_daily_modulation", 1), 1) + cls.validate_ge("inter_monthly_breakdown", values.get("inter_monthly_breakdown", 0), 0) + cls.validate_ge("reservoir_capacity", values.get("reservoir_capacity", 0), 0) + cls.validate_ge("initialize_reservoir_date", values.get("initialize_reservoir_date", 0), 0) + cls.validate_le("initialize_reservoir_date", values.get("initialize_reservoir_date", 11), 11) + cls.validate_ge("leeway_low", values.get("leeway_low", 0), 0) + cls.validate_ge("leeway_up", values.get("leeway_up", 0), 0) + cls.validate_ge("pumping_efficiency", values.get("pumping_efficiency", 0), 0) + return values + + @staticmethod + def validate_ge(field: str, value: Union[int, float], ge: int) -> None: + if value < ge: + raise ValueError(f"Field {field} must be greater than or equal to {ge}") + + @staticmethod + def validate_le(field: str, value: Union[int, float], le: int) -> None: + if value > le: + raise ValueError(f"Field {field} must be lower than or equal to {le}") HYDRO_PATH = "input/hydro/hydro" diff --git a/antarest/study/business/areas/properties_management.py b/antarest/study/business/areas/properties_management.py index 0bccdad784..4dcef05aab 100644 --- a/antarest/study/business/areas/properties_management.py +++ b/antarest/study/business/areas/properties_management.py @@ -2,9 +2,10 @@ import typing as t from builtins import sorted -from pydantic import root_validator +from pydantic import model_validator from antarest.core.exceptions import ChildNotFoundError +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.utils import FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.area import AdequacyPatchMode @@ -37,18 +38,19 @@ def decode_filter(encoded_value: t.Set[str], current_filter: t.Optional[str] = N return ", ".join(sort_filter_options(encoded_value)) +@all_optional_model class PropertiesFormFields(FormFieldsBaseModel): - energy_cost_unsupplied: t.Optional[float] - energy_cost_spilled: t.Optional[float] - non_dispatch_power: t.Optional[bool] - dispatch_hydro_power: t.Optional[bool] - other_dispatch_power: t.Optional[bool] - filter_synthesis: t.Optional[t.Set[str]] - filter_by_year: t.Optional[t.Set[str]] + energy_cost_unsupplied: float + energy_cost_spilled: float + non_dispatch_power: bool + dispatch_hydro_power: bool + other_dispatch_power: bool + filter_synthesis: t.Set[str] + filter_by_year: t.Set[str] # version 830 - adequacy_patch_mode: t.Optional[AdequacyPatchMode] + adequacy_patch_mode: AdequacyPatchMode - @root_validator + @model_validator(mode="before") def validation(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: filters = { "filter_synthesis": values.get("filter_synthesis"), @@ -131,7 +133,7 @@ def get_value(field_info: FieldInfo) -> t.Any: encode = field_info.get("encode") or (lambda x: x) return encode(val) - return PropertiesFormFields.construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()}) + return PropertiesFormFields.model_construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()}) def set_field_values( self, diff --git a/antarest/study/business/areas/renewable_management.py b/antarest/study/business/areas/renewable_management.py index 1009c9d22c..4ee67289f8 100644 --- a/antarest/study/business/areas/renewable_management.py +++ b/antarest/study/business/areas/renewable_management.py @@ -1,12 +1,11 @@ import collections -import json import typing as t -from pydantic import validator +from pydantic import field_validator from antarest.core.exceptions import DuplicateRenewableCluster, RenewableClusterConfigNotFound, RenewableClusterNotFound from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study @@ -34,24 +33,13 @@ class TimeSeriesInterpretation(EnumIgnoreCase): PRODUCTION_FACTOR = "production-factor" +@all_optional_model @camel_case_model -class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass, use_none=True): +class RenewableClusterInput(RenewableProperties): """ Model representing the data structure required to edit an existing renewable cluster. """ - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = RenewableClusterInput( - group="Gas", - name="Gas Cluster XY", - enabled=False, - unitCount=100, - nominalCapacity=1000.0, - tsInterpretation="power-generation", - ) - class RenewableClusterCreation(RenewableClusterInput): """ @@ -59,7 +47,7 @@ class RenewableClusterCreation(RenewableClusterInput): """ # noinspection Pydantic - @validator("name", pre=True) + @field_validator("name", mode="before") def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. @@ -69,29 +57,17 @@ def validate_name(cls, name: t.Optional[str]) -> str: return name def to_config(self, study_version: t.Union[str, int]) -> RenewableConfigType: - values = self.dict(by_alias=False, exclude_none=True) + values = self.model_dump(by_alias=False, exclude_none=True) return create_renewable_config(study_version=study_version, **values) +@all_optional_model @camel_case_model -class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass, use_none=True): +class RenewableClusterOutput(RenewableConfig): """ Model representing the output data structure to display the details of a renewable cluster. """ - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = RenewableClusterOutput( - id="Gas cluster YZ", - group="Gas", - name="Gas Cluster YZ", - enabled=False, - unitCount=100, - nominalCapacity=1000.0, - tsInterpretation="power-generation", - ) - def create_renewable_output( study_version: t.Union[str, int], @@ -99,7 +75,7 @@ def create_renewable_output( config: t.Mapping[str, t.Any], ) -> "RenewableClusterOutput": obj = create_renewable_config(study_version=study_version, **config, id=cluster_id) - kwargs = obj.dict(by_alias=False) + kwargs = obj.model_dump(by_alias=False) return RenewableClusterOutput(**kwargs) @@ -206,7 +182,7 @@ def _make_create_cluster_cmd(self, area_id: str, cluster: RenewableConfigType) - command = CreateRenewablesCluster( area_id=area_id, cluster_name=cluster.id, - parameters=cluster.dict(by_alias=True, exclude={"id"}), + parameters=cluster.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) return command @@ -269,16 +245,16 @@ def update_cluster( old_config = create_renewable_config(study_version, **values) # use Python values to synchronize Config and Form values - new_values = cluster_data.dict(by_alias=False, exclude_none=True) + new_values = cluster_data.model_dump(by_alias=False, exclude_none=True) new_config = old_config.copy(exclude={"id"}, update=new_values) - new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) + new_data = new_config.model_dump(mode="json", by_alias=True, exclude={"id"}) # create the dict containing the new values using aliases - data: t.Dict[str, t.Any] = { - field.alias: new_data[field.alias] - for field_name, field in new_config.__fields__.items() - if field_name in new_values - } + data: t.Dict[str, t.Any] = {} + for field_name, field in new_config.model_fields.items(): + if field_name in new_values: + name = field.alias if field.alias else field_name + data[name] = new_data[name] # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context @@ -288,7 +264,7 @@ def update_cluster( ] execute_or_add_commands(study, file_study, commands, self.storage_service) - values = new_config.dict(by_alias=False) + values = new_config.model_dump(by_alias=False) return RenewableClusterOutput(**values, id=cluster_id) def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None: @@ -340,7 +316,7 @@ def duplicate_cluster( # Cluster duplication current_cluster = self.get_cluster(study, area_id, source_id) current_cluster.name = new_cluster_name - creation_form = RenewableClusterCreation(**current_cluster.dict(by_alias=False, exclude={"id"})) + creation_form = RenewableClusterCreation(**current_cluster.model_dump(by_alias=False, exclude={"id"})) new_config = creation_form.to_config(study.version) create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) @@ -358,7 +334,7 @@ def duplicate_cluster( execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) - return RenewableClusterOutput(**new_config.dict(by_alias=False)) + return RenewableClusterOutput(**new_config.model_dump(by_alias=False)) def update_renewables_props( self, @@ -375,17 +351,17 @@ def update_renewables_props( for renewable_id, update_cluster in update_renewables_by_ids.items(): # Update the renewable cluster properties. old_cluster = old_renewables_by_ids[renewable_id] - new_cluster = old_cluster.copy(update=update_cluster.dict(by_alias=False, exclude_none=True)) + new_cluster = old_cluster.copy(update=update_cluster.model_dump(by_alias=False, exclude_none=True)) new_renewables_by_areas[area_id][renewable_id] = new_cluster # Convert the DTO to a configuration object and update the configuration file. properties = create_renewable_config( - study.version, **new_cluster.dict(by_alias=False, exclude_none=True) + study.version, **new_cluster.model_dump(by_alias=False, exclude_none=True) ) path = _CLUSTER_PATH.format(area_id=area_id, cluster_id=renewable_id) cmd = UpdateConfig( target=path, - data=json.loads(properties.json(by_alias=True, exclude={"id"})), + data=properties.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) commands.append(cmd) diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index 776f57a039..47855a440f 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -1,12 +1,9 @@ import collections -import functools -import json import operator import typing as t import numpy as np -from pydantic import BaseModel, Extra, root_validator, validator -from requests.structures import CaseInsensitiveDict +from pydantic import BaseModel, field_validator, model_validator from typing_extensions import Literal from antarest.core.exceptions import ( @@ -18,7 +15,8 @@ STStorageNotFound, ) from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.core.requests import CaseInsensitiveDict +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id @@ -26,7 +24,6 @@ STStorage880Config, STStorage880Properties, STStorageConfigType, - STStorageGroup, create_st_storage_config, ) from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy @@ -37,26 +34,13 @@ from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig +@all_optional_model @camel_case_model -class STStorageInput(STStorage880Properties, metaclass=AllOptionalMetaclass, use_none=True): +class STStorageInput(STStorage880Properties): """ Model representing the form used to EDIT an existing short-term storage. """ - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = STStorageInput( - name="Siemens Battery", - group=STStorageGroup.BATTERY, - injection_nominal_capacity=150, - withdrawal_nominal_capacity=150, - reservoir_capacity=600, - efficiency=0.94, - initial_level=0.5, - initial_level_optim=True, - ) - class STStorageCreation(STStorageInput): """ @@ -64,7 +48,7 @@ class STStorageCreation(STStorageInput): """ # noinspection Pydantic - @validator("name", pre=True) + @field_validator("name", mode="before") def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. @@ -75,30 +59,17 @@ def validate_name(cls, name: t.Optional[str]) -> str: # noinspection PyUnusedLocal def to_config(self, study_version: t.Union[str, int]) -> STStorageConfigType: - values = self.dict(by_alias=False, exclude_none=True) + values = self.model_dump(by_alias=False, exclude_none=True) return create_st_storage_config(study_version=study_version, **values) +@all_optional_model @camel_case_model -class STStorageOutput(STStorage880Config, metaclass=AllOptionalMetaclass, use_none=True): +class STStorageOutput(STStorage880Config): """ Model representing the form used to display the details of a short-term storage entry. """ - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = STStorageOutput( - id="siemens_battery", - name="Siemens Battery", - group=STStorageGroup.BATTERY, - injection_nominal_capacity=150, - withdrawal_nominal_capacity=150, - reservoir_capacity=600, - efficiency=0.94, - initial_level_optim=True, - ) - # ============= # Time series @@ -119,13 +90,13 @@ class STStorageMatrix(BaseModel): """ class Config: - extra = Extra.forbid + extra = "forbid" data: t.List[t.List[float]] index: t.List[int] columns: t.List[int] - @validator("data") + @field_validator("data") def validate_time_series(cls, data: t.List[t.List[float]]) -> t.List[t.List[float]]: """ Validator to check the integrity of the time series data. @@ -160,7 +131,7 @@ class STStorageMatrices(BaseModel): """ class Config: - extra = Extra.forbid + extra = "forbid" pmax_injection: STStorageMatrix pmax_withdrawal: STStorageMatrix @@ -168,7 +139,7 @@ class Config: upper_rule_curve: STStorageMatrix inflows: STStorageMatrix - @validator( + @field_validator( "pmax_injection", "pmax_withdrawal", "lower_rule_curve", @@ -183,23 +154,18 @@ def validate_time_series(cls, matrix: STStorageMatrix) -> STStorageMatrix: raise ValueError("Matrix values should be between 0 and 1") return matrix - @root_validator() - def validate_rule_curve( - cls, values: t.MutableMapping[str, STStorageMatrix] - ) -> t.MutableMapping[str, STStorageMatrix]: + @model_validator(mode="after") + def validate_rule_curve(self) -> "STStorageMatrices": """ Validator to ensure 'lower_rule_curve' values are less than or equal to 'upper_rule_curve' values. """ - if "lower_rule_curve" in values and "upper_rule_curve" in values: - lower_rule_curve = values["lower_rule_curve"] - upper_rule_curve = values["upper_rule_curve"] - lower_array = np.array(lower_rule_curve.data, dtype=np.float64) - upper_array = np.array(upper_rule_curve.data, dtype=np.float64) - # noinspection PyUnresolvedReferences - if (lower_array > upper_array).any(): - raise ValueError("Each 'lower_rule_curve' value must be lower or equal to each 'upper_rule_curve'") - return values + lower_array = np.array(self.lower_rule_curve.data, dtype=np.float64) + upper_array = np.array(self.upper_rule_curve.data, dtype=np.float64) + if (lower_array > upper_array).any(): + raise ValueError("Each 'lower_rule_curve' value must be lower or equal to each 'upper_rule_curve'") + + return self # noinspection SpellCheckingInspection @@ -237,7 +203,7 @@ def create_storage_output( config: t.Mapping[str, t.Any], ) -> "STStorageOutput": obj = create_st_storage_config(study_version=study_version, **config, id=cluster_id) - kwargs = obj.dict(by_alias=False) + kwargs = obj.model_dump(by_alias=False) return STStorageOutput(**kwargs) @@ -381,17 +347,17 @@ def update_storages_props( for storage_id, update_cluster in update_storages_by_ids.items(): # Update the storage cluster properties. old_cluster = old_storages_by_ids[storage_id] - new_cluster = old_cluster.copy(update=update_cluster.dict(by_alias=False, exclude_none=True)) + new_cluster = old_cluster.copy(update=update_cluster.model_dump(by_alias=False, exclude_none=True)) new_storages_by_areas[area_id][storage_id] = new_cluster # Convert the DTO to a configuration object and update the configuration file. properties = create_st_storage_config( - study.version, **new_cluster.dict(by_alias=False, exclude_none=True) + study.version, **new_cluster.model_dump(by_alias=False, exclude_none=True) ) path = _STORAGE_LIST_PATH.format(area_id=area_id, storage_id=storage_id) cmd = UpdateConfig( target=path, - data=json.loads(properties.json(by_alias=True, exclude={"id"})), + data=properties.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) commands.append(cmd) @@ -460,16 +426,16 @@ def update_storage( old_config = create_st_storage_config(study_version, **values) # use Python values to synchronize Config and Form values - new_values = form.dict(by_alias=False, exclude_none=True) + new_values = form.model_dump(by_alias=False, exclude_none=True) new_config = old_config.copy(exclude={"id"}, update=new_values) - new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) + new_data = new_config.model_dump(mode="json", by_alias=True, exclude={"id"}) # create the dict containing the new values using aliases - data: t.Dict[str, t.Any] = { - field.alias: new_data[field.alias] - for field_name, field in new_config.__fields__.items() - if field_name in new_values - } + data: t.Dict[str, t.Any] = {} + for field_name, field in new_config.model_fields.items(): + if field_name in new_values: + name = field.alias if field.alias else field_name + data[name] = new_data[name] # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context @@ -480,7 +446,7 @@ def update_storage( ] execute_or_add_commands(study, file_study, commands, self.storage_service) - values = new_config.dict(by_alias=False) + values = new_config.model_dump(by_alias=False) return STStorageOutput(**values, id=storage_id) def delete_storages( @@ -542,7 +508,7 @@ def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_clus # We should remove the field 'enabled' for studies before v8.8 as it didn't exist if int(study.version) < 880: fields_to_exclude.add("enabled") - creation_form = STStorageCreation(**current_cluster.dict(by_alias=False, exclude=fields_to_exclude)) + creation_form = STStorageCreation(**current_cluster.model_dump(by_alias=False, exclude=fields_to_exclude)) new_config = creation_form.to_config(study.version) create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) @@ -571,7 +537,7 @@ def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_clus execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) - return STStorageOutput(**new_config.dict(by_alias=False)) + return STStorageOutput(**new_config.model_dump(by_alias=False)) def get_matrix( self, @@ -672,17 +638,18 @@ def validate_matrices( Returns: bool: True if validation is successful. """ - # Create a partial function to retrieve matrix objects - get_matrix_obj = functools.partial(self._get_matrix_obj, study, area_id, storage_id) + + def validate_matrix(matrix_type: STStorageTimeSeries) -> STStorageMatrix: + return STStorageMatrix.model_validate(self._get_matrix_obj(study, area_id, storage_id, matrix_type)) # Validate matrices by constructing the `STStorageMatrices` object # noinspection SpellCheckingInspection STStorageMatrices( - pmax_injection=get_matrix_obj("pmax_injection"), - pmax_withdrawal=get_matrix_obj("pmax_withdrawal"), - lower_rule_curve=get_matrix_obj("lower_rule_curve"), - upper_rule_curve=get_matrix_obj("upper_rule_curve"), - inflows=get_matrix_obj("inflows"), + pmax_injection=validate_matrix("pmax_injection"), + pmax_withdrawal=validate_matrix("pmax_withdrawal"), + lower_rule_curve=validate_matrix("lower_rule_curve"), + upper_rule_curve=validate_matrix("upper_rule_curve"), + inflows=validate_matrix("inflows"), ) # Validation successful diff --git a/antarest/study/business/areas/thermal_management.py b/antarest/study/business/areas/thermal_management.py index 205965eb54..cbd49c8d63 100644 --- a/antarest/study/business/areas/thermal_management.py +++ b/antarest/study/business/areas/thermal_management.py @@ -1,9 +1,8 @@ import collections -import json import typing as t from pathlib import Path -from pydantic import validator +from pydantic import field_validator from antarest.core.exceptions import ( DuplicateThermalCluster, @@ -13,7 +12,7 @@ WrongMatrixHeightError, ) from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id @@ -42,25 +41,13 @@ _ALL_CLUSTERS_PATH = "input/thermal/clusters" +@all_optional_model @camel_case_model -class ThermalClusterInput(Thermal870Properties, metaclass=AllOptionalMetaclass, use_none=True): +class ThermalClusterInput(Thermal870Properties): """ Model representing the data structure required to edit an existing thermal cluster within a study. """ - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = ThermalClusterInput( - group="Gas", - name="Gas Cluster XY", - enabled=False, - unitCount=100, - nominalCapacity=1000.0, - genTs="use global", - co2=7.0, - ) - class ThermalClusterCreation(ThermalClusterInput): """ @@ -68,7 +55,7 @@ class ThermalClusterCreation(ThermalClusterInput): """ # noinspection Pydantic - @validator("name", pre=True) + @field_validator("name", mode="before") def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. @@ -78,30 +65,17 @@ def validate_name(cls, name: t.Optional[str]) -> str: return name def to_config(self, study_version: t.Union[str, int]) -> ThermalConfigType: - values = self.dict(by_alias=False, exclude_none=True) + values = self.model_dump(by_alias=False, exclude_none=True) return create_thermal_config(study_version=study_version, **values) +@all_optional_model @camel_case_model -class ThermalClusterOutput(Thermal870Config, metaclass=AllOptionalMetaclass, use_none=True): +class ThermalClusterOutput(Thermal870Config): """ Model representing the output data structure to display the details of a thermal cluster within a study. """ - class Config: - @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: - schema["example"] = ThermalClusterOutput( - id="Gas cluster YZ", - group="Gas", - name="Gas Cluster YZ", - enabled=False, - unitCount=100, - nominalCapacity=1000.0, - genTs="use global", - co2=7.0, - ) - def create_thermal_output( study_version: t.Union[str, int], @@ -109,7 +83,7 @@ def create_thermal_output( config: t.Mapping[str, t.Any], ) -> "ThermalClusterOutput": obj = create_thermal_config(study_version=study_version, **config, id=cluster_id) - kwargs = obj.dict(by_alias=False) + kwargs = obj.model_dump(by_alias=False) return ThermalClusterOutput(**kwargs) @@ -240,15 +214,17 @@ def update_thermals_props( for thermal_id, update_cluster in update_thermals_by_ids.items(): # Update the thermal cluster properties. old_cluster = old_thermals_by_ids[thermal_id] - new_cluster = old_cluster.copy(update=update_cluster.dict(by_alias=False, exclude_none=True)) + new_cluster = old_cluster.copy(update=update_cluster.model_dump(by_alias=False, exclude_none=True)) new_thermals_by_areas[area_id][thermal_id] = new_cluster # Convert the DTO to a configuration object and update the configuration file. - properties = create_thermal_config(study.version, **new_cluster.dict(by_alias=False, exclude_none=True)) + properties = create_thermal_config( + study.version, **new_cluster.model_dump(by_alias=False, exclude_none=True) + ) path = _CLUSTER_PATH.format(area_id=area_id, cluster_id=thermal_id) cmd = UpdateConfig( target=path, - data=json.loads(properties.json(by_alias=True, exclude={"id"})), + data=properties.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) commands.append(cmd) @@ -290,12 +266,13 @@ def create_cluster(self, study: Study, area_id: str, cluster_data: ThermalCluste def _make_create_cluster_cmd(self, area_id: str, cluster: ThermalConfigType) -> CreateCluster: # NOTE: currently, in the `CreateCluster` class, there is a confusion # between the cluster name and the cluster ID (which is a section name). - command = CreateCluster( - area_id=area_id, - cluster_name=cluster.id, - parameters=cluster.dict(by_alias=True, exclude={"id"}), - command_context=self.storage_service.variant_study_service.command_factory.command_context, - ) + args = { + "area_id": area_id, + "cluster_name": cluster.id, + "parameters": cluster.model_dump(mode="json", by_alias=True, exclude={"id"}), + "command_context": self.storage_service.variant_study_service.command_factory.command_context, + } + command = CreateCluster.model_validate(args) return command def update_cluster( @@ -334,16 +311,16 @@ def update_cluster( old_config = create_thermal_config(study_version, **values) # Use Python values to synchronize Config and Form values - new_values = cluster_data.dict(by_alias=False, exclude_none=True) + new_values = cluster_data.model_dump(by_alias=False, exclude_none=True) new_config = old_config.copy(exclude={"id"}, update=new_values) - new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) + new_data = new_config.model_dump(mode="json", by_alias=True, exclude={"id"}) # create the dict containing the new values using aliases - data: t.Dict[str, t.Any] = { - field.alias: new_data[field.alias] - for field_name, field in new_config.__fields__.items() - if field_name in new_values - } + data: t.Dict[str, t.Any] = {} + for field_name, field in new_config.model_fields.items(): + if field_name in new_values: + name = field.alias if field.alias else field_name + data[name] = new_data[name] # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context @@ -353,8 +330,8 @@ def update_cluster( ] execute_or_add_commands(study, file_study, commands, self.storage_service) - values = new_config.dict(by_alias=False) - return ThermalClusterOutput(**values, id=cluster_id) + values = {**new_config.model_dump(mode="json", by_alias=False), "id": cluster_id} + return ThermalClusterOutput.model_validate(values) def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None: """ @@ -406,7 +383,7 @@ def duplicate_cluster( # Cluster duplication source_cluster = self.get_cluster(study, area_id, source_id) source_cluster.name = new_cluster_name - creation_form = ThermalClusterCreation(**source_cluster.dict(by_alias=False, exclude={"id"})) + creation_form = ThermalClusterCreation(**source_cluster.model_dump(by_alias=False, exclude={"id"})) new_config = creation_form.to_config(study.version) create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) @@ -439,7 +416,7 @@ def duplicate_cluster( execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) - return ThermalClusterOutput(**new_config.dict(by_alias=False)) + return ThermalClusterOutput(**new_config.model_dump(by_alias=False)) def validate_series(self, study: Study, area_id: str, cluster_id: str) -> bool: lower_cluster_id = cluster_id.lower() diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index c2106ef2f3..c898704818 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -1,11 +1,9 @@ import collections -import json import logging import typing as t import numpy as np -from pydantic import BaseModel, Field, root_validator, validator -from requests.utils import CaseInsensitiveDict +from pydantic import BaseModel, Field, field_validator, model_validator from antarest.core.exceptions import ( BindingConstraintNotFound, @@ -19,6 +17,7 @@ WrongMatrixHeightError, ) from antarest.core.model import JSON +from antarest.core.requests import CaseInsensitiveDict from antarest.core.utils.string import to_camel_case from antarest.study.business.all_optional_meta import camel_case_model from antarest.study.business.utils import execute_or_add_commands @@ -117,12 +116,12 @@ class ConstraintTerm(BaseModel): data: the constraint term data (link or cluster), if any. """ - id: t.Optional[str] - weight: t.Optional[float] - offset: t.Optional[int] - data: t.Optional[t.Union[LinkTerm, ClusterTerm]] + id: t.Optional[str] = None + weight: t.Optional[float] = None + offset: t.Optional[int] = None + data: t.Optional[t.Union[LinkTerm, ClusterTerm]] = None - @validator("id") + @field_validator("id") def id_to_lower(cls, v: t.Optional[str]) -> t.Optional[str]: """Ensure the ID is lower case.""" if v is None: @@ -253,7 +252,7 @@ class ConstraintInput(BindingConstraintMatrices, ConstraintInput870): class ConstraintCreation(ConstraintInput): name: str - @root_validator(pre=True) + @model_validator(mode="before") def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: for _key in ["time_step"] + [m.value for m in TermMatrices]: _camel = to_camel_case(_key) @@ -304,20 +303,19 @@ def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t. raise ValueError(err_msg) -@camel_case_model class ConstraintOutputBase(BindingConstraintPropertiesBase): id: str name: str terms: t.MutableSequence[ConstraintTerm] = Field(default_factory=lambda: []) + # I have to redefine the time_step attribute to give him another alias. + time_step: t.Optional[BindingConstraintFrequency] = Field(DEFAULT_TIMESTEP, alias="timeStep") # type: ignore -@camel_case_model class ConstraintOutput830(ConstraintOutputBase): - filter_year_by_year: str = "" - filter_synthesis: str = "" + filter_year_by_year: str = Field(default="", alias="filterYearByYear") + filter_synthesis: str = Field(default="", alias="filterSynthesis") -@camel_case_model class ConstraintOutput870(ConstraintOutput830): group: str = DEFAULT_GROUP @@ -349,7 +347,7 @@ def _get_references_by_widths( references_by_width: t.Dict[int, t.List[t.Tuple[str, str]]] = {} _total = len(bcs) for _index, bc in enumerate(bcs): - matrices_name = operator_matrix_file_map[bc.operator] if file_study.config.version >= 870 else ["{bc_id}"] + matrices_name = operator_matrix_file_map[bc.operator] if file_study.config.version >= 870 else ["{bc_id}"] # type: ignore for matrix_name in matrices_name: matrix_id = matrix_name.format(bc_id=bc.id) logger.info(f"⏲ Validating BC '{bc.id}': {matrix_id=} [{_index+1}/{_total}]") @@ -360,7 +358,7 @@ def _get_references_by_widths( continue matrix_height = matrix.shape[0] - expected_height = EXPECTED_MATRIX_SHAPES[bc.time_step][0] + expected_height = EXPECTED_MATRIX_SHAPES[bc.time_step][0] # type: ignore if matrix_height != expected_height: raise WrongMatrixHeightError( f"The binding constraint '{bc.name}' should have {expected_height} rows, currently: {matrix_height}" @@ -431,17 +429,22 @@ def parse_and_add_terms(key: str, value: t.Any, adapted_constraint: ConstraintOu id=key, weight=weight, offset=offset, - data={ - "area1": term_data[0], - "area2": term_data[1], - }, + data=LinkTerm.model_validate( + { + "area1": term_data[0], + "area2": term_data[1], + } + ), ) ) # Cluster term else: adapted_constraint.terms.append( ConstraintTerm( - id=key, weight=weight, offset=offset, data={"area": term_data[0], "cluster": term_data[1]} + id=key, + weight=weight, + offset=offset, + data=ClusterTerm.model_validate({"area": term_data[0], "cluster": term_data[1]}), ) ) @@ -589,7 +592,7 @@ def get_grouped_constraints(self, study: Study) -> t.Mapping[str, t.Sequence[Con storage_service = self.storage_service.get_storage(study) file_study = storage_service.get_raw(study) config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"]) - grouped_constraints = CaseInsensitiveDict() # type: ignore + grouped_constraints = CaseInsensitiveDict() for constraint in config.values(): constraint_config = self.constraint_model_adapter(constraint, int(study.version)) @@ -698,7 +701,10 @@ def create_binding_constraint( check_attributes_coherence(data, version, data.operator or DEFAULT_OPERATOR) - new_constraint = {"name": data.name, **json.loads(data.json(exclude={"terms", "name"}, exclude_none=True))} + new_constraint = { + "name": data.name, + **data.model_dump(mode="json", exclude={"terms", "name"}, exclude_none=True), + } args = { **new_constraint, "command_context": self.storage_service.variant_study_service.command_factory.command_context, @@ -732,11 +738,11 @@ def update_binding_constraint( existing_constraint = self.get_binding_constraint(study, binding_constraint_id) study_version = int(study.version) - check_attributes_coherence(data, study_version, data.operator or existing_constraint.operator) + check_attributes_coherence(data, study_version, data.operator or existing_constraint.operator) # type: ignore upd_constraint = { "id": binding_constraint_id, - **json.loads(data.json(exclude={"terms", "name"}, exclude_none=True)), + **data.model_dump(mode="json", exclude={"terms", "name"}, exclude_none=True), } args = { **upd_constraint, @@ -756,7 +762,7 @@ def update_binding_constraint( updated_matrices = [term for term in [m.value for m in TermMatrices] if getattr(data, term)] time_step = data.time_step or existing_constraint.time_step command.validates_and_fills_matrices( - time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False + time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False # type: ignore ) execute_or_add_commands(study, file_study, [command], self.storage_service) @@ -822,11 +828,9 @@ def _update_constraint_with_terms( coeffs = { term_id: [term.weight, term.offset] if term.offset else [term.weight] for term_id, term in terms.items() } - command = UpdateBindingConstraint( - id=bc.id, - coeffs=coeffs, - command_context=self.storage_service.variant_study_service.command_factory.command_context, - ) + command_context = self.storage_service.variant_study_service.command_factory.command_context + args = {"id": bc.id, "coeffs": coeffs, "command_context": command_context} + command = UpdateBindingConstraint.model_validate(args) file_study = self.storage_service.get_storage(study).get_raw(study) execute_or_add_commands(study, file_study, [command], self.storage_service) @@ -849,7 +853,7 @@ def update_constraint_terms( if update_mode == "add": for term in constraint_terms: if term.data is None: - raise InvalidConstraintTerm(binding_constraint_id, term.json()) + raise InvalidConstraintTerm(binding_constraint_id, term.model_dump_json()) constraint = self.get_binding_constraint(study, binding_constraint_id) existing_terms = collections.OrderedDict((term.generate_id(), term) for term in constraint.terms) diff --git a/antarest/study/business/config_management.py b/antarest/study/business/config_management.py index 35f46c01ca..9c57502197 100644 --- a/antarest/study/business/config_management.py +++ b/antarest/study/business/config_management.py @@ -30,7 +30,7 @@ def set_playlist( file_study = self.storage_service.get_storage(study).get_raw(study) command = UpdatePlaylist( items=playlist, - weights=weights, + weights=weights, # type: ignore reverse=reverse, active=active, command_context=self.storage_service.variant_study_service.command_factory.command_context, diff --git a/antarest/study/business/correlation_management.py b/antarest/study/business/correlation_management.py index b9abcff2f2..f30b147bb3 100644 --- a/antarest/study/business/correlation_management.py +++ b/antarest/study/business/correlation_management.py @@ -3,11 +3,11 @@ The generators are of the same category and can be hydraulic, wind, load or solar. """ import collections -from typing import Dict, List, Sequence +from typing import Dict, List, Sequence, Union import numpy as np import numpy.typing as npt -from pydantic import conlist, validator +from pydantic import ValidationInfo, conlist, field_validator from antarest.core.exceptions import AreaNotFound from antarest.study.business.area_management import AreaInfoDTO @@ -28,7 +28,7 @@ class AreaCoefficientItem(FormFieldsBaseModel): """ class Config: - allow_population_by_field_name = True + populate_by_name = True area_id: str coefficient: float @@ -45,7 +45,7 @@ class CorrelationFormFields(FormFieldsBaseModel): correlation: List[AreaCoefficientItem] # noinspection PyMethodParameters - @validator("correlation") + @field_validator("correlation") def check_correlation(cls, correlation: List[AreaCoefficientItem]) -> List[AreaCoefficientItem]: if not correlation: raise ValueError("correlation must not be empty") @@ -72,13 +72,15 @@ class CorrelationMatrix(FormFieldsBaseModel): data: A 2D-array matrix of correlation coefficients. """ - index: conlist(str, min_items=1) # type: ignore - columns: conlist(str, min_items=1) # type: ignore + index: conlist(str, min_length=1) # type: ignore + columns: conlist(str, min_length=1) # type: ignore data: List[List[float]] # NonNegativeFloat not necessary # noinspection PyMethodParameters - @validator("data") - def validate_correlation_matrix(cls, data: List[List[float]], values: Dict[str, List[str]]) -> List[List[float]]: + @field_validator("data", mode="before") + def validate_correlation_matrix( + cls, data: List[List[float]], values: Union[Dict[str, List[str]], ValidationInfo] + ) -> List[List[float]]: """ Validates the correlation matrix by checking its shape and range of coefficients. @@ -100,8 +102,9 @@ def validate_correlation_matrix(cls, data: List[List[float]], values: Dict[str, """ array = np.array(data) - rows = len(values.get("index", [])) - cols = len(values.get("columns", [])) + new_values = values if isinstance(values, dict) else values.data + rows = len(new_values.get("index", [])) + cols = len(new_values.get("columns", [])) if array.size == 0: raise ValueError("correlation matrix must not be empty") @@ -116,20 +119,6 @@ def validate_correlation_matrix(cls, data: List[List[float]], values: Dict[str, return data - class Config: - schema_extra = { - "example": { - "columns": ["north", "east", "south", "west"], - "data": [ - [0.0, 0.0, 0.25, 0.0], - [0.0, 0.0, 0.75, 0.12], - [0.25, 0.75, 0.0, 0.75], - [0.0, 0.12, 0.75, 0.0], - ], - "index": ["north", "east", "south", "west"], - } - } - def _config_to_array( area_ids: Sequence[str], diff --git a/antarest/study/business/district_manager.py b/antarest/study/business/district_manager.py index 5a214c284c..c31d3d1def 100644 --- a/antarest/study/business/district_manager.py +++ b/antarest/study/business/district_manager.py @@ -57,16 +57,19 @@ def get_districts(self, study: Study) -> List[DistrictInfoDTO]: """ file_study = self.storage_service.get_storage(study).get_raw(study) all_areas = list(file_study.config.areas) - return [ - DistrictInfoDTO( - id=district_id, - name=district.name, - areas=district.get_areas(all_areas), - output=district.output, - comments=file_study.tree.get(["input", "areas", "sets", district_id]).get("comments", ""), + districts = [] + for district_id, district in file_study.config.sets.items(): + assert district.name is not None + districts.append( + DistrictInfoDTO( + id=district_id, + name=district.name, + areas=district.get_areas(all_areas), + output=district.output, + comments=file_study.tree.get(["input", "areas", "sets", district_id]).get("comments", ""), + ) ) - for district_id, district in file_study.config.sets.items() - ] + return districts def create_district( self, @@ -136,14 +139,14 @@ def update_district( file_study = self.storage_service.get_storage(study).get_raw(study) if district_id not in file_study.config.sets: raise DistrictNotFound(district_id) - areas = frozenset(dto.areas or []) - all_areas = frozenset(file_study.config.areas) + areas = set(dto.areas or []) + all_areas = set(file_study.config.areas) if invalid_areas := areas - all_areas: raise AreaNotFound(*invalid_areas) command = UpdateDistrict( id=district_id, base_filter=DistrictBaseFilter.remove_all, - filter_items=areas, + filter_items=dto.areas or [], output=dto.output, comments=dto.comments, command_context=self.storage_service.variant_study_service.command_factory.command_context, diff --git a/antarest/study/business/general_management.py b/antarest/study/business/general_management.py index 4cf2dc62e0..58e9743cdc 100644 --- a/antarest/study/business/general_management.py +++ b/antarest/study/business/general_management.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Union, cast -from pydantic import PositiveInt, StrictBool, conint, root_validator +from pydantic import PositiveInt, StrictBool, ValidationInfo, conint, model_validator +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -51,39 +52,41 @@ class BuildingMode(EnumIgnoreCase): DayNumberType = conint(ge=1, le=366) +@all_optional_model class GeneralFormFields(FormFieldsBaseModel): - mode: Optional[Mode] - first_day: Optional[DayNumberType] # type: ignore - last_day: Optional[DayNumberType] # type: ignore - horizon: Optional[str] # Don't use `StrictStr` because it can be an int - first_month: Optional[Month] - first_week_day: Optional[WeekDay] - first_january: Optional[WeekDay] - leap_year: Optional[StrictBool] - nb_years: Optional[PositiveInt] - building_mode: Optional[BuildingMode] - selection_mode: Optional[StrictBool] - year_by_year: Optional[StrictBool] - simulation_synthesis: Optional[StrictBool] - mc_scenario: Optional[StrictBool] + mode: Mode + first_day: DayNumberType # type: ignore + last_day: DayNumberType # type: ignore + horizon: Union[str, int] + first_month: Month + first_week_day: WeekDay + first_january: WeekDay + leap_year: StrictBool + nb_years: PositiveInt + building_mode: BuildingMode + selection_mode: StrictBool + year_by_year: StrictBool + simulation_synthesis: StrictBool + mc_scenario: StrictBool # Geographic trimming + Thematic trimming. # For study versions < 710 - filtering: Optional[StrictBool] + filtering: StrictBool # For study versions >= 710 - geographic_trimming: Optional[StrictBool] - thematic_trimming: Optional[StrictBool] - - @root_validator - def day_fields_validation(cls, values: Dict[str, Any]) -> Dict[str, Any]: - first_day = values.get("first_day") - last_day = values.get("last_day") - leap_year = values.get("leap_year") + geographic_trimming: StrictBool + thematic_trimming: StrictBool + + @model_validator(mode="before") + def day_fields_validation(cls, values: Union[Dict[str, Any], ValidationInfo]) -> Dict[str, Any]: + new_values = values if isinstance(values, dict) else values.data + first_day = new_values.get("first_day") + last_day = new_values.get("last_day") + leap_year = new_values.get("leap_year") day_fields = [first_day, last_day, leap_year] if all(v is None for v in day_fields): # The user wishes to update another field than these three. # no need to validate anything: - return values + return new_values if any(v is None for v in day_fields): raise ValueError("First day, last day and leap year fields must be defined together") @@ -98,7 +101,7 @@ def day_fields_validation(cls, values: Dict[str, Any]) -> Dict[str, Any]: if last_day > num_days_in_year: raise ValueError(f"Last day cannot be greater than {num_days_in_year}") - return values + return new_values GENERAL = "general" diff --git a/antarest/study/business/link_management.py b/antarest/study/business/link_management.py index 744401772a..33395186db 100644 --- a/antarest/study/business/link_management.py +++ b/antarest/study/business/link_management.py @@ -4,7 +4,7 @@ from antarest.core.exceptions import ConfigFileNotFound from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import RawStudy from antarest.study.storage.rawstudy.model.filesystem.config.links import LinkProperties @@ -28,8 +28,9 @@ class LinkInfoDTO(BaseModel): ui: t.Optional[LinkUIDTO] = None +@all_optional_model @camel_case_model -class LinkOutput(LinkProperties, metaclass=AllOptionalMetaclass, use_none=True): +class LinkOutput(LinkProperties): """ DTO object use to get the link information. """ @@ -109,7 +110,7 @@ def get_all_links_props(self, study: RawStudy) -> t.Mapping[t.Tuple[str, str], L for area2_id, properties_cfg in property_map.items(): area1_id, area2_id = sorted([area1_id, area2_id]) properties = LinkProperties(**properties_cfg) - links_by_ids[(area1_id, area2_id)] = LinkOutput(**properties.dict(by_alias=False)) + links_by_ids[(area1_id, area2_id)] = LinkOutput(**properties.model_dump(by_alias=False)) return links_by_ids @@ -125,11 +126,11 @@ def update_links_props( for (area1, area2), update_link_dto in update_links_by_ids.items(): # Update the link properties. old_link_dto = old_links_by_ids[(area1, area2)] - new_link_dto = old_link_dto.copy(update=update_link_dto.dict(by_alias=False, exclude_none=True)) + new_link_dto = old_link_dto.copy(update=update_link_dto.model_dump(by_alias=False, exclude_none=True)) new_links_by_ids[(area1, area2)] = new_link_dto # Convert the DTO to a configuration object and update the configuration file. - properties = LinkProperties(**new_link_dto.dict(by_alias=False)) + properties = LinkProperties(**new_link_dto.model_dump(by_alias=False)) path = f"{_ALL_LINKS_PATH}/{area1}/properties/{area2}" cmd = UpdateConfig( target=path, diff --git a/antarest/study/business/optimization_management.py b/antarest/study/business/optimization_management.py index 5defcc2a2a..2f1f511d98 100644 --- a/antarest/study/business/optimization_management.py +++ b/antarest/study/business/optimization_management.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Union, cast from pydantic.types import StrictBool +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -33,24 +34,20 @@ class SimplexOptimizationRange(EnumIgnoreCase): WEEK = "week" +@all_optional_model class OptimizationFormFields(FormFieldsBaseModel): - binding_constraints: Optional[StrictBool] - hurdle_costs: Optional[StrictBool] - transmission_capacities: Optional[ - Union[ - StrictBool, - Union[LegacyTransmissionCapacities, TransmissionCapacities], - ] - ] - thermal_clusters_min_stable_power: Optional[StrictBool] - thermal_clusters_min_ud_time: Optional[StrictBool] - day_ahead_reserve: Optional[StrictBool] - primary_reserve: Optional[StrictBool] - strategic_reserve: Optional[StrictBool] - spinning_reserve: Optional[StrictBool] - export_mps: Optional[Union[bool, str]] - unfeasible_problem_behavior: Optional[UnfeasibleProblemBehavior] - simplex_optimization_range: Optional[SimplexOptimizationRange] + binding_constraints: StrictBool + hurdle_costs: StrictBool + transmission_capacities: Union[StrictBool, LegacyTransmissionCapacities, TransmissionCapacities] + thermal_clusters_min_stable_power: StrictBool + thermal_clusters_min_ud_time: StrictBool + day_ahead_reserve: StrictBool + primary_reserve: StrictBool + strategic_reserve: StrictBool + spinning_reserve: StrictBool + export_mps: Union[bool, str] + unfeasible_problem_behavior: UnfeasibleProblemBehavior + simplex_optimization_range: SimplexOptimizationRange OPTIMIZATION_PATH = f"{GENERAL_DATA_PATH}/optimization" diff --git a/antarest/study/business/table_mode_management.py b/antarest/study/business/table_mode_management.py index bc31683139..d15fa660b0 100644 --- a/antarest/study/business/table_mode_management.py +++ b/antarest/study/business/table_mode_management.py @@ -83,36 +83,37 @@ def __init__( def _get_table_data_unsafe(self, study: RawStudy, table_type: TableModeType) -> TableDataDTO: if table_type == TableModeType.AREA: areas_map = self._area_manager.get_all_area_props(study) - data = {area_id: area.dict(by_alias=True) for area_id, area in areas_map.items()} + data = {area_id: area.model_dump(by_alias=True) for area_id, area in areas_map.items()} elif table_type == TableModeType.LINK: links_map = self._link_manager.get_all_links_props(study) data = { - f"{area1_id} / {area2_id}": link.dict(by_alias=True) for (area1_id, area2_id), link in links_map.items() + f"{area1_id} / {area2_id}": link.model_dump(by_alias=True) + for (area1_id, area2_id), link in links_map.items() } elif table_type == TableModeType.THERMAL: thermals_by_areas = self._thermal_manager.get_all_thermals_props(study) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, thermals_by_ids in thermals_by_areas.items() for cluster_id, cluster in thermals_by_ids.items() } elif table_type == TableModeType.RENEWABLE: renewables_by_areas = self._renewable_manager.get_all_renewables_props(study) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, renewables_by_ids in renewables_by_areas.items() for cluster_id, cluster in renewables_by_ids.items() } elif table_type == TableModeType.ST_STORAGE: storages_by_areas = self._st_storage_manager.get_all_storages_props(study) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, storages_by_ids in storages_by_areas.items() for cluster_id, cluster in storages_by_ids.items() } elif table_type == TableModeType.BINDING_CONSTRAINT: bc_seq = self._binding_constraint_manager.get_binding_constraints(study) - data = {bc.id: bc.dict(by_alias=True, exclude={"id", "name", "terms"}) for bc in bc_seq} + data = {bc.id: bc.model_dump(by_alias=True, exclude={"id", "name", "terms"}) for bc in bc_seq} else: # pragma: no cover raise NotImplementedError(f"Table type {table_type} not implemented") return data @@ -177,13 +178,13 @@ def update_table_data( # Use AreaOutput to update properties of areas, which may include `None` values area_props_by_ids = {key: AreaOutput(**values) for key, values in data.items()} areas_map = self._area_manager.update_areas_props(study, area_props_by_ids) - data = {area_id: area.dict(by_alias=True, exclude_none=True) for area_id, area in areas_map.items()} + data = {area_id: area.model_dump(by_alias=True, exclude_none=True) for area_id, area in areas_map.items()} return data elif table_type == TableModeType.LINK: links_map = {tuple(key.split(" / ")): LinkOutput(**values) for key, values in data.items()} updated_map = self._link_manager.update_links_props(study, links_map) # type: ignore data = { - f"{area1_id} / {area2_id}": link.dict(by_alias=True) + f"{area1_id} / {area2_id}": link.model_dump(by_alias=True) for (area1_id, area2_id), link in updated_map.items() } return data @@ -195,7 +196,7 @@ def update_table_data( thermals_by_areas[area_id][cluster_id] = ThermalClusterInput(**values) thermals_map = self._thermal_manager.update_thermals_props(study, thermals_by_areas) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, thermals_by_ids in thermals_map.items() for cluster_id, cluster in thermals_by_ids.items() } @@ -208,7 +209,7 @@ def update_table_data( renewables_by_areas[area_id][cluster_id] = RenewableClusterInput(**values) renewables_map = self._renewable_manager.update_renewables_props(study, renewables_by_areas) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, renewables_by_ids in renewables_map.items() for cluster_id, cluster in renewables_by_ids.items() } @@ -221,7 +222,7 @@ def update_table_data( storages_by_areas[area_id][cluster_id] = STStorageInput(**values) storages_map = self._st_storage_manager.update_storages_props(study, storages_by_areas) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, storages_by_ids in storages_map.items() for cluster_id, cluster in storages_by_ids.items() } @@ -229,7 +230,9 @@ def update_table_data( elif table_type == TableModeType.BINDING_CONSTRAINT: bcs_by_ids = {key: ConstraintInput(**values) for key, values in data.items()} bcs_map = self._binding_constraint_manager.update_binding_constraints(study, bcs_by_ids) - return {bc_id: bc.dict(by_alias=True, exclude={"id", "name", "terms"}) for bc_id, bc in bcs_map.items()} + return { + bc_id: bc.model_dump(by_alias=True, exclude={"id", "name", "terms"}) for bc_id, bc in bcs_map.items() + } else: # pragma: no cover raise NotImplementedError(f"Table type {table_type} not implemented") diff --git a/antarest/study/business/thematic_trimming_field_infos.py b/antarest/study/business/thematic_trimming_field_infos.py index 3baabd8014..b142b3d72b 100644 --- a/antarest/study/business/thematic_trimming_field_infos.py +++ b/antarest/study/business/thematic_trimming_field_infos.py @@ -4,11 +4,12 @@ import typing as t -from antarest.study.business.all_optional_meta import AllOptionalMetaclass +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.utils import FormFieldsBaseModel -class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetaclass, use_none=True): +@all_optional_model +class ThematicTrimmingFormFields(FormFieldsBaseModel): """ This class manages the configuration of result filtering in a simulation. @@ -225,6 +226,5 @@ class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetac } -def get_fields_info(study_version: t.Union[str, int]) -> t.Mapping[str, t.Mapping[str, t.Any]]: - study_version = int(study_version) +def get_fields_info(study_version: int) -> t.Mapping[str, t.Mapping[str, t.Any]]: return {key: info for key, info in FIELDS_INFO.items() if (info.get("start_version") or 0) <= study_version} diff --git a/antarest/study/business/thematic_trimming_management.py b/antarest/study/business/thematic_trimming_management.py index d4af9f960e..85811935d9 100644 --- a/antarest/study/business/thematic_trimming_management.py +++ b/antarest/study/business/thematic_trimming_management.py @@ -37,7 +37,7 @@ def set_field_values(self, study: Study, field_values: ThematicTrimmingFormField Set Thematic Trimming config from the webapp form """ file_study = self.storage_service.get_storage(study).get_raw(study) - field_values_dict = field_values.dict() + field_values_dict = field_values.model_dump() keys_by_bool: t.Dict[bool, t.List[t.Any]] = {True: [], False: []} fields_info = get_fields_info(int(study.version)) diff --git a/antarest/study/business/timeseries_config_management.py b/antarest/study/business/timeseries_config_management.py index 56cf07cfa1..b668a93aa0 100644 --- a/antarest/study/business/timeseries_config_management.py +++ b/antarest/study/business/timeseries_config_management.py @@ -1,8 +1,9 @@ import typing as t -from pydantic import StrictBool, StrictInt, root_validator, validator +from pydantic import StrictBool, StrictInt, field_validator, model_validator from antarest.core.model import JSON +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -27,28 +28,30 @@ class SeasonCorrelation(EnumIgnoreCase): ANNUAL = "annual" +@all_optional_model class TSFormFieldsForType(FormFieldsBaseModel): - stochastic_ts_status: t.Optional[StrictBool] - number: t.Optional[StrictInt] - refresh: t.Optional[StrictBool] - refresh_interval: t.Optional[StrictInt] - season_correlation: t.Optional[SeasonCorrelation] - store_in_input: t.Optional[StrictBool] - store_in_output: t.Optional[StrictBool] - intra_modal: t.Optional[StrictBool] - inter_modal: t.Optional[StrictBool] - - + stochastic_ts_status: StrictBool + number: StrictInt + refresh: StrictBool + refresh_interval: StrictInt + season_correlation: SeasonCorrelation + store_in_input: StrictBool + store_in_output: StrictBool + intra_modal: StrictBool + inter_modal: StrictBool + + +@all_optional_model class TSFormFields(FormFieldsBaseModel): - load: t.Optional[TSFormFieldsForType] = None - hydro: t.Optional[TSFormFieldsForType] = None - thermal: t.Optional[TSFormFieldsForType] = None - wind: t.Optional[TSFormFieldsForType] = None - solar: t.Optional[TSFormFieldsForType] = None - renewables: t.Optional[TSFormFieldsForType] = None - ntc: t.Optional[TSFormFieldsForType] = None - - @root_validator(pre=True) + load: TSFormFieldsForType + hydro: TSFormFieldsForType + thermal: TSFormFieldsForType + wind: TSFormFieldsForType + solar: TSFormFieldsForType + renewables: TSFormFieldsForType + ntc: TSFormFieldsForType + + @model_validator(mode="before") def check_type_validity( cls, values: t.Dict[str, t.Optional[TSFormFieldsForType]] ) -> t.Dict[str, t.Optional[TSFormFieldsForType]]: @@ -61,7 +64,7 @@ def has_type(ts_type: TSType) -> bool: ) return values - @validator("thermal") + @field_validator("thermal") def thermal_validation(cls, v: TSFormFieldsForType) -> TSFormFieldsForType: if v.season_correlation is not None: raise ValueError("season_correlation is not allowed for 'thermal' type") @@ -118,7 +121,7 @@ def __set_field_values_for_type( field_values: TSFormFieldsForType, ) -> None: commands: t.List[UpdateConfig] = [] - values = field_values.dict() + values = field_values.model_dump() for field, path in PATH_BY_TS_STR_FIELD.items(): field_val = values[field] diff --git a/antarest/study/business/utils.py b/antarest/study/business/utils.py index 8c4b567b22..e08aa505d9 100644 --- a/antarest/study/business/utils.py +++ b/antarest/study/business/utils.py @@ -5,7 +5,7 @@ from antarest.core.exceptions import CommandApplicationError from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters -from antarest.core.utils.string import to_camel_case +from antarest.study.business.all_optional_meta import camel_case_model from antarest.study.model import RawStudy, Study from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.storage_service import StudyStorageService @@ -57,12 +57,12 @@ def execute_or_add_commands( ) +@camel_case_model class FormFieldsBaseModel( BaseModel, - alias_generator=to_camel_case, extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Pydantic Model for webapp form diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py index 66d25860dd..7a6deff54a 100644 --- a/antarest/study/business/xpansion_management.py +++ b/antarest/study/business/xpansion_management.py @@ -7,11 +7,11 @@ import zipfile from fastapi import HTTPException, UploadFile -from pydantic import BaseModel, Extra, Field, ValidationError, root_validator, validator +from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator from antarest.core.exceptions import BadZipBinary, ChildNotFoundError from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.bucket_node import BucketNode @@ -62,12 +62,12 @@ class XpansionSensitivitySettings(BaseModel): projection: t.List[str] = Field(default_factory=list, description="List of candidate names to project") capex: bool = Field(default=False, description="Whether to include capex in the sensitivity analysis") - @validator("projection", pre=True) + @field_validator("projection", mode="before") def projection_validation(cls, v: t.Optional[t.Sequence[str]]) -> t.Sequence[str]: return [] if v is None else v -class XpansionSettings(BaseModel, extra=Extra.ignore, validate_assignment=True, allow_population_by_field_name=True): +class XpansionSettings(BaseModel, extra="ignore", validate_assignment=True, populate_by_name=True): """ A data transfer object representing the general settings used for Xpansion. @@ -140,8 +140,8 @@ class XpansionSettings(BaseModel, extra=Extra.ignore, validate_assignment=True, # The sensitivity analysis is optional sensitivity_config: t.Optional[XpansionSensitivitySettings] = None - @root_validator(pre=True) - def normalize_values(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[str, t.Any]: + @model_validator(mode="before") + def validate_float_values(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[str, t.Any]: if "relaxed-optimality-gap" in values: values["relaxed_optimality_gap"] = values.pop("relaxed-optimality-gap") @@ -196,7 +196,8 @@ def from_config(cls, config_obj: JSON) -> "GetXpansionSettings": return cls.construct(**config_obj) -class UpdateXpansionSettings(XpansionSettings, metaclass=AllOptionalMetaclass, use_none=True): +@all_optional_model +class UpdateXpansionSettings(XpansionSettings): """ DTO object used to update the Xpansion settings. @@ -207,13 +208,6 @@ class UpdateXpansionSettings(XpansionSettings, metaclass=AllOptionalMetaclass, u # note: for some reason, the alias is not taken into account when using the metaclass, # so we have to redefine the fields with the alias. - - # On the other hand, we make these fields mandatory, because there is an anomaly on the front side: - # When the user does not select any file, the front sends a request without the "yearly-weights" - # or "additional-constraints" field, instead of sending the field with an empty value. - # This is not a problem as long as the front sends a request with all the fields (PUT case), - # but it is a problem for partial requests (PATCH case). - yearly_weights: str = Field( "", alias="yearly-weights", @@ -233,20 +227,18 @@ class XpansionCandidateDTO(BaseModel): name: str link: str annual_cost_per_mw: float = Field(alias="annual-cost-per-mw", ge=0) - unit_size: t.Optional[float] = Field(None, alias="unit-size", ge=0) - max_units: t.Optional[int] = Field(None, alias="max-units", ge=0) - max_investment: t.Optional[float] = Field(None, alias="max-investment", ge=0) - already_installed_capacity: t.Optional[int] = Field(None, alias="already-installed-capacity", ge=0) + unit_size: float = Field(None, alias="unit-size", ge=0) + max_units: int = Field(None, alias="max-units", ge=0) + max_investment: float = Field(None, alias="max-investment", ge=0) + already_installed_capacity: int = Field(None, alias="already-installed-capacity", ge=0) # this is obsolete (replaced by direct/indirect) - link_profile: t.Optional[str] = Field(None, alias="link-profile") + link_profile: str = Field(None, alias="link-profile") # this is obsolete (replaced by direct/indirect) - already_installed_link_profile: t.Optional[str] = Field(None, alias="already-installed-link-profile") - direct_link_profile: t.Optional[str] = Field(None, alias="direct-link-profile") - indirect_link_profile: t.Optional[str] = Field(None, alias="indirect-link-profile") - already_installed_direct_link_profile: t.Optional[str] = Field(None, alias="already-installed-direct-link-profile") - already_installed_indirect_link_profile: t.Optional[str] = Field( - None, alias="already-installed-indirect-link-profile" - ) + already_installed_link_profile: str = Field(None, alias="already-installed-link-profile") + direct_link_profile: str = Field(None, alias="direct-link-profile") + indirect_link_profile: str = Field(None, alias="indirect-link-profile") + already_installed_direct_link_profile: str = Field(None, alias="already-installed-direct-link-profile") + already_installed_indirect_link_profile: str = Field(None, alias="already-installed-indirect-link-profile") class LinkNotFound(HTTPException): @@ -334,10 +326,12 @@ def create_xpansion_configuration(self, study: Study, zipped_config: t.Optional[ ) raise BadZipBinary("Only zip file are allowed.") - xpansion_settings = XpansionSettings() - settings_obj = xpansion_settings.dict(by_alias=True, exclude_none=True, exclude={"sensitivity_config"}) + xpansion_settings = XpansionSettings() # type: ignore + settings_obj = xpansion_settings.model_dump( + by_alias=True, exclude_none=True, exclude={"sensitivity_config"} + ) if xpansion_settings.sensitivity_config: - sensitivity_obj = xpansion_settings.sensitivity_config.dict(by_alias=True, exclude_none=True) + sensitivity_obj = xpansion_settings.sensitivity_config.model_dump(by_alias=True, exclude_none=True) else: sensitivity_obj = {} @@ -377,7 +371,9 @@ def update_xpansion_settings( logger.info(f"Updating xpansion settings for study '{study.id}'") actual_settings = self.get_xpansion_settings(study) - settings_fields = new_xpansion_settings.dict(by_alias=False, exclude_none=True, exclude={"sensitivity_config"}) + settings_fields = new_xpansion_settings.model_dump( + by_alias=False, exclude_none=True, exclude={"sensitivity_config"} + ) updated_settings = actual_settings.copy(deep=True, update=settings_fields) file_study = self.study_storage_service.get_storage(study).get_raw(study) @@ -397,11 +393,11 @@ def update_xpansion_settings( msg = f"Additional constraints file '{constraints_file}' does not exist" raise XpansionFileNotFoundError(msg) from None - config_obj = updated_settings.dict(by_alias=True, exclude={"sensitivity_config"}) + config_obj = updated_settings.model_dump(by_alias=True, exclude={"sensitivity_config"}) file_study.tree.save(config_obj, ["user", "expansion", "settings"]) if new_xpansion_settings.sensitivity_config: - sensitivity_obj = new_xpansion_settings.sensitivity_config.dict(by_alias=True) + sensitivity_obj = new_xpansion_settings.sensitivity_config.model_dump(by_alias=True) file_study.tree.save(sensitivity_obj, ["user", "expansion", "sensitivity", "sensitivity_in"]) return self.get_xpansion_settings(study) @@ -541,7 +537,7 @@ def add_candidate(self, study: Study, xpansion_candidate: XpansionCandidateDTO) ) # The primary key is actually the name, the id does not matter and is never checked. logger.info(f"Adding candidate '{xpansion_candidate.name}' to study '{study.id}'") - candidates_obj[next_id] = xpansion_candidate.dict(by_alias=True, exclude_none=True) + candidates_obj[next_id] = xpansion_candidate.model_dump(by_alias=True, exclude_none=True) candidates_data = {"user": {"expansion": {"candidates": candidates_obj}}} file_study.tree.save(candidates_data) # Should we add a field in the study config containing the xpansion candidates like the links or the areas ? @@ -582,7 +578,7 @@ def update_candidate( for candidate_id, candidate in candidates.items(): if candidate["name"] == candidate_name: logger.info(f"Updating candidate '{candidate_name}' of study '{study.id}'") - candidates[candidate_id] = xpansion_candidate_dto.dict(by_alias=True, exclude_none=True) + candidates[candidate_id] = xpansion_candidate_dto.model_dump(by_alias=True, exclude_none=True) file_study.tree.save(candidates, ["user", "expansion", "candidates"]) return raise CandidateNotFoundError(f"The candidate '{xpansion_candidate_dto.name}' does not exist") @@ -602,7 +598,8 @@ def update_xpansion_constraints_settings(self, study: Study, constraints_file_na # Make sure filename is not `None`, because `None` values are ignored by the update. constraints_file_name = constraints_file_name or "" # noinspection PyArgumentList - xpansion_settings = UpdateXpansionSettings(additional_constraints=constraints_file_name) + args = {"additional_constraints": constraints_file_name} + xpansion_settings = UpdateXpansionSettings.model_validate(args) return self.update_xpansion_settings(study, xpansion_settings) def _raw_file_dir(self, raw_file_type: XpansionResourceFileType) -> t.List[str]: @@ -643,6 +640,7 @@ def _add_raw_files( content = file.file.read() if isinstance(content, str): content = content.encode(encoding="utf-8") + assert file.filename is not None buffer[file.filename] = content file_study.tree.save(data) diff --git a/antarest/study/model.py b/antarest/study/model.py index 4eff8109ab..6dd30e0537 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta from pathlib import Path -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from sqlalchemy import ( # type: ignore Boolean, Column, @@ -339,13 +339,18 @@ class StudyMetadataDTO(BaseModel): workspace: str managed: bool archived: bool - horizon: t.Optional[str] - scenario: t.Optional[str] - status: t.Optional[str] - doc: t.Optional[str] + horizon: t.Optional[str] = None + scenario: t.Optional[str] = None + status: t.Optional[str] = None + doc: t.Optional[str] = None folder: t.Optional[str] = None tags: t.List[str] = [] + @field_validator("horizon", mode="before") + def transform_horizon_to_str(cls, val: t.Union[str, int, None]) -> t.Optional[str]: + # horizon can be an int. + return str(val) if val else val # type: ignore + class StudyMetadataPatchDTO(BaseModel): name: t.Optional[str] = None @@ -354,18 +359,20 @@ class StudyMetadataPatchDTO(BaseModel): scenario: t.Optional[str] = None status: t.Optional[str] = None doc: t.Optional[str] = None - tags: t.Sequence[str] = () + tags: t.List[str] = [] - @validator("tags", each_item=True) - def _normalize_tags(cls, v: str) -> str: + @field_validator("tags", mode="before") + def _normalize_tags(cls, v: t.List[str]) -> t.List[str]: """Remove leading and trailing whitespaces, and replace consecutive whitespaces by a single one.""" - tag = " ".join(v.split()) - if not tag: - raise ValueError("Tag cannot be empty") - elif len(tag) > 40: - raise ValueError(f"Tag is too long: {tag!r}") - else: - return tag + tags = [] + for tag in v: + tag = " ".join(tag.split()) + if not tag: + raise ValueError("Tag cannot be empty") + elif len(tag) > 40: + raise ValueError(f"Tag is too long: {tag!r}") + tags.append(tag) + return tags class StudySimSettingsDTO(BaseModel): diff --git a/antarest/study/service.py b/antarest/study/service.py index 8e4b537e19..48ccd72b49 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -1318,7 +1318,7 @@ def export_task(_notifier: TaskUpdateNotifier) -> TaskResult: else: json_response = json.dumps( - matrix.dict(), + matrix.model_dump(), ensure_ascii=False, allow_nan=True, indent=None, @@ -1473,6 +1473,7 @@ def _create_edit_study_command( context = self.storage_service.variant_study_service.command_factory.command_context if isinstance(tree_node, IniFileNode): + assert not isinstance(data, (bytes, list)) return UpdateConfig( target=url, data=data, @@ -1488,6 +1489,7 @@ def _create_edit_study_command( matrix=matrix.tolist(), command_context=context, ) + assert isinstance(data, (list, str)) return ReplaceMatrix( target=url, matrix=data, @@ -1495,6 +1497,9 @@ def _create_edit_study_command( ) elif isinstance(tree_node, RawFileNode): if url.split("/")[-1] == "comments": + if isinstance(data, bytes): + data = data.decode("utf-8") + assert isinstance(data, str) return UpdateComments( comments=data, command_context=context, @@ -2413,7 +2418,7 @@ def unarchive_output_task( src=str(src), dest=str(dest), remove_src=not keep_src_zip, - ).dict(), + ).model_dump(), name=task_name, ref_id=study.id, request_params=params, diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index 892f855970..aafd782261 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -76,7 +76,7 @@ def get_study_information( try: patch_obj = json.loads(additional_data.patch or "{}") - patch = Patch.parse_obj(patch_obj) + patch = Patch.model_validate(patch_obj) except ValueError as e: # The conversion to JSON and the parsing can fail if the patch is not valid logger.warning(f"Failed to parse patch for study {study.id}", exc_info=e) @@ -316,7 +316,7 @@ def _read_additional_data_from_files(self, file_study: FileStudy) -> StudyAdditi horizon = file_study.tree.get(url=["settings", "generaldata", "general", "horizon"]) author = file_study.tree.get(url=["study", "antares", "author"]) patch = self.patch_service.get_from_filestudy(file_study) - study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.json()) + study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.model_dump_json()) return study_additional_data def archive_study_output(self, study: T, output_id: str) -> bool: diff --git a/antarest/study/storage/patch_service.py b/antarest/study/storage/patch_service.py index c52752ae25..f0586fccf3 100644 --- a/antarest/study/storage/patch_service.py +++ b/antarest/study/storage/patch_service.py @@ -23,7 +23,7 @@ def get(self, study: t.Union[RawStudy, VariantStudy], get_from_file: bool = Fals # the `study.additional_data.patch` field is optional if study.additional_data.patch: patch_obj = json.loads(study.additional_data.patch or "{}") - return Patch.parse_obj(patch_obj) + return Patch.model_validate(patch_obj) patch = Patch() patch_path = Path(study.path) / PATCH_JSON @@ -55,9 +55,9 @@ def set_reference_output( def save(self, study: t.Union[RawStudy, VariantStudy], patch: Patch) -> None: if self.repository: study.additional_data = study.additional_data or StudyAdditionalData() - study.additional_data.patch = patch.json() + study.additional_data.patch = patch.model_dump_json() self.repository.save(study) patch_path = (Path(study.path)) / PATCH_JSON patch_path.parent.mkdir(parents=True, exist_ok=True) - patch_path.write_text(patch.json()) + patch_path.write_text(patch.model_dump_json()) diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/area.py b/antarest/study/storage/rawstudy/model/filesystem/config/area.py index 5ade25159f..a43bda176c 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/area.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/area.py @@ -4,7 +4,7 @@ import typing as t -from pydantic import Field, root_validator, validator +from pydantic import Field, field_validator, model_validator from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.storage.rawstudy.model.filesystem.config.field_validators import ( @@ -42,7 +42,7 @@ class OptimizationProperties(IniProperties): >>> opt = OptimizationProperties(**obj) - >>> pprint(opt.dict(by_alias=True), width=80) + >>> pprint(opt.model_dump(by_alias=True), width=80) {'filtering': {'filter-synthesis': 'hourly, daily, weekly, monthly, annual', 'filter-year-by-year': 'hourly, annual'}, 'nodal optimization': {'dispatchable-hydro-power': False, @@ -63,7 +63,7 @@ class OptimizationProperties(IniProperties): Convert the object to a dictionary for writing to a configuration file: - >>> pprint(opt.dict(by_alias=True, exclude_defaults=True), width=80) + >>> pprint(opt.model_dump(by_alias=True, exclude_defaults=True), width=80) {'filtering': {'filter-synthesis': 'hourly, weekly, monthly, annual', 'filter-year-by-year': 'hourly, monthly, annual'}, 'nodal optimization': {'dispatchable-hydro-power': False, @@ -77,7 +77,7 @@ class FilteringSection(IniProperties): filter_synthesis: str = Field("", alias="filter-synthesis") filter_year_by_year: str = Field("", alias="filter-year-by-year") - @validator("filter_synthesis", "filter_year_by_year", pre=True) + @field_validator("filter_synthesis", "filter_year_by_year", mode="before") def _validate_filtering(cls, v: t.Any) -> str: return validate_filtering(v) @@ -92,7 +92,7 @@ class ModalOptimizationSection(IniProperties): spread_spilled_energy_cost: float = Field(default=0.0, ge=0, alias="spread-spilled-energy-cost") filtering: FilteringSection = Field( - default_factory=FilteringSection, + default_factory=FilteringSection, # type: ignore alias="filtering", ) nodal_optimization: ModalOptimizationSection = Field( @@ -147,33 +147,29 @@ class AreaUI(IniProperties): ... "color_b": 255, ... } >>> ui = AreaUI(**obj) - >>> pprint(ui.dict(by_alias=True), width=80) + >>> pprint(ui.model_dump(by_alias=True), width=80) {'colorRgb': '#0080FF', 'x': 1148, 'y': 144} Update the color: >>> ui.color_rgb = (192, 168, 127) - >>> pprint(ui.dict(by_alias=True), width=80) + >>> pprint(ui.model_dump(by_alias=True), width=80) {'colorRgb': '#C0A87F', 'x': 1148, 'y': 144} """ - x: int = Field(0, description="x coordinate of the area in the map") - y: int = Field(0, description="y coordinate of the area in the map") - color_rgb: str = Field( - "#E66C2C", - alias="colorRgb", - description="color of the area in the map", - ) + x: t.Optional[int] = 0 # x coordinate of the area in the map + y: t.Optional[int] = 0 # y coordinate of the area in the map + color_rgb: t.Optional[str] = "#E66C2C" # color of the area in the map - @validator("color_rgb", pre=True) + @field_validator("color_rgb", mode="before") def _validate_color_rgb(cls, v: t.Any) -> str: return validate_color_rgb(v) - @root_validator(pre=True) + @model_validator(mode="before") def _validate_colors(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: return validate_colors(values) - def to_config(self) -> t.Mapping[str, t.Any]: + def to_config(self) -> t.Dict[str, t.Any]: """ Convert the object to a dictionary for writing to a configuration file: @@ -186,6 +182,7 @@ def to_config(self) -> t.Mapping[str, t.Any]: >>> pprint(ui.to_config(), width=80) {'color_b': 255, 'color_g': 128, 'color_r': 0, 'x': 1148, 'y': 144} """ + assert self.color_rgb is not None r = int(self.color_rgb[1:3], 16) g = int(self.color_rgb[3:5], 16) b = int(self.color_rgb[5:7], 16) @@ -204,7 +201,7 @@ class UIProperties(IniProperties): UIProperties has default values for `style` and `layers`: >>> ui = UIProperties() - >>> pprint(ui.dict(), width=80) + >>> pprint(ui.model_dump(), width=80) {'layer_styles': {0: {'color_rgb': '#E66C2C', 'x': 0, 'y': 0}}, 'layers': {0}, 'style': {'color_rgb': '#E66C2C', 'x': 0, 'y': 0}} @@ -232,7 +229,7 @@ class UIProperties(IniProperties): ... } >>> ui = UIProperties(**obj) - >>> pprint(ui.dict(), width=80) + >>> pprint(ui.model_dump(), width=80) {'layer_styles': {0: {'color_rgb': '#0080FF', 'x': 1148, 'y': 144}, 4: {'color_rgb': '#0080FF', 'x': 1148, 'y': 144}, 6: {'color_rgb': '#C0A863', 'x': 1148, 'y': 144}, @@ -247,38 +244,27 @@ class UIProperties(IniProperties): default_factory=AreaUI, description="style of the area in the map: coordinates and color", ) - layers: t.Set[int] = Field( - default_factory=set, - description="layers where the area is visible", - ) + layers: t.Set[int] = {0} # layers where the area is visible layer_styles: t.Dict[int, AreaUI] = Field( default_factory=dict, description="style of the area in each layer", alias="layerStyles", ) - @root_validator(pre=True) - def _set_default_style(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: + @staticmethod + def _set_default_style(values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: """Defined the default style if missing.""" - style = values.get("style") + style = values.get("style", None) if style is None: values["style"] = AreaUI() elif isinstance(style, dict): values["style"] = AreaUI(**style) else: - values["style"] = AreaUI(**style.dict()) - return values - - @root_validator(pre=True) - def _set_default_layers(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: - """Define the default layers if missing.""" - _layers = values.get("layers") - if _layers is None: - values["layers"] = {0} + values["style"] = AreaUI(**style.model_dump()) return values - @root_validator(pre=True) - def _set_default_layer_styles(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: + @staticmethod + def _set_default_layer_styles(values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: """Define the default layer styles if missing.""" layer_styles = values.get("layer_styles") if layer_styles is None: @@ -290,13 +276,15 @@ def _set_default_layer_styles(cls, values: t.MutableMapping[str, t.Any]) -> t.Ma if isinstance(style, dict): values["layer_styles"][key] = AreaUI(**style) else: - values["layer_styles"][key] = AreaUI(**style.dict()) + values["layer_styles"][key] = AreaUI(**style.model_dump()) else: raise TypeError(f"Invalid type for layer_styles: {type(layer_styles)}") return values - @root_validator(pre=True) + @model_validator(mode="before") def _validate_layers(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: + cls._set_default_style(values) + cls._set_default_layer_styles(values) # Parse the `[ui]` section (if any) ui_section = values.pop("ui", {}) if ui_section: @@ -335,7 +323,7 @@ def _validate_layers(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str return values - def to_config(self) -> t.Mapping[str, t.Mapping[str, t.Any]]: + def to_config(self) -> t.Dict[str, t.Dict[str, t.Any]]: """ Convert the object to a dictionary for writing to a configuration file: @@ -363,7 +351,7 @@ def to_config(self) -> t.Mapping[str, t.Mapping[str, t.Any]]: 'x': 1148, 'y': 144}} """ - obj: t.MutableMapping[str, t.MutableMapping[str, t.Any]] = { + obj: t.Dict[str, t.Dict[str, t.Any]] = { "ui": {}, "layerX": {}, "layerY": {}, @@ -374,6 +362,7 @@ def to_config(self) -> t.Mapping[str, t.Mapping[str, t.Any]]: for layer, style in self.layer_styles.items(): obj["layerX"][str(layer)] = style.x obj["layerY"][str(layer)] = style.y + assert style.color_rgb is not None r = int(style.color_rgb[1:3], 16) g = int(style.color_rgb[3:5], 16) b = int(style.color_rgb[5:7], 16) @@ -393,7 +382,7 @@ class AreaFolder(IniProperties): Create and validate a new AreaProperties object from a dictionary read from a configuration file. >>> obj = AreaFolder() - >>> pprint(obj.dict(), width=80) + >>> pprint(obj.model_dump(), width=80) {'adequacy_patch': None, 'optimization': {'filtering': {'filter_synthesis': '', 'filter_year_by_year': ''}, @@ -438,7 +427,7 @@ class AreaFolder(IniProperties): ... } >>> obj = AreaFolder.construct(**data) - >>> pprint(obj.dict(), width=80) + >>> pprint(obj.model_dump(), width=80) {'adequacy_patch': None, 'optimization': {'filtering': {'filter-synthesis': 'annual, centennial'}, 'nodal optimization': {'spread-spilled-energy-cost': '15.5', @@ -500,7 +489,7 @@ class ThermalAreasProperties(IniProperties): ... }, ... } >>> area = ThermalAreasProperties(**obj) - >>> pprint(area.dict(), width=80) + >>> pprint(area.model_dump(), width=80) {'spilled_energy_cost': {'cz': 100.0}, 'unserverd_energy_cost': {'at': 4000.8, 'be': 3500.0, @@ -511,7 +500,7 @@ class ThermalAreasProperties(IniProperties): >>> area.unserverd_energy_cost["at"] = 6500.0 >>> area.unserverd_energy_cost["fr"] = 0.0 - >>> pprint(area.dict(), width=80) + >>> pprint(area.model_dump(), width=80) {'spilled_energy_cost': {'cz': 100.0}, 'unserverd_energy_cost': {'at': 6500.0, 'be': 3500.0, 'de': 1250.0, 'fr': 0.0}} @@ -534,7 +523,7 @@ class ThermalAreasProperties(IniProperties): description="spilled energy cost (€/MWh) of each area", ) - @validator("unserverd_energy_cost", "spilled_energy_cost", pre=True) + @field_validator("unserverd_energy_cost", "spilled_energy_cost", mode="before") def _validate_energy_cost(cls, v: t.Any) -> t.MutableMapping[str, float]: if isinstance(v, dict): return {str(k): float(v) for k, v in v.items()} diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py index 1c84019294..dc4a8e12b3 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py @@ -7,15 +7,15 @@ import functools import typing as t -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Field @functools.total_ordering class ItemProperties( BaseModel, - extra=Extra.forbid, + extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Common properties related to thermal and renewable clusters, and short-term storage. @@ -35,7 +35,7 @@ class ItemProperties( group: str = Field(default="", description="Cluster group") - name: str = Field(description="Cluster name", regex=r"[a-zA-Z0-9_(),& -]+") + name: str = Field(description="Cluster name", pattern=r"[a-zA-Z0-9_(),& -]+") def __lt__(self, other: t.Any) -> bool: """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py b/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py index 74f93f5c46..f5272daa0c 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py @@ -3,7 +3,7 @@ _ALL_FILTERING = ["hourly", "daily", "weekly", "monthly", "annual"] -def extract_filtering(v: t.Any) -> t.Sequence[str]: +def extract_filtering(v: t.Any) -> t.List[str]: """ Extract filtering values from a comma-separated list of values. """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/files.py b/antarest/study/storage/rawstudy/model/filesystem/config/files.py index b8140b8c87..8013634bea 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/files.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/files.py @@ -14,7 +14,6 @@ DEFAULT_GROUP, DEFAULT_OPERATOR, DEFAULT_TIMESTEP, - BindingConstraintFrequency, ) from antarest.study.storage.rawstudy.model.filesystem.config.exceptions import ( SimulationParsingError, diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py b/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py index b4b1a3c3f3..1efc0fef1d 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py @@ -1,21 +1,21 @@ import typing as t -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import BaseModel, Field, model_validator __all__ = ("IgnoreCaseIdentifier", "LowerCaseIdentifier") class IgnoreCaseIdentifier( BaseModel, - extra=Extra.forbid, + extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Base class for all configuration sections with an ID. """ - id: str = Field(description="ID (section name)", regex=r"[a-zA-Z0-9_(),& -]+") + id: str = Field(description="ID (section name)", pattern=r"[a-zA-Z0-9_(),& -]+") @classmethod def generate_id(cls, name: str) -> str: @@ -33,7 +33,7 @@ def generate_id(cls, name: str) -> str: return transform_name_to_id(name, lower=False) - @root_validator(pre=True) + @model_validator(mode="before") def validate_id(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: """ Calculate an ID based on the name, if not provided. @@ -44,18 +44,21 @@ def validate_id(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.A Returns: The updated values. """ - if storage_id := values.get("id"): - # If the ID is provided, it comes from a INI section name. - # In some legacy case, the ID was in lower case, so we need to convert it. - values["id"] = cls.generate_id(storage_id) - return values - if not values.get("name"): - return values - name = values["name"] - if storage_id := cls.generate_id(name): - values["id"] = storage_id - else: - raise ValueError(f"Invalid name '{name}'.") + + # For some reason I can't explain, values can be an object. If so, no validation is needed. + if isinstance(values, dict): + if storage_id := values.get("id"): + # If the ID is provided, it comes from a INI section name. + # In some legacy case, the ID was in lower case, so we need to convert it. + values["id"] = cls.generate_id(storage_id) + return values + if not values.get("name"): + return values + name = values["name"] + if storage_id := cls.generate_id(name): + values["id"] = storage_id + else: + raise ValueError(f"Invalid name '{name}'.") return values @@ -64,7 +67,7 @@ class LowerCaseIdentifier(IgnoreCaseIdentifier): Base class for all configuration sections with a lower case ID. """ - id: str = Field(description="ID (section name)", regex=r"[a-z0-9_(),& -]+") + id: str = Field(description="ID (section name)", pattern=r"[a-z0-9_(),& -]+") @classmethod def generate_id(cls, name: str) -> str: diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py b/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py index 51f10a5ca5..5c6be19ba4 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py @@ -1,7 +1,7 @@ import json import typing as t -from pydantic import BaseModel, Extra +from pydantic import BaseModel class IniProperties( @@ -9,17 +9,17 @@ class IniProperties( # On reading, if the configuration contains an extra field, it is better # to forbid it, because it allows errors to be detected early. # Ignoring extra attributes can hide errors. - extra=Extra.forbid, + extra="forbid", # If a field is updated on assignment, it is also validated. validate_assignment=True, # On testing, we can use snake_case for field names. - allow_population_by_field_name=True, + populate_by_name=True, ): """ Base class for configuration sections. """ - def to_config(self) -> t.Mapping[str, t.Any]: + def to_config(self) -> t.Dict[str, t.Any]: """ Convert the object to a dictionary for writing to a configuration file (`*.ini`). @@ -28,14 +28,16 @@ def to_config(self) -> t.Mapping[str, t.Any]: """ config = {} - for field_name, field in self.__fields__.items(): + for field_name, field in self.model_fields.items(): value = getattr(self, field_name) if value is None: continue + alias = field.alias + assert alias is not None if isinstance(value, IniProperties): - config[field.alias] = value.to_config() + config[alias] = value.to_config() else: - config[field.alias] = json.loads(json.dumps(value)) + config[alias] = json.loads(json.dumps(value)) return config @classmethod @@ -44,7 +46,7 @@ def construct(cls, _fields_set: t.Optional[t.Set[str]] = None, **values: t.Any) Construct a new model instance from a dict of values, replacing aliases with real field names. """ # The pydantic construct() function does not allow aliases to be handled. - aliases = {(field.alias or name): name for name, field in cls.__fields__.items()} + aliases = {(field.alias or name): name for name, field in cls.model_fields.items()} renamed_values = {aliases.get(k, k): v for k, v in values.items()} if _fields_set is not None: _fields_set = {aliases.get(f, f) for f in _fields_set} diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/links.py b/antarest/study/storage/rawstudy/model/filesystem/config/links.py index 7ebc0e2176..ec97c52895 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/links.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/links.py @@ -4,7 +4,7 @@ import typing as t -from pydantic import Field, root_validator, validator +from pydantic import Field, field_validator, model_validator from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.storage.rawstudy.model.filesystem.config.field_validators import ( @@ -84,7 +84,7 @@ class LinkProperties(IniProperties): >>> opt = LinkProperties(**obj) - >>> pprint(opt.dict(by_alias=True), width=80) + >>> pprint(opt.model_dump(by_alias=True), width=80) {'asset-type': , 'colorRgb': '#50C0FF', 'comments': 'This is a link', @@ -134,20 +134,20 @@ class LinkProperties(IniProperties): description="color of the area in the map", ) - @validator("filter_synthesis", "filter_year_by_year", pre=True) + @field_validator("filter_synthesis", "filter_year_by_year", mode="before") def _validate_filtering(cls, v: t.Any) -> str: return validate_filtering(v) - @validator("color_rgb", pre=True) + @field_validator("color_rgb", mode="before") def _validate_color_rgb(cls, v: t.Any) -> str: return validate_color_rgb(v) - @root_validator(pre=True) + @model_validator(mode="before") def _validate_colors(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: return validate_colors(values) # noinspection SpellCheckingInspection - def to_config(self) -> t.Mapping[str, t.Any]: + def to_config(self) -> t.Dict[str, t.Any]: """ Convert the object to a dictionary for writing to a configuration file. """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/model.py b/antarest/study/storage/rawstudy/model/filesystem/config/model.py index 46e02bb051..f0a0e2ab11 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/model.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/model.py @@ -2,7 +2,7 @@ import typing as t from pathlib import Path -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from antarest.core.utils.utils import DTO from antarest.study.business.enum_ignore_case import EnumIgnoreCase @@ -52,7 +52,7 @@ class Link(BaseModel, extra="ignore"): filters_synthesis: t.List[str] = Field(default_factory=list) filters_year: t.List[str] = Field(default_factory=list) - @root_validator(pre=True) + @model_validator(mode="before") def validation(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[str, t.Any]: # note: field names are in kebab-case in the INI file filters_synthesis = values.pop("filter-synthesis", values.pop("filters_synthesis", "")) @@ -82,7 +82,7 @@ class DistrictSet(BaseModel): Object linked to /inputs/sets.ini information """ - ALL = ["hourly", "daily", "weekly", "monthly", "annual"] + ALL: t.List[str] = ["hourly", "daily", "weekly", "monthly", "annual"] name: t.Optional[str] = None inverted_set: bool = False areas: t.Optional[t.List[str]] = None diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py b/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py index 3355ba571a..5f71d89a29 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/st_storage.py @@ -97,7 +97,7 @@ class STStorage880Properties(STStorageProperties): # Activity status: # - True: the plant may generate. # - False: Ignored by the simulator. - enabled: bool = Field(default=True, description="Activity status") + enabled: t.Optional[bool] = True # Activity status # noinspection SpellCheckingInspection diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py b/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py index dcd0bc7729..18165c2010 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py @@ -427,4 +427,4 @@ def create_thermal_config(study_version: t.Union[str, int], **kwargs: t.Any) -> ValueError: If the study version is not supported. """ cls = get_thermal_config_cls(study_version) - return cls(**kwargs) + return cls.model_validate(kwargs) diff --git a/antarest/study/storage/rawstudy/model/filesystem/factory.py b/antarest/study/storage/rawstudy/model/filesystem/factory.py index 040e747629..7d70756c71 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/factory.py +++ b/antarest/study/storage/rawstudy/model/filesystem/factory.py @@ -92,7 +92,7 @@ def _create_from_fs_unsafe( from_cache = self.cache.get(cache_id) if from_cache is not None: logger.info(f"Study {study_id} read from cache") - config = FileStudyTreeConfigDTO.parse_obj(from_cache).to_build_config() + config = FileStudyTreeConfigDTO.model_validate(from_cache).to_build_config() if output_path: config.output_path = output_path config.outputs = parse_outputs(output_path) @@ -106,7 +106,7 @@ def _create_from_fs_unsafe( logger.info(f"Cache new entry from StudyFactory (studyID: {study_id})") self.cache.put( cache_id, - FileStudyTreeConfigDTO.from_build_config(config).dict(), + FileStudyTreeConfigDTO.from_build_config(config).model_dump(), ) return result diff --git a/antarest/study/storage/rawstudy/raw_study_service.py b/antarest/study/storage/rawstudy/raw_study_service.py index 4990d70bf6..1fc66caa6e 100644 --- a/antarest/study/storage/rawstudy/raw_study_service.py +++ b/antarest/study/storage/rawstudy/raw_study_service.py @@ -92,7 +92,7 @@ def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: t.Option metadata.updated_at = metadata.updated_at or datetime.utcnow() if metadata.additional_data is None: metadata.additional_data = StudyAdditionalData() - metadata.additional_data.patch = metadata.additional_data.patch or Patch().json() + metadata.additional_data.patch = metadata.additional_data.patch or Patch().model_dump_json() metadata.additional_data.author = metadata.additional_data.author or "Unknown" else: diff --git a/antarest/study/storage/study_download_utils.py b/antarest/study/storage/study_download_utils.py index 0c922fed03..55d91cf35e 100644 --- a/antarest/study/storage/study_download_utils.py +++ b/antarest/study/storage/study_download_utils.py @@ -333,7 +333,7 @@ def export( if filetype == ExportFormat.JSON: with open(target_file, "w") as fh: json.dump( - matrix.dict(), + matrix.model_dump(), fh, ensure_ascii=False, allow_nan=True, diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index 4ac5070a69..72aa744922 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -211,7 +211,7 @@ def _extract_cluster(self, study: FileStudy, area_id: str, cluster_id: str, rene create_cluster_command( area_id=area_id, cluster_name=cluster.id, - parameters=cluster.dict(by_alias=True, exclude_defaults=True, exclude={"id"}), + parameters=cluster.model_dump(by_alias=True, exclude_defaults=True, exclude={"id"}), command_context=self.command_context, ), self.generate_replace_matrix( @@ -311,6 +311,7 @@ def extract_district(self, study: FileStudy, district_id: str) -> t.List[IComman district_config = study_config.sets[district_id] base_filter = DistrictBaseFilter.add_all if district_config.inverted_set else DistrictBaseFilter.remove_all district_fetched_config = study_tree.get(["input", "areas", "sets", district_id]) + assert district_config.name is not None study_commands.append( CreateDistrict( name=district_config.name, @@ -370,7 +371,8 @@ def extract_binding_constraint( matrices[name] = matrix["data"] # Create the command to create the binding constraint - create_cmd = CreateBindingConstraint(**binding, **matrices, coeffs=terms, command_context=self.command_context) + kwargs = {**binding, **matrices, "coeffs": terms, "command_context": self.command_context} + create_cmd = CreateBindingConstraint.model_validate(kwargs) return [create_cmd] @@ -404,8 +406,8 @@ def generate_update_playlist( config = study_tree.get(["settings", "generaldata"]) playlist = get_playlist(config) return UpdatePlaylist( - items=playlist.keys() if playlist else None, - weights=({year for year, weight in playlist.items() if weight != 1} if playlist else None), + items=list(playlist.keys()) if playlist else None, + weights=({year: weight for year, weight in playlist.items() if weight != 1} if playlist else None), active=bool(playlist and len(playlist) > 0), reverse=False, command_context=self.command_context, @@ -441,6 +443,7 @@ def generate_update_district( study_tree = study.tree district_config = study_config.sets[district_id] district_fetched_config = study_tree.get(["input", "areas", "sets", district_id]) + assert district_config.name is not None return UpdateDistrict( id=district_config.name, base_filter=DistrictBaseFilter.add_all if district_config.inverted_set else DistrictBaseFilter.remove_all, diff --git a/antarest/study/storage/variantstudy/business/command_reverter.py b/antarest/study/storage/variantstudy/business/command_reverter.py index 9cf5d70b13..cbc4c44de6 100644 --- a/antarest/study/storage/variantstudy/business/command_reverter.py +++ b/antarest/study/storage/variantstudy/business/command_reverter.py @@ -123,7 +123,7 @@ def _revert_update_binding_constraint( if matrix is not None: args[matrix_name] = matrix_service.get_matrix_id(matrix) - return [UpdateBindingConstraint(**args)] + return [UpdateBindingConstraint.model_validate(args)] return base_command.get_command_extractor().extract_binding_constraint(base, base_command.id) diff --git a/antarest/study/storage/variantstudy/business/utils.py b/antarest/study/storage/variantstudy/business/utils.py index 75396ccbc6..b03e5fec2e 100644 --- a/antarest/study/storage/variantstudy/business/utils.py +++ b/antarest/study/storage/variantstudy/business/utils.py @@ -110,7 +110,7 @@ def transform_command_to_dto( commands_dto.append( CommandDTO( action=prev_command.command_name.value, - args=cur_command_args_batch, + args=cur_command_args_batch, # type: ignore ) ) cur_command_args_batch = [command.to_dto().args] @@ -118,5 +118,5 @@ def transform_command_to_dto( cur_dto = ref_commands_dto[cur_dto_index] cur_dto_arg_count = 1 if isinstance(cur_dto.args, dict) else len(cur_dto.args) prev_command = command - commands_dto.append(CommandDTO(action=prev_command.command_name.value, args=cur_command_args_batch)) + commands_dto.append(CommandDTO(action=prev_command.command_name.value, args=cur_command_args_batch)) # type: ignore return commands_dto diff --git a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py index 8d7464c5aa..4c6f931bf5 100644 --- a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py +++ b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py @@ -1,7 +1,6 @@ import typing as t from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( - DEFAULT_TIMESTEP, BindingConstraintFrequency, BindingConstraintOperator, ) diff --git a/antarest/study/storage/variantstudy/command_factory.py b/antarest/study/storage/variantstudy/command_factory.py index 45d325585b..05433350d5 100644 --- a/antarest/study/storage/variantstudy/command_factory.py +++ b/antarest/study/storage/variantstudy/command_factory.py @@ -1,3 +1,4 @@ +import copy import typing as t from antarest.core.model import JSON @@ -105,10 +106,15 @@ def to_command(self, command_dto: CommandDTO) -> t.List[ICommand]: """ args = command_dto.args if isinstance(args, dict): - return [self._to_single_command(command_dto.action, args, command_dto.version, command_dto.id)] + # In some cases, pydantic can modify inplace the given args. + # We don't want that so before doing so we copy the dictionnary. + new_args = copy.deepcopy(args) + return [self._to_single_command(command_dto.action, new_args, command_dto.version, command_dto.id)] elif isinstance(args, list): return [ - self._to_single_command(command_dto.action, argument, command_dto.version, command_dto.id) + self._to_single_command( + command_dto.action, copy.deepcopy(argument), command_dto.version, command_dto.id + ) for argument in args ] raise NotImplementedError() diff --git a/antarest/study/storage/variantstudy/model/command/create_area.py b/antarest/study/storage/variantstudy/model/command/create_area.py index 0ef68b61d3..b3411881f9 100644 --- a/antarest/study/storage/variantstudy/model/command/create_area.py +++ b/antarest/study/storage/variantstudy/model/command/create_area.py @@ -49,8 +49,8 @@ class CreateArea(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_AREA - version = 1 + command_name: CommandName = CommandName.CREATE_AREA + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 2925e67c3d..6e95ba18e5 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -1,13 +1,12 @@ -import json import typing as t from abc import ABCMeta from enum import Enum import numpy as np -from pydantic import BaseModel, Extra, Field, root_validator, validator +from pydantic import BaseModel, Field, field_validator, model_validator from antarest.matrixstore.model import MatrixData -from antarest.study.business.all_optional_meta import AllOptionalMetaclass +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( DEFAULT_GROUP, DEFAULT_OPERATOR, @@ -79,18 +78,24 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp # ================================================================================= -class BindingConstraintPropertiesBase(BaseModel, extra=Extra.forbid, allow_population_by_field_name=True): - enabled: bool = True +class BindingConstraintPropertiesBase(BaseModel, extra="forbid", populate_by_name=True): + enabled: t.Optional[bool] = True time_step: BindingConstraintFrequency = Field(DEFAULT_TIMESTEP, alias="type") - operator: BindingConstraintOperator = DEFAULT_OPERATOR - comments: str = "" + operator: t.Optional[BindingConstraintOperator] = DEFAULT_OPERATOR + comments: t.Optional[str] = "" + + @model_validator(mode="before") + def replace_with_alias(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + if "type" in values: + values["time_step"] = values.pop("type") + return values @classmethod def from_dict(cls, **attrs: t.Any) -> "BindingConstraintPropertiesBase": """ Instantiate a class from a dictionary excluding unknown or `None` fields. """ - attrs = {k: v for k, v in attrs.items() if k in cls.__fields__ and v is not None} + attrs = {k: v for k, v in attrs.items() if k in cls.model_fields and v is not None} return cls(**attrs) @@ -98,7 +103,7 @@ class BindingConstraintProperties830(BindingConstraintPropertiesBase): filter_year_by_year: str = Field("", alias="filter-year-by-year") filter_synthesis: str = Field("", alias="filter-synthesis") - @validator("filter_synthesis", "filter_year_by_year", pre=True) + @field_validator("filter_synthesis", "filter_year_by_year", mode="before") def _validate_filtering(cls, v: t.Any) -> str: return validate_filtering(v) @@ -142,7 +147,8 @@ def create_binding_constraint_config(study_version: t.Union[str, int], **kwargs: return cls.from_dict(**kwargs) -class OptionalProperties(BindingConstraintProperties870, metaclass=AllOptionalMetaclass, use_none=True): +@all_optional_model +class OptionalProperties(BindingConstraintProperties870): pass @@ -151,32 +157,18 @@ class OptionalProperties(BindingConstraintProperties870, metaclass=AllOptionalMe # ================================================================================= -class BindingConstraintMatrices(BaseModel, extra=Extra.forbid, allow_population_by_field_name=True): +@camel_case_model +class BindingConstraintMatrices(BaseModel, extra="forbid", populate_by_name=True): """ Class used to store the matrices of a binding constraint. """ - values: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="2nd member matrix for studies before v8.7", - ) - less_term_matrix: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="less term matrix for v8.7+ studies", - alias="lessTermMatrix", - ) - greater_term_matrix: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="greater term matrix for v8.7+ studies", - alias="greaterTermMatrix", - ) - equal_term_matrix: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="equal term matrix for v8.7+ studies", - alias="equalTermMatrix", - ) - - @root_validator(pre=True) + values: t.Optional[t.Union[MatrixType, str]] = None # 2nd member matrix for studies before v8.7 + less_term_matrix: t.Optional[t.Union[MatrixType, str]] = None # less term matrix for v8.7+ studies + greater_term_matrix: t.Optional[t.Union[MatrixType, str]] = None # greater term matrix for v8.7+ studies + equal_term_matrix: t.Optional[t.Union[MatrixType, str]] = None # equal term matrix for v8.7+ studies + + @model_validator(mode="before") def check_matrices( cls, values: t.Dict[str, t.Optional[t.Union[MatrixType, str]]] ) -> t.Dict[str, t.Optional[t.Union[MatrixType, str]]]: @@ -203,10 +195,10 @@ class AbstractBindingConstraintCommand(OptionalProperties, BindingConstraintMatr Abstract class for binding constraint commands. """ - coeffs: t.Optional[t.Dict[str, t.List[float]]] + coeffs: t.Optional[t.Dict[str, t.List[float]]] = None def to_dto(self) -> CommandDTO: - json_command = json.loads(self.json(exclude={"command_context"})) + json_command = self.model_dump(mode="json", exclude={"command_context"}) args = {} for field in ["enabled", "coeffs", "comments", "time_step", "operator"]: if json_command[field]: @@ -388,7 +380,7 @@ class CreateBindingConstraint(AbstractBindingConstraintCommand): Command used to create a binding constraint. """ - command_name = CommandName.CREATE_BINDING_CONSTRAINT + command_name: CommandName = CommandName.CREATE_BINDING_CONSTRAINT version: int = 1 # Properties of the `CREATE_BINDING_CONSTRAINT` command: @@ -415,8 +407,8 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: bd_id = transform_name_to_id(self.name) study_version = study_data.config.version - props = create_binding_constraint_config(study_version, **self.dict()) - obj = json.loads(props.json(by_alias=True)) + props = create_binding_constraint_config(study_version, **self.model_dump()) + obj = props.model_dump(mode="json", by_alias=True) new_binding = {"id": bd_id, "name": self.name, **obj} @@ -442,9 +434,9 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: bd_id = transform_name_to_id(self.name) args = {"id": bd_id, "command_context": other.command_context} - excluded_fields = frozenset(ICommand.__fields__) - self_command = json.loads(self.json(exclude=excluded_fields)) - other_command = json.loads(other.json(exclude=excluded_fields)) + excluded_fields = set(ICommand.model_fields) + self_command = self.model_dump(mode="json", exclude=excluded_fields) + other_command = other.model_dump(mode="json", exclude=excluded_fields) properties = [ "enabled", "coeffs", @@ -468,7 +460,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: if self_matrix_id != other_matrix_id: args[matrix_name] = other_matrix_id - return [UpdateBindingConstraint(**args)] + return [UpdateBindingConstraint.model_validate(args)] def match(self, other: "ICommand", equal: bool = False) -> bool: if not isinstance(other, self.__class__): diff --git a/antarest/study/storage/variantstudy/model/command/create_cluster.py b/antarest/study/storage/variantstudy/model/command/create_cluster.py index a884eb7b9c..d45def780b 100644 --- a/antarest/study/storage/variantstudy/model/command/create_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/create_cluster.py @@ -1,6 +1,6 @@ import typing as t -from pydantic import validator +from pydantic import Field, ValidationInfo, field_validator from antarest.core.model import JSON from antarest.core.utils.utils import assert_this @@ -26,46 +26,51 @@ class CreateCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_THERMAL_CLUSTER - version = 1 + command_name: CommandName = CommandName.CREATE_THERMAL_CLUSTER + version: int = 1 # Command parameters # ================== area_id: str cluster_name: str - parameters: t.Dict[str, str] - prepro: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = None - modulation: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = None + parameters: t.Dict[str, t.Any] + prepro: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = Field(None, validate_default=True) + modulation: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = Field(None, validate_default=True) - @validator("cluster_name") + @field_validator("cluster_name", mode="before") def validate_cluster_name(cls, val: str) -> str: valid_name = transform_name_to_id(val, lower=False) if valid_name != val: raise ValueError("Cluster name must only contains [a-zA-Z0-9],&,-,_,(,) characters") return val - @validator("prepro", always=True) + @field_validator("prepro", mode="before") def validate_prepro( - cls, v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], values: t.Any + cls, + v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], + values: t.Union[t.Dict[str, t.Any], ValidationInfo], ) -> t.Optional[t.Union[t.List[t.List[MatrixData]], str]]: + new_values = values if isinstance(values, dict) else values.data if v is None: - v = values["command_context"].generator_matrix_constants.get_thermal_prepro_data() + v = new_values["command_context"].generator_matrix_constants.get_thermal_prepro_data() return v - else: - return validate_matrix(v, values) + return validate_matrix(v, new_values) - @validator("modulation", always=True) + @field_validator("modulation", mode="before") def validate_modulation( - cls, v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], values: t.Any + cls, + v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], + values: t.Union[t.Dict[str, t.Any], ValidationInfo], ) -> t.Optional[t.Union[t.List[t.List[MatrixData]], str]]: + new_values = values if isinstance(values, dict) else values.data if v is None: - v = values["command_context"].generator_matrix_constants.get_thermal_prepro_modulation() + v = new_values["command_context"].generator_matrix_constants.get_thermal_prepro_modulation() return v else: - return validate_matrix(v, values) + return validate_matrix(v, new_values) def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: # Search the Area in the configuration diff --git a/antarest/study/storage/variantstudy/model/command/create_district.py b/antarest/study/storage/variantstudy/model/command/create_district.py index 9311345db0..f035847edd 100644 --- a/antarest/study/storage/variantstudy/model/command/create_district.py +++ b/antarest/study/storage/variantstudy/model/command/create_district.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, cast -from pydantic import validator +from pydantic import field_validator from antarest.study.storage.rawstudy.model.filesystem.config.model import ( DistrictSet, @@ -27,8 +27,8 @@ class CreateDistrict(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_DISTRICT - version = 1 + command_name: CommandName = CommandName.CREATE_DISTRICT + version: int = 1 # Command parameters # ================== @@ -39,7 +39,7 @@ class CreateDistrict(ICommand): output: bool = True comments: str = "" - @validator("name") + @field_validator("name") def validate_district_name(cls, val: str) -> str: valid_name = transform_name_to_id(val, lower=False) if valid_name != val: diff --git a/antarest/study/storage/variantstudy/model/command/create_link.py b/antarest/study/storage/variantstudy/model/command/create_link.py index f716b0c4d6..6c6b39c850 100644 --- a/antarest/study/storage/variantstudy/model/command/create_link.py +++ b/antarest/study/storage/variantstudy/model/command/create_link.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast -from pydantic import root_validator, validator +from pydantic import ValidationInfo, field_validator, model_validator from antarest.core.model import JSON from antarest.core.utils.utils import assert_this @@ -39,30 +39,31 @@ class CreateLink(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_LINK - version = 1 + command_name: CommandName = CommandName.CREATE_LINK + version: int = 1 # Command parameters # ================== area1: str area2: str - parameters: Optional[Dict[str, str]] = None + parameters: Optional[Dict[str, Any]] = None series: Optional[Union[List[List[MatrixData]], str]] = None direct: Optional[Union[List[List[MatrixData]], str]] = None indirect: Optional[Union[List[List[MatrixData]], str]] = None - @validator("series", "direct", "indirect", always=True) + @field_validator("series", "direct", "indirect", mode="before") def validate_series( - cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any + cls, v: Optional[Union[List[List[MatrixData]], str]], values: Union[Dict[str, Any], ValidationInfo] ) -> Optional[Union[List[List[MatrixData]], str]]: - return validate_matrix(v, values) if v is not None else v + new_values = values if isinstance(values, dict) else values.data + return validate_matrix(v, new_values) if v is not None else v - @root_validator - def validate_areas(cls, values: Dict[str, Any]) -> Any: - if values.get("area1") == values.get("area2"): + @model_validator(mode="after") + def validate_areas(self) -> "CreateLink": + if self.area1 == self.area2: raise ValueError("Cannot create link on same node") - return values + return self def _create_link_in_config(self, area_from: str, area_to: str, study_data: FileStudyTreeConfig) -> None: self.parameters = self.parameters or {} diff --git a/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py b/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py index c0c9aa44f8..fed67b3af4 100644 --- a/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py @@ -1,6 +1,6 @@ import typing as t -from pydantic import validator +from pydantic import field_validator from antarest.core.model import JSON from antarest.study.storage.rawstudy.model.filesystem.config.model import ( @@ -24,17 +24,17 @@ class CreateRenewablesCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_RENEWABLES_CLUSTER - version = 1 + command_name: CommandName = CommandName.CREATE_RENEWABLES_CLUSTER + version: int = 1 # Command parameters # ================== area_id: str cluster_name: str - parameters: t.Dict[str, str] + parameters: t.Dict[str, t.Any] - @validator("cluster_name") + @field_validator("cluster_name") def validate_cluster_name(cls, val: str) -> str: valid_name = transform_name_to_id(val, lower=False) if valid_name != val: diff --git a/antarest/study/storage/variantstudy/model/command/create_st_storage.py b/antarest/study/storage/variantstudy/model/command/create_st_storage.py index 90fb980c4b..db1c3dc258 100644 --- a/antarest/study/storage/variantstudy/model/command/create_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/create_st_storage.py @@ -1,9 +1,7 @@ -import json import typing as t import numpy as np -from pydantic import Field, validator -from pydantic.fields import ModelField +from pydantic import Field, ValidationInfo, model_validator from antarest.core.model import JSON from antarest.matrixstore.model import MatrixData @@ -40,34 +38,19 @@ class CreateSTStorage(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_ST_STORAGE - version = 1 + command_name: CommandName = CommandName.CREATE_ST_STORAGE + version: int = 1 # Command parameters # ================== - area_id: str = Field(description="Area ID", regex=r"[a-z0-9_(),& -]+") + area_id: str = Field(description="Area ID", pattern=r"[a-z0-9_(),& -]+") parameters: STStorageConfigType - pmax_injection: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="Charge capacity (modulation)", - ) - pmax_withdrawal: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="Discharge capacity (modulation)", - ) - lower_rule_curve: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="Lower rule curve (coefficient)", - ) - upper_rule_curve: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="Upper rule curve (coefficient)", - ) - inflows: t.Optional[t.Union[MatrixType, str]] = Field( - None, - description="Inflows (MW)", - ) + pmax_injection: t.Optional[t.Union[MatrixType, str]] = None # Charge capacity (modulation) + pmax_withdrawal: t.Optional[t.Union[MatrixType, str]] = None # Discharge capacity (modulation) + lower_rule_curve: t.Optional[t.Union[MatrixType, str]] = None # Lower rule curve (coefficient) + upper_rule_curve: t.Optional[t.Union[MatrixType, str]] = None # Upper rule curve (coefficient) + inflows: t.Optional[t.Union[MatrixType, str]] = None # Inflows (MW) @property def storage_id(self) -> str: @@ -79,12 +62,9 @@ def storage_name(self) -> str: """The label representing the name of the storage for the user.""" return self.parameters.name - @validator(*_MATRIX_NAMES, always=True) - def register_matrix( - cls, - v: t.Optional[t.Union[MatrixType, str]], - values: t.Dict[str, t.Any], - field: ModelField, + @staticmethod + def validate_field( + v: t.Optional[t.Union[MatrixType, str]], values: t.Dict[str, t.Any], field: str ) -> t.Optional[t.Union[MatrixType, str]]: """ Validates a matrix array or link, and store the matrix array in the matrix repository. @@ -100,7 +80,7 @@ def register_matrix( Args: v: The matrix array or link to be validated and registered. values: A dictionary containing additional values used for validation. - field: The field being validated. + field: The name of the validated parameter Returns: The ID of the validated and stored matrix prefixed by "matrix://". @@ -122,7 +102,7 @@ def register_matrix( "upper_rule_curve": constants.get_st_storage_upper_rule_curve, "inflows": constants.get_st_storage_inflows, } - method = methods[field.name] + method = methods[field] return method() if isinstance(v, str): # Check the matrix link @@ -136,7 +116,7 @@ def register_matrix( raise ValueError("Matrix values cannot contain NaN") # All matrices except "inflows" are constrained between 0 and 1 constrained = set(_MATRIX_NAMES) - {"inflows"} - if field.name in constrained and (np.any(array < 0) or np.any(array > 1)): + if field in constrained and (np.any(array < 0) or np.any(array > 1)): raise ValueError("Matrix values should be between 0 and 1") v = t.cast(MatrixType, array.tolist()) return validate_matrix(v, values) @@ -144,6 +124,13 @@ def register_matrix( # pragma: no cover raise TypeError(repr(v)) + @model_validator(mode="before") + def validate_matrices(cls, values: t.Union[t.Dict[str, t.Any], ValidationInfo]) -> t.Dict[str, t.Any]: + new_values = values if isinstance(values, dict) else values.data + for field in _MATRIX_NAMES: + new_values[field] = cls.validate_field(new_values.get(field, None), new_values, field) + return new_values + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ Applies configuration changes to the study data: add the short-term storage in the storages list. @@ -211,14 +198,14 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: Returns: The output of the command execution. """ - output, data = self._apply_config(study_data.config) + output, _ = self._apply_config(study_data.config) if not output.status: return output # Fill-in the "list.ini" file with the parameters. # On creation, it's better to write all the parameters in the file. config = study_data.tree.get(["input", "st-storage", "clusters", self.area_id, "list"]) - config[self.storage_id] = json.loads(self.parameters.json(by_alias=True, exclude={"id"})) + config[self.storage_id] = self.parameters.model_dump(mode="json", by_alias=True, exclude={"id"}) new_data: JSON = { "input": { @@ -240,7 +227,7 @@ def to_dto(self) -> CommandDTO: Returns: The DTO object representing the current command. """ - parameters = json.loads(self.parameters.json(by_alias=True, exclude={"id"})) + parameters = self.parameters.model_dump(mode="json", by_alias=True, exclude={"id"}) return CommandDTO( action=self.command_name.value, args={ @@ -305,7 +292,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: if getattr(self, attr) != getattr(other, attr) ] if self.parameters != other.parameters: - data: t.Dict[str, t.Any] = json.loads(other.parameters.json(by_alias=True, exclude={"id"})) + data: t.Dict[str, t.Any] = other.parameters.model_dump(mode="json", by_alias=True, exclude={"id"}) commands.append( UpdateConfig( target=f"input/st-storage/clusters/{self.area_id}/list/{self.storage_id}", diff --git a/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py b/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py index c604e9d6e0..f7ce46c775 100644 --- a/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py +++ b/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py @@ -31,8 +31,8 @@ class GenerateThermalClusterTimeSeries(ICommand): Command used to generate thermal cluster timeseries for an entire study """ - command_name = CommandName.GENERATE_THERMAL_CLUSTER_TIMESERIES - version = 1 + command_name: CommandName = CommandName.GENERATE_THERMAL_CLUSTER_TIMESERIES + version: int = 1 def _apply_config(self, study_data: FileStudyTreeConfig) -> OutputTuple: return CommandOutput(status=True, message="Nothing to do"), {} diff --git a/antarest/study/storage/variantstudy/model/command/icommand.py b/antarest/study/storage/variantstudy/model/command/icommand.py index 98b8756dda..ee2a5223ee 100644 --- a/antarest/study/storage/variantstudy/model/command/icommand.py +++ b/antarest/study/storage/variantstudy/model/command/icommand.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod import typing_extensions as te -from pydantic import BaseModel, Extra +from pydantic import BaseModel from antarest.core.utils.utils import assert_this from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig @@ -23,7 +23,7 @@ OutputTuple: te.TypeAlias = t.Tuple[CommandOutput, t.Dict[str, t.Any]] -class ICommand(ABC, BaseModel, extra=Extra.forbid, arbitrary_types_allowed=True, copy_on_model_validation="deep"): +class ICommand(ABC, BaseModel, extra="forbid", arbitrary_types_allowed=True): """ Interface for all commands that can be applied to a study. @@ -126,9 +126,9 @@ def match(self, other: "ICommand", equal: bool = False) -> bool: """ if not isinstance(other, self.__class__): return False - excluded_fields = set(ICommand.__fields__) - this_values = self.dict(exclude=excluded_fields) - that_values = other.dict(exclude=excluded_fields) + excluded_fields = set(ICommand.model_fields) + this_values = self.model_dump(exclude=excluded_fields) + that_values = other.model_dump(exclude=excluded_fields) return this_values == that_values @abstractmethod diff --git a/antarest/study/storage/variantstudy/model/command/remove_area.py b/antarest/study/storage/variantstudy/model/command/remove_area.py index f39c8aac9c..e17fbd599e 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_area.py +++ b/antarest/study/storage/variantstudy/model/command/remove_area.py @@ -21,7 +21,7 @@ class RemoveArea(ICommand): Command used to remove an area. """ - command_name = CommandName.REMOVE_AREA + command_name: CommandName = CommandName.REMOVE_AREA version: int = 1 # Properties of the `REMOVE_AREA` command: diff --git a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py index d2bd4de9ea..29b0a5de87 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py @@ -15,7 +15,7 @@ class RemoveBindingConstraint(ICommand): Command used to remove a binding constraint. """ - command_name = CommandName.REMOVE_BINDING_CONSTRAINT + command_name: CommandName = CommandName.REMOVE_BINDING_CONSTRAINT version: int = 1 # Properties of the `REMOVE_BINDING_CONSTRAINT` command: diff --git a/antarest/study/storage/variantstudy/model/command/remove_cluster.py b/antarest/study/storage/variantstudy/model/command/remove_cluster.py index a9460f4025..cbb3f4f75e 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/remove_cluster.py @@ -18,8 +18,8 @@ class RemoveCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_THERMAL_CLUSTER - version = 1 + command_name: CommandName = CommandName.REMOVE_THERMAL_CLUSTER + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/remove_district.py b/antarest/study/storage/variantstudy/model/command/remove_district.py index 36995b5c11..eb496c261b 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_district.py +++ b/antarest/study/storage/variantstudy/model/command/remove_district.py @@ -15,8 +15,8 @@ class RemoveDistrict(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_DISTRICT - version = 1 + command_name: CommandName = CommandName.REMOVE_DISTRICT + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/remove_link.py b/antarest/study/storage/variantstudy/model/command/remove_link.py index c82597f32b..c4492b81ba 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_link.py +++ b/antarest/study/storage/variantstudy/model/command/remove_link.py @@ -1,6 +1,6 @@ import typing as t -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig, transform_name_to_id from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy @@ -17,7 +17,7 @@ class RemoveLink(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_LINK + command_name: CommandName = CommandName.REMOVE_LINK version: int = 1 # Command parameters @@ -28,7 +28,7 @@ class RemoveLink(ICommand): area2: str # noinspection PyMethodParameters - @validator("area1", "area2", pre=True) + @field_validator("area1", "area2", mode="before") def _validate_id(cls, area: str) -> str: if isinstance(area, str): # Area IDs must be in lowercase and not empty. @@ -42,16 +42,12 @@ def _validate_id(cls, area: str) -> str: return area # noinspection PyMethodParameters - @root_validator(pre=False) - def _validate_link(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - area1 = values.get("area1") - area2 = values.get("area2") - - if area1 and area2: - # By convention, the source area is always the smallest one (in lexicographic order). - values["area1"], values["area2"] = sorted([area1, area2]) - - return values + @model_validator(mode="after") + def _validate_link(self) -> "RemoveLink": + # By convention, the source area is always the smallest one (in lexicographic order). + if self.area1 > self.area2: + self.area1, self.area2 = self.area2, self.area1 + return self def _check_link_exists(self, study_cfg: FileStudyTreeConfig) -> OutputTuple: """ diff --git a/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py b/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py index 0e41eef9b5..75656aa1ce 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py @@ -15,8 +15,8 @@ class RemoveRenewablesCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_RENEWABLES_CLUSTER - version = 1 + command_name: CommandName = CommandName.REMOVE_RENEWABLES_CLUSTER + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/remove_st_storage.py b/antarest/study/storage/variantstudy/model/command/remove_st_storage.py index 116f402c08..cf1bfe1744 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/remove_st_storage.py @@ -20,14 +20,14 @@ class RemoveSTStorage(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_ST_STORAGE - version = 1 + command_name: CommandName = CommandName.REMOVE_ST_STORAGE + version: int = 1 # Command parameters # ================== - area_id: str = Field(description="Area ID", regex=r"[a-z0-9_(),& -]+") - storage_id: str = Field(description="Short term storage ID", regex=r"[a-z0-9_(),& -]+") + area_id: str = Field(description="Area ID", pattern=r"[a-z0-9_(),& -]+") + storage_id: str = Field(description="Short term storage ID", pattern=r"[a-z0-9_(),& -]+") def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ diff --git a/antarest/study/storage/variantstudy/model/command/replace_matrix.py b/antarest/study/storage/variantstudy/model/command/replace_matrix.py index 6a51ca86b1..d6882c1590 100644 --- a/antarest/study/storage/variantstudy/model/command/replace_matrix.py +++ b/antarest/study/storage/variantstudy/model/command/replace_matrix.py @@ -1,6 +1,6 @@ import typing as t -from pydantic import validator +from pydantic import Field, ValidationInfo, field_validator from antarest.core.exceptions import ChildNotFoundError from antarest.core.model import JSON @@ -23,16 +23,18 @@ class ReplaceMatrix(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REPLACE_MATRIX - version = 1 + command_name: CommandName = CommandName.REPLACE_MATRIX + version: int = 1 # Command parameters # ================== target: str - matrix: t.Union[t.List[t.List[MatrixData]], str] + matrix: t.Union[t.List[t.List[MatrixData]], str] = Field(validate_default=True) - _validate_matrix = validator("matrix", each_item=True, always=True, allow_reuse=True)(validate_matrix) + @field_validator("matrix", mode="before") + def matrix_validator(cls, matrix: t.Union[t.List[t.List[MatrixData]], str], values: ValidationInfo) -> str: + return validate_matrix(matrix, values.data) def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: return ( diff --git a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py index 808aaaecc8..05ef24398b 100644 --- a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py @@ -1,4 +1,3 @@ -import json import typing as t from antarest.core.model import JSON @@ -90,7 +89,7 @@ class UpdateBindingConstraint(AbstractBindingConstraintCommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_BINDING_CONSTRAINT + command_name: CommandName = CommandName.UPDATE_BINDING_CONSTRAINT version: int = 1 # Command parameters @@ -171,14 +170,14 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: ) study_version = study_data.config.version - props = create_binding_constraint_config(study_version, **self.dict()) - obj = json.loads(props.json(by_alias=True, exclude_unset=True)) + props = create_binding_constraint_config(study_version, **self.model_dump()) + obj = props.model_dump(mode="json", by_alias=True, exclude_unset=True) updated_cfg = binding_constraints[index] updated_cfg.update(obj) - excluded_fields = set(ICommand.__fields__) | {"id"} - updated_properties = self.dict(exclude=excluded_fields, exclude_none=True) + excluded_fields = set(ICommand.model_fields) | {"id"} + updated_properties = self.model_dump(exclude=excluded_fields, exclude_none=True) # This 2nd check is here to remove the last term. if self.coeffs or updated_properties == {"coeffs": {}}: # Remove terms which IDs contain a "%" or a "." in their name @@ -191,8 +190,8 @@ def to_dto(self) -> CommandDTO: matrices = ["values"] + [m.value for m in TermMatrices] matrix_service = self.command_context.matrix_service - excluded_fields = frozenset(ICommand.__fields__) - json_command = json.loads(self.json(exclude=excluded_fields, exclude_none=True)) + excluded_fields = set(ICommand.model_fields) + json_command = self.model_dump(mode="json", exclude=excluded_fields, exclude_none=True) for key in json_command: if key in matrices: json_command[key] = matrix_service.get_matrix_id(json_command[key]) diff --git a/antarest/study/storage/variantstudy/model/command/update_comments.py b/antarest/study/storage/variantstudy/model/command/update_comments.py index 028cbc5060..f2dc803b7f 100644 --- a/antarest/study/storage/variantstudy/model/command/update_comments.py +++ b/antarest/study/storage/variantstudy/model/command/update_comments.py @@ -16,8 +16,8 @@ class UpdateComments(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_COMMENTS - version = 1 + command_name: CommandName = CommandName.UPDATE_COMMENTS + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_config.py b/antarest/study/storage/variantstudy/model/command/update_config.py index b19444f6ae..394fd494ed 100644 --- a/antarest/study/storage/variantstudy/model/command/update_config.py +++ b/antarest/study/storage/variantstudy/model/command/update_config.py @@ -32,8 +32,8 @@ class UpdateConfig(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_CONFIG - version = 1 + command_name: CommandName = CommandName.UPDATE_CONFIG + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_district.py b/antarest/study/storage/variantstudy/model/command/update_district.py index 1a14e37acd..ad9e8d53d7 100644 --- a/antarest/study/storage/variantstudy/model/command/update_district.py +++ b/antarest/study/storage/variantstudy/model/command/update_district.py @@ -16,17 +16,17 @@ class UpdateDistrict(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_DISTRICT - version = 1 + command_name: CommandName = CommandName.UPDATE_DISTRICT + version: int = 1 # Command parameters # ================== id: str - base_filter: Optional[DistrictBaseFilter] - filter_items: Optional[List[str]] - output: Optional[bool] - comments: Optional[str] + base_filter: Optional[DistrictBaseFilter] = None + filter_items: Optional[List[str]] = None + output: Optional[bool] = None + comments: Optional[str] = None def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: base_set = study_data.sets[self.id] diff --git a/antarest/study/storage/variantstudy/model/command/update_playlist.py b/antarest/study/storage/variantstudy/model/command/update_playlist.py index c70dfebbb5..2aad7478eb 100644 --- a/antarest/study/storage/variantstudy/model/command/update_playlist.py +++ b/antarest/study/storage/variantstudy/model/command/update_playlist.py @@ -16,8 +16,8 @@ class UpdatePlaylist(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_PLAYLIST - version = 1 + command_name: CommandName = CommandName.UPDATE_PLAYLIST + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_raw_file.py b/antarest/study/storage/variantstudy/model/command/update_raw_file.py index 3e7b3b8759..de033c564c 100644 --- a/antarest/study/storage/variantstudy/model/command/update_raw_file.py +++ b/antarest/study/storage/variantstudy/model/command/update_raw_file.py @@ -17,8 +17,8 @@ class UpdateRawFile(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_FILE - version = 1 + command_name: CommandName = CommandName.UPDATE_FILE + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py b/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py index ff8e1311ac..4edc33fd50 100644 --- a/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py +++ b/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py @@ -1,8 +1,8 @@ import typing as t import numpy as np -from requests.structures import CaseInsensitiveDict +from antarest.core.requests import CaseInsensitiveDict from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput @@ -33,13 +33,13 @@ class UpdateScenarioBuilder(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_SCENARIO_BUILDER - version = 1 + command_name: CommandName = CommandName.UPDATE_SCENARIO_BUILDER + version: int = 1 # Command parameters # ================== - data: t.Dict[str, t.Any] + data: t.Union[t.Dict[str, t.Any], t.Mapping[str, t.Any], t.MutableMapping[str, t.Any]] def _apply(self, study_data: FileStudy) -> CommandOutput: """ diff --git a/antarest/study/storage/variantstudy/model/command_context.py b/antarest/study/storage/variantstudy/model/command_context.py index a361d40959..aaba33a759 100644 --- a/antarest/study/storage/variantstudy/model/command_context.py +++ b/antarest/study/storage/variantstudy/model/command_context.py @@ -12,4 +12,3 @@ class CommandContext(BaseModel): class Config: arbitrary_types_allowed = True - copy_on_model_validation = False diff --git a/antarest/study/storage/variantstudy/model/model.py b/antarest/study/storage/variantstudy/model/model.py index e170bf4383..02aefe08ef 100644 --- a/antarest/study/storage/variantstudy/model/model.py +++ b/antarest/study/storage/variantstudy/model/model.py @@ -57,7 +57,7 @@ class CommandDTO(BaseModel): version: The version of the command. """ - id: t.Optional[str] + id: t.Optional[str] = None action: str args: t.Union[t.MutableSequence[JSON], JSON] version: int = 1 diff --git a/antarest/study/storage/variantstudy/snapshot_generator.py b/antarest/study/storage/variantstudy/snapshot_generator.py index ee4532349f..e969a9c1f7 100644 --- a/antarest/study/storage/variantstudy/snapshot_generator.py +++ b/antarest/study/storage/variantstudy/snapshot_generator.py @@ -121,7 +121,7 @@ def generate_snapshot( else: try: - notifier(results.json()) + notifier(results.model_dump_json()) except Exception as exc: # This exception is ignored, because it is not critical. logger.warning(f"Error while sending notification: {exc}", exc_info=True) @@ -191,7 +191,7 @@ def _read_additional_data(self, file_study: FileStudy) -> StudyAdditionalData: horizon = file_study.tree.get(url=["settings", "generaldata", "general", "horizon"]) author = file_study.tree.get(url=["study", "antares", "author"]) patch = self.patch_service.get_from_filestudy(file_study) - study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.json()) + study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.model_dump_json()) return study_additional_data def _update_cache(self, file_study: FileStudy) -> None: @@ -199,7 +199,7 @@ def _update_cache(self, file_study: FileStudy) -> None: self.cache.invalidate(f"{CacheConstants.RAW_STUDY}/{file_study.config.study_id}") self.cache.put( f"{CacheConstants.STUDY_FACTORY}/{file_study.config.study_id}", - FileStudyTreeConfigDTO.from_build_config(file_study.config).dict(), + FileStudyTreeConfigDTO.from_build_config(file_study.config).model_dump(), ) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index a0c03a8457..aacc8ff240 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -613,7 +613,7 @@ def callback(notifier: TaskUpdateNotifier) -> TaskResult: message=f"{study_id} generated successfully" if generate_result.success else f"{study_id} not generated", - return_value=generate_result.json(), + return_value=generate_result.model_dump_json(), ) metadata.generation_task = self.task_service.add_task( @@ -704,7 +704,7 @@ def notify(command_index: int, command_result: bool, command_message: str) -> No success=command_result, message=command_message, ) - notifier(command_result_obj.json()) + notifier(command_result_obj.model_dump_json()) self.event_bus.push( Event( type=EventType.STUDY_VARIANT_GENERATION_COMMAND_RESULT, diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index 1dbc4810f7..e5510eb176 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -715,7 +715,7 @@ def output_download( extra={"user": current_user.id}, ) params = RequestParameters(user=current_user) - accept = request.headers.get("Accept") + accept = request.headers["Accept"] filetype = ExportFormat.from_dto(accept) content = study_service.download_outputs( diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index 8a960da5f3..c77e3dde7d 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -17,7 +17,7 @@ from antarest.matrixstore.matrix_editor import MatrixEditInstruction from antarest.study.business.adequacy_patch_management import AdequacyPatchFormFields from antarest.study.business.advanced_parameters_management import AdvancedParamsFormFields -from antarest.study.business.allocation_management import AllocationFormFields, AllocationMatrix +from antarest.study.business.allocation_management import AllocationField, AllocationFormFields, AllocationMatrix from antarest.study.business.area_management import AreaCreationDTO, AreaInfoDTO, AreaType, LayerInfoDTO, UpdateAreaUi from antarest.study.business.areas.hydro_management import InflowStructure, ManagementOptionsFormFields from antarest.study.business.areas.properties_management import PropertiesFormFields @@ -48,7 +48,12 @@ ConstraintOutput, ConstraintTerm, ) -from antarest.study.business.correlation_management import CorrelationFormFields, CorrelationManager, CorrelationMatrix +from antarest.study.business.correlation_management import ( + AreaCoefficientItem, + CorrelationFormFields, + CorrelationManager, + CorrelationMatrix, +) from antarest.study.business.district_manager import DistrictCreationDTO, DistrictInfoDTO, DistrictUpdateDTO from antarest.study.business.general_management import GeneralFormFields from antarest.study.business.link_management import LinkInfoDTO @@ -110,7 +115,7 @@ def create_study_data_routes(study_service: StudyService, config: Config) -> API "/studies/{uuid}/areas", tags=[APITag.study_data], summary="Get all areas basic info", - response_model=t.Union[t.List[AreaInfoDTO], t.Dict[str, t.Any]], # type: ignore + response_model=t.Union[t.List[AreaInfoDTO], t.Dict[str, t.Any]], ) def get_areas( uuid: str, @@ -491,7 +496,7 @@ def get_inflow_structure( "/studies/{uuid}/areas/{area_id}/hydro/inflow-structure", tags=[APITag.study_data], summary="Update inflow structure values", - response_model=InflowStructure, + response_model=None, ) def update_inflow_structure( uuid: str, @@ -1132,7 +1137,7 @@ def get_binding_constraint_list( "/studies/{uuid}/bindingconstraints/{binding_constraint_id}", tags=[APITag.study_data], summary="Get binding constraint", - response_model=ConstraintOutput, # type: ignore + response_model=ConstraintOutput, ) def get_binding_constraint( uuid: str, @@ -1520,8 +1525,8 @@ def set_allocation_form_fields( ..., example=AllocationFormFields( allocation=[ - {"areaId": "EAST", "coefficient": 1}, - {"areaId": "NORTH", "coefficient": 0.20}, + AllocationField.model_validate({"areaId": "EAST", "coefficient": 1}), + AllocationField.model_validate({"areaId": "NORTH", "coefficient": 0.20}), ] ), ), @@ -1553,8 +1558,8 @@ def set_allocation_form_fields( def get_correlation_matrix( uuid: str, columns: t.Optional[str] = Query( - None, - examples={ + default=None, + openapi_examples={ "all areas": { "description": "get the correlation matrix for all areas (by default)", "value": "", @@ -1686,8 +1691,8 @@ def set_correlation_form_fields( ..., example=CorrelationFormFields( correlation=[ - {"areaId": "east", "coefficient": 80}, - {"areaId": "north", "coefficient": 20}, + AreaCoefficientItem.model_validate({"areaId": "east", "coefficient": 80}), + AreaCoefficientItem.model_validate({"areaId": "north", "coefficient": 20}), ] ), ), diff --git a/antarest/study/web/variant_blueprint.py b/antarest/study/web/variant_blueprint.py index 060ff167e9..4ea821fcc8 100644 --- a/antarest/study/web/variant_blueprint.py +++ b/antarest/study/web/variant_blueprint.py @@ -95,12 +95,7 @@ def create_variant( "/studies/{uuid}/variants", tags=[APITag.study_variant_management], summary="Get children variants", - responses={ - 200: { - "description": "The list of children study variant", - "model": List[StudyMetadataDTO], - } - }, + response_model=None, ) def get_variants( uuid: str, diff --git a/antarest/study/web/xpansion_studies_blueprint.py b/antarest/study/web/xpansion_studies_blueprint.py index 1b46af1a84..3bce9bb47b 100644 --- a/antarest/study/web/xpansion_studies_blueprint.py +++ b/antarest/study/web/xpansion_studies_blueprint.py @@ -127,7 +127,7 @@ def add_candidate( current_user: JWTUser = Depends(auth.get_current_user), ) -> XpansionCandidateDTO: logger.info( - f"Adding new candidate {xpansion_candidate_dto.dict(by_alias=True)} to study {uuid}", + f"Adding new candidate {xpansion_candidate_dto.model_dump(by_alias=True)} to study {uuid}", extra={"user": current_user.id}, ) params = RequestParameters(user=current_user) diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index 60e3215f5b..a2e2dff1eb 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -8,17 +8,7 @@ from zipfile import ZipFile import numpy as np - -try: - # The HTTPX equivalent of `requests.Session` is `httpx.Client`. - import httpx as requests - from httpx import Client as Session -except ImportError: - # noinspection PyUnresolvedReferences, PyPackageRequirements - import requests - - # noinspection PyUnresolvedReferences,PyPackageRequirements - from requests import Session +from httpx import Client from antarest.core.cache.business.local_chache import LocalCache from antarest.core.config import CacheConfig @@ -55,22 +45,17 @@ def __init__( study_id: str, host: Optional[str] = None, token: Optional[str] = None, - session: Optional[Session] = None, + session: Optional[Client] = None, ): self.study_id = study_id # todo: find the correct way to handle certificates. - # By default, Requests/Httpx verifies SSL certificates for HTTPS requests. + # By default, Httpx verifies SSL certificates for HTTPS requests. # When verify is set to `False`, requests will accept any TLS certificate presented - # by the server,and will ignore hostname mismatches and/or expired certificates, + # by the server, and will ignore hostname mismatches and/or expired certificates, # which will make your application vulnerable to man-in-the-middle (MitM) attacks. - # Setting verify to False may be useful during local development or testing. - if Session.__name__ == "Client": - # noinspection PyArgumentList - self.session = session or Session(verify=False) - else: - self.session = session or Session() - self.session.verify = False + # Setting verify to `False` may be useful during local development or testing. + self.session = session or Client(verify=False) self.host = host if session is None and host is None: @@ -102,7 +87,7 @@ def apply_commands( res = self.session.post( self.build_url(f"/v1/studies/{self.study_id}/commands"), - json=[command.dict() for command in commands], + json=[command.model_dump() for command in commands], ) res.raise_for_status() stopwatch.log_elapsed(lambda x: logger.info(f"Command upload done in {x}s")) @@ -212,7 +197,7 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: (commands_output_dir / COMMAND_FILE).write_text( json.dumps( - [command.dict(exclude={"id"}) for command in command_list], + [command.model_dump(exclude={"id"}) for command in command_list], indent=2, ) ) @@ -304,7 +289,7 @@ def generate_diff( (output_dir / COMMAND_FILE).write_text( json.dumps( - [command.to_dto().dict(exclude={"id"}) for command in diff_commands], + [command.to_dto().model_dump(exclude={"id"}) for command in diff_commands], indent=2, ) ) diff --git a/antarest/utils.py b/antarest/utils.py index 1f61717ada..fe33e6011a 100644 --- a/antarest/utils.py +++ b/antarest/utils.py @@ -1,18 +1,11 @@ -import datetime import logging from enum import Enum from pathlib import Path from typing import Any, Dict, Mapping, Optional, Tuple import redis -import sqlalchemy.ext.baked # type: ignore -import uvicorn # type: ignore from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT # type: ignore -from ratelimit import RateLimitMiddleware # type: ignore -from ratelimit.backends.redis import RedisBackend # type: ignore -from ratelimit.backends.simple import MemoryBackend # type: ignore -from sqlalchemy import create_engine +from sqlalchemy import create_engine # type: ignore from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.pool import NullPool # type: ignore diff --git a/antarest/worker/archive_worker.py b/antarest/worker/archive_worker.py index 0ef7bedc31..00b76bace0 100644 --- a/antarest/worker/archive_worker.py +++ b/antarest/worker/archive_worker.py @@ -46,10 +46,10 @@ def __init__( ) def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: - logger.info(f"Executing task {task_info.json()}") + logger.info(f"Executing task {task_info.model_dump_json()}") try: # sourcery skip: extract-method - archive_args = ArchiveTaskArgs.parse_obj(task_info.task_args) + archive_args = ArchiveTaskArgs.model_validate(task_info.task_args) dest = self.translate_path(Path(archive_args.dest)) src = self.translate_path(Path(archive_args.src)) stopwatch = StopWatch() @@ -63,7 +63,7 @@ def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return TaskResult(success=True, message="") except Exception as e: logger.warning( - f"Task {task_info.json()} failed", + f"Task {task_info.model_dump_json()} failed", exc_info=e, ) return TaskResult(success=False, message=str(e)) diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py index d37a8825f5..4ed221e782 100644 --- a/antarest/worker/simulator_worker.py +++ b/antarest/worker/simulator_worker.py @@ -69,7 +69,7 @@ def execute_kirshoff_constraint_generation_task(self, task_info: WorkerTaskComma def execute_timeseries_generation_task(self, task_info: WorkerTaskCommand) -> TaskResult: result = TaskResult(success=True, message="", return_value="") - task = GenerateTimeseriesTaskArgs.parse_obj(task_info.task_args) + task = GenerateTimeseriesTaskArgs.model_validate(task_info.task_args) binary = ( self.binaries[task.study_version] if task.study_version in self.binaries diff --git a/antarest/worker/worker.py b/antarest/worker/worker.py index d292d7a8b8..caf89c16a2 100644 --- a/antarest/worker/worker.py +++ b/antarest/worker/worker.py @@ -100,8 +100,8 @@ def _loop(self) -> None: time.sleep(1) async def _listen_for_tasks(self, event: Event) -> None: - logger.info(f"Accepting new task {event.json()}") - task_info = WorkerTaskCommand.parse_obj(event.payload) + logger.info(f"Accepting new task {event.model_dump_json()}") + task_info = WorkerTaskCommand.model_validate(event.payload) self.event_bus.push( Event( type=EventType.WORKER_TASK_STARTED, @@ -119,7 +119,7 @@ def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return self._execute_task(task_info) except Exception as e: logger.error( - f"Unexpected error occurred when executing task {task_info.json()}", + f"Unexpected error occurred when executing task {task_info.model_dump_json()}", exc_info=e, ) return TaskResult(success=False, message=repr(e)) diff --git a/pyproject.toml b/pyproject.toml index 1d01065859..90405665a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,13 @@ where = ["."] include = ["antarest*"] [tool.mypy] +exclude = "antarest/fastapi_jwt_auth/*" strict = true -files = "antarest/**/*.py" +files = "antarest" + +[[tool.mypy.overrides]] +module = ["antarest/fastapi_jwt_auth.*"] +follow_imports = "skip" [[tool.mypy.overrides]] module = [ @@ -71,7 +76,7 @@ line-length = 120 exclude = "(antares-?launcher/*|alembic/*)" [tool.coverage.run] -omit = ["antarest/tools/cli.py", "antarest/tools/admin.py"] +omit = ["antarest/tools/cli.py", "antarest/tools/admin.py", "antarest/fastapi_jwt_auth/*.py"] relative_files = true # avoids absolute path issues in CI [tool.isort] diff --git a/requirements-dev.txt b/requirements-dev.txt index 5ea6126c6c..973973dc76 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # Version of Black should match the versions set in `.github/workflows/main.yml` black~=23.7.0 isort~=5.12.0 -mypy~=1.4.1 +mypy~=1.11.1 pyinstaller==5.6.2 pyinstaller-hooks-contrib==2024.6 diff --git a/requirements-test.txt b/requirements-test.txt index 8e408b2677..9d58c1bb5b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,4 +5,5 @@ pytest-cov~=4.0.0 # In this version DataFrame conversion to Excel is done using 'xlsxwriter' library. # But Excel files reading is done using 'openpyxl' library, during testing only. -openpyxl~=3.1.2 \ No newline at end of file +openpyxl~=3.1.2 +jinja2~=3.1.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eb42c65793..4659b879ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,17 +2,34 @@ Antares-Launcher~=1.3.2 antares-study-version~=1.0.3 antares-timeseries-generation~=0.1.5 +# When you install `fastapi[all]`, you get FastAPI along with additional dependencies: +# - `uvicorn`: A fast ASGI server commonly used to run FastAPI applications. +# - `pydantic`: A data validation library integrated into FastAPI for defining data models and validating input and output of endpoints. +# - `httpx`: A modern HTTP client library for making HTTP requests from your FastAPI application. +# - `starlette`: The underlying ASGI framework used by FastAPI for handling HTTP requests. +# - `uvloop` (on certain systems): A fast event loop based on uvicorn that improves FastAPI application performance. +# - `python-multipart`: A library for handling multipart data in HTTP requests, commonly used for processing file uploads. +# - `watchdog`: A file watching library used by FastAPI's automatic reloading tool to update the application in real-time when files are modified. +# - `email-validator`: A library for email address validation, used for validating email fields in Pydantic models. +# - `python-dotenv`: A library for managing environment variables, commonly used to load application configurations from `.env` files. + +# We prefer to add only the specific libraries we need for our project +# and **manage their versions** for better control and to avoid unnecessary dependencies. +fastapi~=0.110.3 +uvicorn[standard]~=0.30.6 +pydantic~=2.8.2 +httpx~=0.27.0 +python-multipart~=0.0.9 + alembic~=1.7.5 asgi-ratelimit[redis]==0.7.0 bcrypt~=3.2.0 click~=8.0.3 contextvars~=2.4 -fastapi-jwt-auth~=0.5.0 -fastapi[all]~=0.73.0 filelock~=3.4.2 gunicorn~=20.1.0 -Jinja2~=3.0.3 jsonref~=0.2 +PyJWT~=2.9.0 MarkupSafe~=2.0.1 numpy~=1.22.1 pandas~=1.4.0 @@ -20,18 +37,13 @@ paramiko~=2.12.0 plyer~=2.0.0 psycopg2-binary==2.9.4 py7zr~=0.20.6 -pydantic~=1.9.0 PyQt5~=5.15.6 python-json-logger~=2.0.7 -python-multipart~=0.0.5 PyYAML~=5.4.1; python_version <= '3.9' PyYAML~=5.3.1; python_version > '3.9' redis~=4.1.2 -requests~=2.27.1 SQLAlchemy~=1.4.46 -starlette~=0.17.1 tables==3.6.1; python_version <= '3.8' tables==3.9.2; python_version > '3.8' -typing_extensions~=4.7.1 -uvicorn[standard]~=0.15.0 -xlsxwriter~=3.2.0 +typing_extensions~=4.12.2 +xlsxwriter~=3.2.0 \ No newline at end of file diff --git a/tests/cache/test_local_cache.py b/tests/cache/test_local_cache.py index b9fae75ee9..79deddfea4 100644 --- a/tests/cache/test_local_cache.py +++ b/tests/cache/test_local_cache.py @@ -29,11 +29,11 @@ def test_lifecycle(): id = "some_id" duration = 3600 timeout = int(time.time()) + duration - cache_element = LocalCacheElement(duration=duration, data=config.dict(), timeout=timeout) + cache_element = LocalCacheElement(duration=duration, data=config.model_dump(), timeout=timeout) # PUT - cache.put(id=id, data=config.dict(), duration=duration) + cache.put(id=id, data=config.model_dump(), duration=duration) assert cache.cache[id] == cache_element # GET - assert cache.get(id=id) == config.dict() + assert cache.get(id=id) == config.model_dump() diff --git a/tests/cache/test_redis_cache.py b/tests/cache/test_redis_cache.py index a7c76d07c1..fe47050edf 100644 --- a/tests/cache/test_redis_cache.py +++ b/tests/cache/test_redis_cache.py @@ -28,7 +28,7 @@ def test_lifecycle(): id = "some_id" redis_key = f"cache:{id}" duration = 3600 - cache_element = RedisCacheElement(duration=duration, data=config.dict()).json() + cache_element = RedisCacheElement(duration=duration, data=config.model_dump()).model_dump_json() # GET redis_client.get.return_value = cache_element @@ -39,7 +39,7 @@ def test_lifecycle(): # PUT duration = 7200 - cache_element = RedisCacheElement(duration=duration, data=config.dict()).json() - cache.put(id=id, data=config.dict(), duration=duration) + cache_element = RedisCacheElement(duration=duration, data=config.model_dump()).model_dump_json() + cache.put(id=id, data=config.model_dump(), duration=duration) redis_client.set.assert_called_once_with(redis_key, cache_element) redis_client.expire.assert_called_with(redis_key, duration) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index cfe36a244e..ba03ad815f 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -107,7 +107,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None: "status": TaskStatus.FAILED, "type": None, } - assert res.dict() == expected + assert res.model_dump() == expected # Test Case: add a task that fails and wait for it # ================================================ diff --git a/tests/eventbus/test_redis_event_bus.py b/tests/eventbus/test_redis_event_bus.py index 8e28ab1163..173ee311e3 100644 --- a/tests/eventbus/test_redis_event_bus.py +++ b/tests/eventbus/test_redis_event_bus.py @@ -17,7 +17,7 @@ def test_lifecycle(): payload="foo", permissions=PermissionInfo(public_mode=PublicMode.READ), ) - serialized = event.json() + serialized = event.model_dump_json() pubsub_mock.get_message.return_value = {"data": serialized} eventbus.push_event(event) redis_client.publish.assert_called_once_with("events", serialized) diff --git a/tests/eventbus/test_websocket_manager.py b/tests/eventbus/test_websocket_manager.py index fc512854c8..43b21ffdbb 100644 --- a/tests/eventbus/test_websocket_manager.py +++ b/tests/eventbus/test_websocket_manager.py @@ -25,7 +25,7 @@ async def test_subscriptions(self): await ws_manager.connect(mock_connection, user) assert len(ws_manager.active_connections) == 1 - ws_manager.process_message(subscribe_message.json(), mock_connection) + ws_manager.process_message(subscribe_message.model_dump_json(), mock_connection) connections = ws_manager.active_connections[0] assert len(connections.channel_subscriptions) == 1 assert connections.channel_subscriptions[0] == "foo" @@ -39,7 +39,7 @@ async def test_subscriptions(self): mock_connection.send_text.assert_has_calls([call("msg1"), call("msg2")]) - ws_manager.process_message(unsubscribe_message.json(), mock_connection) + ws_manager.process_message(unsubscribe_message.model_dump_json(), mock_connection) assert len(connections.channel_subscriptions) == 0 ws_manager.disconnect(mock_connection) diff --git a/tests/integration/filesystem_blueprint/test_model.py b/tests/integration/filesystem_blueprint/test_model.py index 3c21340363..aff018d00b 100644 --- a/tests/integration/filesystem_blueprint/test_model.py +++ b/tests/integration/filesystem_blueprint/test_model.py @@ -16,7 +16,7 @@ def test_init(self) -> None: "common": "/path/to/workspaces/common_studies", }, } - dto = FilesystemDTO.parse_obj(example) + dto = FilesystemDTO.model_validate(example) assert dto.name == example["name"] assert dto.mount_dirs["default"] == Path(example["mount_dirs"]["default"]) assert dto.mount_dirs["common"] == Path(example["mount_dirs"]["common"]) @@ -32,7 +32,7 @@ def test_init(self) -> None: "free_bytes": 1e9 - 0.6e9, "message": f"{0.6e9 / 1e9:%} used", } - dto = MountPointDTO.parse_obj(example) + dto = MountPointDTO.model_validate(example) assert dto.name == example["name"] assert dto.path == Path(example["path"]) assert dto.total_bytes == example["total_bytes"] @@ -75,7 +75,7 @@ def test_init(self) -> None: "accessed": "2024-01-11T17:54:09", "message": "OK", } - dto = FileInfoDTO.parse_obj(example) + dto = FileInfoDTO.model_validate(example) assert dto.path == Path(example["path"]) assert dto.file_type == example["file_type"] assert dto.file_count == example["file_count"] diff --git a/tests/integration/launcher_blueprint/test_launcher_local.py b/tests/integration/launcher_blueprint/test_launcher_local.py index 08d1175889..1ee93ea65b 100644 --- a/tests/integration/launcher_blueprint/test_launcher_local.py +++ b/tests/integration/launcher_blueprint/test_launcher_local.py @@ -64,10 +64,8 @@ def test_get_launcher_nb_cores( ) assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() actual = res.json() - assert actual == { - "description": "Unknown solver configuration: 'unknown'", - "exception": "UnknownSolverConfig", - } + assert actual["description"] == "Input should be 'slurm', 'local' or 'default'" + assert actual["exception"] == "RequestValidationError" def test_get_launcher_time_limit( self, @@ -118,7 +116,5 @@ def test_get_launcher_time_limit( ) assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() actual = res.json() - assert actual == { - "description": "Unknown solver configuration: 'unknown'", - "exception": "UnknownSolverConfig", - } + assert actual["description"] == "Input should be 'slurm', 'local' or 'default'" + assert actual["exception"] == "RequestValidationError" diff --git a/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py b/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py index 591791c8ee..37ef364d5c 100644 --- a/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py +++ b/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py @@ -7,14 +7,7 @@ import pytest from starlette.testclient import TestClient -from antarest.study.business.aggregator_management import ( - MCAllAreasQueryFile, - MCAllLinksQueryFile, - MCIndAreasQueryFile, - MCIndLinksQueryFile, -) from antarest.study.storage.df_download import TableExportFormat -from antarest.study.storage.rawstudy.model.filesystem.matrix.matrix import MatrixFrequency from tests.integration.raw_studies_blueprint.assets import ASSETS_DIR # define the requests parameters for the `economy/mc-ind` outputs aggregation @@ -22,8 +15,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "", "areas_ids": "", "columns_names": "", @@ -33,8 +26,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.HOURLY, + "query_file": "details", + "frequency": "hourly", "mc_years": "1", "areas_ids": "de,fr,it", "columns_names": "", @@ -44,8 +37,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.WEEKLY, + "query_file": "values", + "frequency": "weekly", "mc_years": "1,2", "areas_ids": "", "columns_names": "OP. COST,MRG. PRICE", @@ -55,8 +48,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "2", "areas_ids": "es,fr,de", "columns_names": "", @@ -66,8 +59,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.ANNUAL, + "query_file": "values", + "frequency": "annual", "mc_years": "", "areas_ids": "", "columns_names": "", @@ -77,8 +70,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "COSt,NODu", }, "test-06.result.tsv", @@ -86,8 +79,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.HOURLY, + "query_file": "details", + "frequency": "hourly", "columns_names": "COSt,NODu", }, "test-07.result.tsv", @@ -98,8 +91,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "", "columns_names": "", }, @@ -108,8 +101,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "1", "columns_names": "", }, @@ -118,8 +111,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "1,2", "columns_names": "UCAP LIn.,FLOw qUAD.", }, @@ -128,8 +121,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "1", "links_ids": "de - fr", }, @@ -138,8 +131,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "MArG. COsT,CONG. PRoB +", }, "test-05.result.tsv", @@ -150,8 +143,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "csv", }, "test-01.result.tsv", @@ -159,8 +152,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "tsv", }, "test-01.result.tsv", @@ -168,8 +161,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "xlsx", }, "test-01.result.tsv", @@ -180,20 +173,20 @@ INCOHERENT_REQUESTS_BODIES__IND = [ { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "123456789", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "fake_col", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "links_ids": "fake_id", }, ] @@ -202,17 +195,17 @@ { "output_id": "20201014-1425eco-goodbye", "query_file": "fake_query_file", - "frequency": MatrixFrequency.HOURLY, + "frequency": "hourly", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, + "query_file": "values", "frequency": "fake_frequency", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "fake_format", }, ] @@ -222,8 +215,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "areas_ids": "", "columns_names": "", }, @@ -232,8 +225,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "details", + "frequency": "monthly", "areas_ids": "de,fr,it", "columns_names": "", }, @@ -242,8 +235,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "areas_ids": "", "columns_names": "OP. CoST,MRG. PrICE", }, @@ -252,8 +245,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "areas_ids": "es,fr,de", "columns_names": "", }, @@ -262,8 +255,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "areas_ids": "", "columns_names": "", }, @@ -272,8 +265,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.ID, - "frequency": MatrixFrequency.DAILY, + "query_file": "id", + "frequency": "daily", "areas_ids": "", "columns_names": "", }, @@ -282,8 +275,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "COsT,NoDU", }, "test-07-all.result.tsv", @@ -291,8 +284,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "details", + "frequency": "monthly", "columns_names": "COsT,NoDU", }, "test-08-all.result.tsv", @@ -303,8 +296,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "", }, "test-01-all.result.tsv", @@ -312,8 +305,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "columns_names": "", }, "test-02-all.result.tsv", @@ -321,8 +314,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "", }, "test-03-all.result.tsv", @@ -330,8 +323,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "links_ids": "de - fr", }, "test-04-all.result.tsv", @@ -339,8 +332,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.ID, - "frequency": MatrixFrequency.DAILY, + "query_file": "id", + "frequency": "daily", "links_ids": "", }, "test-05-all.result.tsv", @@ -348,8 +341,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "MARG. COsT,CONG. ProB +", }, "test-06-all.result.tsv", @@ -360,8 +353,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "csv", }, "test-01-all.result.tsv", @@ -369,8 +362,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "tsv", }, "test-01-all.result.tsv", @@ -378,8 +371,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "xlsx", }, "test-01-all.result.tsv", @@ -390,19 +383,19 @@ INCOHERENT_REQUESTS_BODIES__ALL = [ { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", }, { "output_id": "20201014-1427eco", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "fake_col", }, { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "links_ids": "fake_id", }, ] @@ -411,17 +404,17 @@ { "output_id": "20201014-1427eco", "query_file": "fake_query_file", - "frequency": MatrixFrequency.MONTHLY, + "frequency": "monthly", }, { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, + "query_file": "values", "frequency": "fake_frequency", }, { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "fake_format", }, ] @@ -562,8 +555,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-ind/unknown_id", params={ - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -574,8 +567,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-ind/unknown_id", params={ - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -592,8 +585,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-ind/20201014-1425eco-goodbye", params={ - "query_file": MCIndAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.HOURLY, + "query_file": "details", + "frequency": "hourly", "columns_names": "fake_col", }, ) @@ -605,8 +598,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-ind/20201014-1425eco-goodbye", params={ - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "fake_col", }, ) @@ -628,7 +621,7 @@ def test_non_existing_folder( shutil.rmtree(mc_ind_folder) res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-ind/20201014-1425eco-goodbye", - params={"query_file": MCIndAreasQueryFile.VALUES, "frequency": MatrixFrequency.HOURLY}, + params={"query_file": "values", "frequency": "hourly"}, ) assert res.status_code == 404, res.json() assert "economy/mc-ind" in res.json()["description"] @@ -770,8 +763,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-all/unknown_id", params={ - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -782,8 +775,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-all/unknown_id", params={ - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -801,8 +794,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-all/20201014-1427eco", params={ - "query_file": MCAllAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "details", + "frequency": "monthly", "columns_names": "fake_col", }, ) @@ -814,8 +807,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-all/20241807-1540eco-extra-outputs", params={ - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "fake_col", }, ) @@ -835,7 +828,7 @@ def test_non_existing_folder( shutil.rmtree(mc_all_path) res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-all/20241807-1540eco-extra-outputs", - params={"query_file": MCAllLinksQueryFile.VALUES, "frequency": MatrixFrequency.DAILY}, + params={"query_file": "values", "frequency": "daily"}, ) assert res.status_code == 404, res.json() assert "economy/mc-all" in res.json()["description"] diff --git a/tests/integration/raw_studies_blueprint/test_download_matrices.py b/tests/integration/raw_studies_blueprint/test_download_matrices.py index c491108f5c..bcd8652e76 100644 --- a/tests/integration/raw_studies_blueprint/test_download_matrices.py +++ b/tests/integration/raw_studies_blueprint/test_download_matrices.py @@ -347,7 +347,7 @@ def test_download_matrices(self, client: TestClient, user_access_token: str, int for export_format in ["tsv", "xlsx"]: res = client.get( f"/v1/studies/{study_860_id}/raw/download", - params={"path": "input/hydro/series/de/mingen", "format": {export_format}}, + params={"path": "input/hydro/series/de/mingen", "format": export_format}, headers=user_headers, ) assert res.status_code == 200 diff --git a/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json b/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json index 80ffa29e32..1e0f3ada52 100644 --- a/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json +++ b/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json @@ -38,22 +38,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -77,22 +62,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -116,22 +86,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -155,22 +110,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -194,22 +134,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -233,22 +158,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -272,22 +182,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -311,22 +206,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -350,22 +230,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -413,22 +278,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -452,22 +302,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -491,22 +326,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -530,22 +350,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -569,22 +374,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -608,22 +398,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -647,22 +422,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -686,22 +446,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -725,22 +470,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -788,22 +518,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -827,22 +542,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -866,22 +566,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -905,22 +590,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -944,22 +614,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -983,22 +638,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1022,22 +662,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1061,22 +686,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1100,22 +710,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -1151,22 +746,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -1190,22 +770,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -1229,22 +794,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -1268,22 +818,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -1307,22 +842,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -1346,22 +866,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1385,22 +890,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1424,22 +914,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1463,22 +938,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], diff --git a/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json b/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json index a77fa18a58..7e449747e4 100644 --- a/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json +++ b/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json @@ -38,22 +38,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -77,22 +62,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -116,22 +86,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -155,22 +110,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -194,22 +134,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -233,22 +158,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -272,22 +182,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -311,22 +206,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -350,22 +230,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -413,22 +278,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -452,22 +302,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -491,22 +326,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -530,22 +350,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -569,22 +374,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -608,22 +398,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -647,22 +422,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -686,22 +446,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -725,22 +470,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -788,22 +518,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -827,22 +542,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -866,22 +566,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -905,22 +590,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -944,22 +614,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -983,22 +638,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1022,22 +662,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1061,22 +686,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1100,22 +710,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -1151,22 +746,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -1190,22 +770,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -1229,22 +794,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -1268,22 +818,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -1307,22 +842,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -1346,22 +866,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1385,22 +890,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1424,22 +914,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1463,22 +938,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], diff --git a/tests/integration/studies_blueprint/test_comments.py b/tests/integration/studies_blueprint/test_comments.py index 39ce84e35b..6bfccb04d5 100644 --- a/tests/integration/studies_blueprint/test_comments.py +++ b/tests/integration/studies_blueprint/test_comments.py @@ -26,12 +26,9 @@ def test_raw_study( This test verifies that we can retrieve and modify the comments of a study. It also performs performance measurements and analyzes. """ - + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/comments") assert res.status_code == 200, res.json() actual = res.json() actual_xml = ElementTree.parse(io.StringIO(actual)).getroot() @@ -40,10 +37,7 @@ def test_raw_study( # Ensure the duration is relatively short start = time.time() - res = client.get( - f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/comments") assert res.status_code == 200, res.json() duration = time.time() - start assert 0 <= duration <= 0.1, f"Duration is {duration} seconds" @@ -51,16 +45,12 @@ def test_raw_study( # Update the comments of the study res = client.put( f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"comments": "Ceci est un commentaire en français."}, ) assert res.status_code == 204, res.json() # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/comments") assert res.status_code == 200, res.json() assert res.json() == "Ceci est un commentaire en français." @@ -74,10 +64,10 @@ def test_variant_study( This test verifies that we can retrieve and modify the comments of a VARIANT study. It also performs performance measurements and analyzes. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # First, we create a copy of the study, and we convert it to a managed study. res = client.post( f"/v1/studies/{internal_study_id}/copy", - headers={"Authorization": f"Bearer {user_access_token}"}, params={"dest": "default", "with_outputs": False, "use_task": False}, # type: ignore ) assert res.status_code == 201, res.json() @@ -85,20 +75,13 @@ def test_variant_study( assert base_study_id is not None # Then, we create a new variant of the base study - res = client.post( - f"/v1/studies/{base_study_id}/variants", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"name": "Variant XYZ"}, - ) + res = client.post(f"/v1/studies/{base_study_id}/variants", params={"name": "Variant XYZ"}) assert res.status_code == 200, res.json() # should be CREATED variant_id = res.json() assert variant_id is not None # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/comments") assert res.status_code == 200, res.json() actual = res.json() actual_xml = ElementTree.parse(io.StringIO(actual)).getroot() @@ -107,10 +90,7 @@ def test_variant_study( # Ensure the duration is relatively short start = time.time() - res = client.get( - f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/comments") assert res.status_code == 200, res.json() duration = time.time() - start assert 0 <= duration <= 0.3, f"Duration is {duration} seconds" @@ -118,15 +98,11 @@ def test_variant_study( # Update the comments of the study res = client.put( f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"comments": "Ceci est un commentaire en français."}, ) assert res.status_code == 204, res.json() # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/comments") assert res.status_code == 200, res.json() assert res.json() == "Ceci est un commentaire en français." diff --git a/tests/integration/studies_blueprint/test_disk_usage.py b/tests/integration/studies_blueprint/test_disk_usage.py index d2257a40fd..885063428e 100644 --- a/tests/integration/studies_blueprint/test_disk_usage.py +++ b/tests/integration/studies_blueprint/test_disk_usage.py @@ -69,7 +69,7 @@ def test_disk_usage_endpoint( # Wait for task completion res = client.get(f"/v1/tasks/{task_id}", headers=user_headers, params={"wait_for_completion": True}) assert res.status_code == 200 - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result is not None assert task_result.result.success diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index af8f790f20..4e58929d04 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -1471,7 +1471,7 @@ def test_get_studies__invalid_parameters( res = client.get(STUDIES_URL, headers=headers, params={"sortBy": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"not a valid enumeration member", description), f"{description=}" + assert re.search("Input should be", description), f"{description=}" # Invalid `pageNb` parameter (negative integer) res = client.get(STUDIES_URL, headers=headers, params={"pageNb": -1}) @@ -1483,7 +1483,7 @@ def test_get_studies__invalid_parameters( res = client.get(STUDIES_URL, headers=headers, params={"pageNb": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"not a valid integer", description), f"{description=}" + assert re.search(r"should be a valid integer", description), f"{description=}" # Invalid `pageSize` parameter (negative integer) res = client.get(STUDIES_URL, headers=headers, params={"pageSize": -1}) @@ -1495,43 +1495,43 @@ def test_get_studies__invalid_parameters( res = client.get(STUDIES_URL, headers=headers, params={"pageSize": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"not a valid integer", description), f"{description=}" + assert re.search(r"should be a valid integer", description), f"{description=}" # Invalid `managed` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"managed": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" # Invalid `archived` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"archived": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" # Invalid `variant` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"variant": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" # Invalid `versions` parameter (not a list of integers) res = client.get(STUDIES_URL, headers=headers, params={"versions": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"string does not match regex", description), f"{description=}" + assert re.search(r"String should match pattern", description), f"{description=}" # Invalid `users` parameter (not a list of integers) res = client.get(STUDIES_URL, headers=headers, params={"users": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"string does not match regex", description), f"{description=}" + assert re.search(r"String should match pattern", description), f"{description=}" # Invalid `exists` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"exists": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" def test_studies_counting(client: TestClient, admin_access_token: str, user_access_token: str) -> None: diff --git a/tests/integration/studies_blueprint/test_update_tags.py b/tests/integration/studies_blueprint/test_update_tags.py index 9ee37c7d70..8dd389cd5a 100644 --- a/tests/integration/studies_blueprint/test_update_tags.py +++ b/tests/integration/studies_blueprint/test_update_tags.py @@ -16,14 +16,10 @@ def test_update_tags( This test verifies that we can update the tags of a study. It also tests the tags normalization. """ - + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Classic usage: set some tags to a study study_tags = ["Tag1", "Tag2"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) == set(study_tags) @@ -32,11 +28,7 @@ def test_update_tags( # - "Tag1" is preserved, but with the same case as the existing one. # - "Tag2" is replaced by "Tag3". study_tags = ["tag1", "Tag3"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) != set(study_tags) # not the same case @@ -45,22 +37,14 @@ def test_update_tags( # String normalization: whitespaces are stripped and # consecutive whitespaces are replaced by a single one. study_tags = [" \xa0Foo \t Bar \n ", " \t Baz\xa0\xa0"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) == {"Foo Bar", "Baz"} # We can have symbols in the tags study_tags = ["Foo-Bar", ":Baz%"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) == {"Foo-Bar", ":Baz%"} @@ -71,13 +55,10 @@ def test_update_tags__invalid_tags( user_access_token: str, internal_study_id: str, ) -> None: + client.headers = {"Authorization": f"Bearer {user_access_token}"} # We cannot have empty tags study_tags = [""] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 422, res.json() description = res.json()["description"] assert "Tag cannot be empty" in description @@ -85,11 +66,7 @@ def test_update_tags__invalid_tags( # We cannot have tags longer than 40 characters study_tags = ["very long tags, very long tags, very long tags"] assert len(study_tags[0]) > 40 - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 422, res.json() description = res.json()["description"] assert "Tag is too long" in description diff --git a/tests/integration/study_data_blueprint/test_advanced_parameters.py b/tests/integration/study_data_blueprint/test_advanced_parameters.py index 4aff92b0cd..9e22f92742 100644 --- a/tests/integration/study_data_blueprint/test_advanced_parameters.py +++ b/tests/integration/study_data_blueprint/test_advanced_parameters.py @@ -96,7 +96,7 @@ def test_set_advanced_parameters_values( ) assert res.status_code == 422 assert res.json()["exception"] == "RequestValidationError" - assert res.json()["description"] == "Invalid value: fake_correlation" + assert res.json()["description"] == "Value error, Invalid value: fake_correlation" obj = {"unitCommitmentMode": "milp"} res = client.put( diff --git a/tests/integration/study_data_blueprint/test_binding_constraints.py b/tests/integration/study_data_blueprint/test_binding_constraints.py index 79b5d5cbd0..8d5366c745 100644 --- a/tests/integration/study_data_blueprint/test_binding_constraints.py +++ b/tests/integration/study_data_blueprint/test_binding_constraints.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd import pytest -from requests.exceptions import HTTPError +from httpx._exceptions import HTTPError from starlette.testclient import TestClient from antarest.study.business.binding_constraint_management import ClusterTerm, ConstraintTerm, LinkTerm @@ -307,7 +307,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st assert res.status_code == 422, res.json() assert res.json() == { "body": {"data": {}, "id": f"{area1_id}.{cluster_id}"}, - "description": "field required", + "description": "Field required", "exception": "RequestValidationError", } diff --git a/tests/integration/study_data_blueprint/test_config_general.py b/tests/integration/study_data_blueprint/test_config_general.py index 3084a86b4e..e327de11a0 100644 --- a/tests/integration/study_data_blueprint/test_config_general.py +++ b/tests/integration/study_data_blueprint/test_config_general.py @@ -33,7 +33,7 @@ def test_get_general_form_values( "firstJanuary": "Monday", "firstMonth": "january", "firstWeekDay": "Monday", - "horizon": "2030", + "horizon": 2030, "lastDay": 7, "leapYear": False, "mcScenario": True, diff --git a/tests/integration/study_data_blueprint/test_renewable.py b/tests/integration/study_data_blueprint/test_renewable.py index b6c450e8f3..a4b3375ac7 100644 --- a/tests/integration/study_data_blueprint/test_renewable.py +++ b/tests/integration/study_data_blueprint/test_renewable.py @@ -24,7 +24,6 @@ * validate the consistency of the matrices (and properties) """ -import json import re import typing as t @@ -38,7 +37,7 @@ from antarest.study.storage.rawstudy.model.filesystem.config.renewable import RenewableProperties from tests.integration.utils import wait_task_completion -DEFAULT_PROPERTIES = json.loads(RenewableProperties(name="Dummy").json()) +DEFAULT_PROPERTIES = RenewableProperties(name="Dummy").model_dump(mode="json") DEFAULT_PROPERTIES = {to_camel_case(k): v for k, v in DEFAULT_PROPERTIES.items() if k != "name"} # noinspection SpellCheckingInspection @@ -525,13 +524,10 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var In this test, we want to check that renewable clusters can be managed in the context of a "variant" study. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Create an area area_name = "France" - res = client.post( - f"/v1/studies/{variant_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": area_name, "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{variant_id}/areas", json={"name": area_name, "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_cfg = res.json() area_id = area_cfg["id"] @@ -540,7 +536,6 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var cluster_name = "Th1" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": cluster_name, "group": "Wind Offshore", @@ -553,9 +548,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Update the renewable cluster res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"unitCount": 15}, + f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{cluster_id}", json={"unitCount": 15} ) assert res.status_code == 200, res.json() cluster_cfg = res.json() @@ -565,19 +558,13 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() matrix_path = f"input/renewables/series/{area_id}/{cluster_id.lower()}/series" args = {"target": matrix_path, "matrix": matrix} - res = client.post( - f"/v1/studies/{variant_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{variant_id}/commands", json=[{"action": "replace_matrix", "args": args}]) assert res.status_code in {200, 201}, res.json() # Duplicate the renewable cluster new_name = "Th2" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/renewables/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": new_name}, + f"/v1/studies/{variant_id}/areas/{area_id}/renewables/{cluster_id}", params={"newName": new_name} ) assert res.status_code in {200, 201}, res.json() cluster_cfg = res.json() @@ -585,10 +572,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var new_id = cluster_cfg["id"] # Check that the duplicate has the right properties - res = client.get( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{new_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{new_id}") assert res.status_code == 200, res.json() cluster_cfg = res.json() assert cluster_cfg["group"] == "Wind Offshore" @@ -597,27 +581,19 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Check that the duplicate has the right matrix new_cluster_matrix_path = f"input/renewables/series/{area_id}/{new_id.lower()}/series" - res = client.get( - f"/v1/studies/{variant_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix # Delete the renewable cluster - res = client.delete( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[cluster_id], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable", json=[cluster_id] ) assert res.status_code == 204, res.json() # Check the list of variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 7 diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index 68fe46b138..07a1e4bf6c 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -1,4 +1,3 @@ -import json import re import typing as t from unittest.mock import ANY @@ -19,11 +18,11 @@ _ST_STORAGE_OUTPUT_860 = create_storage_output(860, cluster_id="dummy", config={"name": "dummy"}) _ST_STORAGE_OUTPUT_880 = create_storage_output(880, cluster_id="dummy", config={"name": "dummy"}) -DEFAULT_CONFIG_860 = json.loads(_ST_STORAGE_860_CONFIG.json(by_alias=True, exclude={"id", "name"})) -DEFAULT_CONFIG_880 = json.loads(_ST_STORAGE_880_CONFIG.json(by_alias=True, exclude={"id", "name"})) +DEFAULT_CONFIG_860 = _ST_STORAGE_860_CONFIG.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) +DEFAULT_CONFIG_880 = _ST_STORAGE_880_CONFIG.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) -DEFAULT_OUTPUT_860 = json.loads(_ST_STORAGE_OUTPUT_860.json(by_alias=True, exclude={"id", "name"})) -DEFAULT_OUTPUT_880 = json.loads(_ST_STORAGE_OUTPUT_880.json(by_alias=True, exclude={"id", "name"})) +DEFAULT_OUTPUT_860 = _ST_STORAGE_OUTPUT_860.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) +DEFAULT_OUTPUT_880 = _ST_STORAGE_OUTPUT_880.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) # noinspection SpellCheckingInspection @@ -313,7 +312,7 @@ def test_lifecycle__nominal( json=[siemens_battery_id], ) assert res.status_code == 204, res.json() - assert res.text in {"", "null"} # Old FastAPI versions return 'null'. + assert not res.text # If the short-term storage list is empty, the deletion should be a no-op. res = client.request( @@ -323,7 +322,7 @@ def test_lifecycle__nominal( json=[], ) assert res.status_code == 204, res.json() - assert res.text in {"", "null"} # Old FastAPI versions return 'null'. + assert not res.text # It's possible to delete multiple short-term storages at once. # In the following example, we will create two short-term storages: @@ -383,7 +382,7 @@ def test_lifecycle__nominal( json=[grand_maison_id, duplicated_output["id"]], ) assert res.status_code == 204, res.json() - assert res.text in {"", "null"} # Old FastAPI versions return 'null'. + assert not res.text # Only one st-storage should remain. res = client.get( @@ -475,7 +474,7 @@ def test_lifecycle__nominal( assert res.status_code == 422, res.json() obj = res.json() description = obj["description"] - assert re.search(r"not a valid enumeration member", description, flags=re.IGNORECASE) + assert re.search(r"Input should be", description) # Check PATCH with the wrong `area_id` res = client.patch( @@ -577,40 +576,26 @@ def test__default_values( Then the short-term storage is created with initialLevel = 0.0, and initialLevelOptim = False. """ # Create a new study in version 860 (or higher) - user_headers = {"Authorization": f"Bearer {user_access_token}"} - res = client.post( - "/v1/studies", - headers=user_headers, - params={"name": "MyStudy", "version": study_version}, - ) + client.headers = {"Authorization": f"Bearer {user_access_token}"} + res = client.post("/v1/studies", params={"name": "MyStudy", "version": study_version}) assert res.status_code in {200, 201}, res.json() study_id = res.json() if study_type == "variant": # Create Variant - res = client.post( - f"/v1/studies/{study_id}/variants", - headers=user_headers, - params={"name": "Variant 1"}, - ) + res = client.post(f"/v1/studies/{study_id}/variants", params={"name": "Variant 1"}) assert res.status_code in {200, 201}, res.json() study_id = res.json() # Create a new area named "FR" - res = client.post( - f"/v1/studies/{study_id}/areas", - headers=user_headers, - json={"name": "FR", "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{study_id}/areas", json={"name": "FR", "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_id = res.json()["id"] # Create a new short-term storage named "Tesla Battery" tesla_battery = "Tesla Battery" res = client.post( - f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers=user_headers, - json={"name": tesla_battery, "group": "Battery"}, + f"/v1/studies/{study_id}/areas/{area_id}/storages", json={"name": tesla_battery, "group": "Battery"} ) assert res.status_code == 200, res.json() tesla_battery_id = res.json()["id"] @@ -621,7 +606,6 @@ def test__default_values( # are properly set in the configuration file. res = client.get( f"/v1/studies/{study_id}/raw", - headers=user_headers, params={"path": f"input/st-storage/clusters/{area_id}/list/{tesla_battery_id}"}, ) assert res.status_code == 200, res.json() @@ -634,28 +618,19 @@ def test__default_values( # in the variant commands. # Create a variant of the study - res = client.post( - f"/v1/studies/{study_id}/variants", - headers=user_headers, - params={"name": "MyVariant"}, - ) + res = client.post(f"/v1/studies/{study_id}/variants", params={"name": "MyVariant"}) assert res.status_code in {200, 201}, res.json() variant_id = res.json() # Create a new short-term storage named "Siemens Battery" siemens_battery = "Siemens Battery" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers=user_headers, - json={"name": siemens_battery, "group": "Battery"}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages", json={"name": siemens_battery, "group": "Battery"} ) assert res.status_code == 200, res.json() # Check the variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers=user_headers, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 1 @@ -679,17 +654,12 @@ def test__default_values( # Update the initialLevel property of the "Siemens Battery" short-term storage to 0.5 siemens_battery_id = transform_name_to_id(siemens_battery) res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers=user_headers, - json={"initialLevel": 0.5}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", json={"initialLevel": 0.5} ) assert res.status_code == 200, res.json() # Check the variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers=user_headers, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 2 @@ -698,7 +668,7 @@ def test__default_values( "id": ANY, "action": "update_config", "args": { - "data": "0.5", + "data": 0.5, "target": "input/st-storage/clusters/fr/list/siemens battery/initiallevel", }, "version": 1, @@ -708,16 +678,12 @@ def test__default_values( # Update the initialLevel property of the "Siemens Battery" short-term storage back to 0 res = client.patch( f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers=user_headers, json={"initialLevel": 0.0, "injectionNominalCapacity": 1600}, ) assert res.status_code == 200, res.json() # Check the variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers=user_headers, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 3 @@ -727,11 +693,11 @@ def test__default_values( "action": "update_config", "args": [ { - "data": "1600.0", + "data": 1600.0, "target": "input/st-storage/clusters/fr/list/siemens battery/injectionnominalcapacity", }, { - "data": "0.0", + "data": 0.0, "target": "input/st-storage/clusters/fr/list/siemens battery/initiallevel", }, ], @@ -743,7 +709,6 @@ def test__default_values( # are properly set in the configuration file. res = client.get( f"/v1/studies/{variant_id}/raw", - headers=user_headers, params={"path": f"input/st-storage/clusters/{area_id}/list/{siemens_battery_id}"}, ) assert res.status_code == 200, res.json() @@ -791,13 +756,10 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var In this test, we want to check that short-term storages can be managed in the context of a "variant" study. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Create an area area_name = "France" - res = client.post( - f"/v1/studies/{variant_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": area_name, "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{variant_id}/areas", json={"name": area_name, "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_cfg = res.json() area_id = area_cfg["id"] @@ -806,7 +768,6 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var cluster_name = "Tesla1" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": cluster_name, "group": "Battery", @@ -820,9 +781,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Update the short-term storage res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"reservoirCapacity": 5600}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", json={"reservoirCapacity": 5600} ) assert res.status_code == 200, res.json() cluster_cfg = res.json() @@ -832,19 +791,13 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() matrix_path = f"input/st-storage/series/{area_id}/{cluster_id.lower()}/pmax_injection" args = {"target": matrix_path, "matrix": matrix} - res = client.post( - f"/v1/studies/{variant_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{variant_id}/commands", json=[{"action": "replace_matrix", "args": args}]) assert res.status_code in {200, 201}, res.json() # Duplicate the short-term storage new_name = "Tesla2" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": new_name}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", params={"newName": new_name} ) assert res.status_code in {200, 201}, res.json() cluster_cfg = res.json() @@ -852,10 +805,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var new_id = cluster_cfg["id"] # Check that the duplicate has the right properties - res = client.get( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{new_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/areas/{area_id}/storages/{new_id}") assert res.status_code == 200, res.json() cluster_cfg = res.json() assert cluster_cfg["group"] == "Battery" @@ -865,27 +815,19 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Check that the duplicate has the right matrix new_cluster_matrix_path = f"input/st-storage/series/{area_id}/{new_id.lower()}/pmax_injection" - res = client.get( - f"/v1/studies/{variant_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix # Delete the short-term storage - res = client.delete( - f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[cluster_id], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{variant_id}/areas/{area_id}/storages", json=[cluster_id] ) assert res.status_code == 204, res.json() # Check the list of variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 7 diff --git a/tests/integration/study_data_blueprint/test_thermal.py b/tests/integration/study_data_blueprint/test_thermal.py index 33e9aef03d..606217b800 100644 --- a/tests/integration/study_data_blueprint/test_thermal.py +++ b/tests/integration/study_data_blueprint/test_thermal.py @@ -28,7 +28,6 @@ * validate the consistency of the matrices (and properties) """ import io -import json import re import typing as t @@ -42,7 +41,7 @@ from antarest.study.storage.rawstudy.model.filesystem.config.thermal import ThermalProperties from tests.integration.utils import wait_task_completion -DEFAULT_PROPERTIES = json.loads(ThermalProperties(name="Dummy").json()) +DEFAULT_PROPERTIES = ThermalProperties(name="Dummy").model_dump(mode="json") DEFAULT_PROPERTIES = {to_camel_case(k): v for k, v in DEFAULT_PROPERTIES.items() if k != "name"} # noinspection SpellCheckingInspection @@ -499,7 +498,6 @@ def test_lifecycle( json={"nox": 10.0}, ) assert res.status_code == 200 - assert res.json()["nox"] == 10.0 # Update with the field `efficiency`. Should succeed even with versions prior to v8.7 res = client.patch( @@ -508,7 +506,6 @@ def test_lifecycle( json={"efficiency": 97.0}, ) assert res.status_code == 200 - assert res.json()["efficiency"] == 97.0 # ============================= # THERMAL CLUSTER DUPLICATION @@ -937,13 +934,10 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var In this test, we want to check that thermal clusters can be managed in the context of a "variant" study. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Create an area area_name = "France" - res = client.post( - f"/v1/studies/{variant_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": area_name, "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{variant_id}/areas", json={"name": area_name, "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_cfg = res.json() area_id = area_cfg["id"] @@ -952,7 +946,6 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var cluster_name = "Th1" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": cluster_name, "group": "Nuclear", @@ -966,11 +959,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Update the thermal cluster res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "marginalCost": 0.2, - }, + f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{cluster_id}", json={"marginalCost": 0.2} ) assert res.status_code == 200, res.json() cluster_cfg = res.json() @@ -980,19 +969,13 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() matrix_path = f"input/thermal/prepro/{area_id}/{cluster_id.lower()}/data" args = {"target": matrix_path, "matrix": matrix} - res = client.post( - f"/v1/studies/{variant_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{variant_id}/commands", json=[{"action": "replace_matrix", "args": args}]) assert res.status_code in {200, 201}, res.json() # Duplicate the thermal cluster new_name = "Th2" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/thermals/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": new_name}, + f"/v1/studies/{variant_id}/areas/{area_id}/thermals/{cluster_id}", params={"newName": new_name} ) assert res.status_code in {200, 201}, res.json() cluster_cfg = res.json() @@ -1000,10 +983,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var new_id = cluster_cfg["id"] # Check that the duplicate has the right properties - res = client.get( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{new_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{new_id}") assert res.status_code == 200, res.json() cluster_cfg = res.json() assert cluster_cfg["group"] == "Nuclear" @@ -1013,27 +993,19 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Check that the duplicate has the right matrix new_cluster_matrix_path = f"input/thermal/prepro/{area_id}/{new_id.lower()}/data" - res = client.get( - f"/v1/studies/{variant_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix # Delete the thermal cluster - res = client.delete( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[cluster_id], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal", json=[cluster_id] ) assert res.status_code == 204, res.json() # Check the list of variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 7 @@ -1170,9 +1142,9 @@ def test_thermal_cluster_deletion(self, client: TestClient, user_access_token: s assert res.status_code == 200, res.json() # check that deleting the thermal cluster in area_1 fails - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", - json=["cluster_1"], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", json=["cluster_1"] ) assert res.status_code == 403, res.json() @@ -1183,16 +1155,14 @@ def test_thermal_cluster_deletion(self, client: TestClient, user_access_token: s assert res.status_code == 200, res.json() # check that deleting the thermal cluster in area_1 succeeds - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", - json=["cluster_1"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", json=["cluster_1"] ) assert res.status_code == 204, res.json() # check that deleting the thermal cluster in area_2 fails - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", - json=["cluster_2"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", json=["cluster_2"] ) assert res.status_code == 403, res.json() @@ -1203,15 +1173,13 @@ def test_thermal_cluster_deletion(self, client: TestClient, user_access_token: s assert res.status_code == 200, res.json() # check that deleting the thermal cluster in area_2 succeeds - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", - json=["cluster_2"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", json=["cluster_2"] ) assert res.status_code == 204, res.json() # check that deleting the thermal cluster in area_3 succeeds - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_3/clusters/thermal", - json=["cluster_3"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_3/clusters/thermal", json=["cluster_3"] ) assert res.status_code == 204, res.json() diff --git a/tests/integration/test_apidoc.py b/tests/integration/test_apidoc.py index 562f7ccb8e..4216c10b9a 100644 --- a/tests/integration/test_apidoc.py +++ b/tests/integration/test_apidoc.py @@ -1,11 +1,12 @@ -from fastapi.openapi.utils import get_flat_models_from_routes -from fastapi.utils import get_model_definitions -from pydantic.schema import get_model_name_map from starlette.testclient import TestClient +from antarest import __version__ + def test_apidoc(client: TestClient) -> None: - # Asserts that the apidoc can be loaded - flat_models = get_flat_models_from_routes(client.app.routes) - model_name_map = get_model_name_map(flat_models) - get_model_definitions(flat_models=flat_models, model_name_map=model_name_map) + # Local import to avoid breaking all tests if FastAPI changes its API + from fastapi.openapi.utils import get_openapi + + routes = client.app.routes + openapi = get_openapi(title="Antares Web", version=__version__, routes=routes) + assert openapi diff --git a/tests/integration/test_core_blueprint.py b/tests/integration/test_core_blueprint.py index 61a0fcbe3a..60949ea8c9 100644 --- a/tests/integration/test_core_blueprint.py +++ b/tests/integration/test_core_blueprint.py @@ -1,4 +1,3 @@ -import http import re from unittest import mock @@ -36,23 +35,3 @@ def test_version_info(self, app: FastAPI): "dependencies": mock.ANY, } assert actual == expected - - -class TestKillWorker: - def test_kill_worker__not_granted(self, app: FastAPI): - client = TestClient(app, raise_server_exceptions=False) - res = client.get("/kill") - assert res.status_code == http.HTTPStatus.UNAUTHORIZED, res.json() - assert res.json() == {"detail": "Missing cookie access_token_cookie"} - - def test_kill_worker__nominal_case(self, app: FastAPI): - client = TestClient(app, raise_server_exceptions=False) - # login as "admin" - res = client.post("/v1/login", json={"username": "admin", "password": "admin"}) - res.raise_for_status() - credentials = res.json() - admin_access_token = credentials["access_token"] - # kill the worker - res = client.get("/kill", headers={"Authorization": f"Bearer {admin_access_token}"}) - assert res.status_code == 500, res.json() - assert not res.content diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 2119f96658..8c50801404 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -440,7 +440,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: }, ) - client.post( + res = client.post( f"/v1/studies/{study_id}/commands", json=[ { @@ -453,6 +453,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: } ], ) + res.raise_for_status() client.post( f"/v1/studies/{study_id}/commands", @@ -594,13 +595,14 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: }, ] - client.post( + res = client.post( f"/v1/studies/{study_id}/links", json={ "area1": "area 1", "area2": "area 2", }, ) + res.raise_for_status() res_links = client.get(f"/v1/studies/{study_id}/links?with_ui=true") assert res_links.json() == [ { @@ -613,15 +615,16 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: # -- `layers` integration tests res = client.get(f"/v1/studies/{study_id}/layers") - assert res.json() == [LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict()] + res.raise_for_status() + assert res.json() == [LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump()] res = client.post(f"/v1/studies/{study_id}/layers?name=test") assert res.json() == "1" res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), - LayerInfoDTO(id="1", name="test", areas=[]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), + LayerInfoDTO(id="1", name="test", areas=[]).model_dump(), ] res = client.put(f"/v1/studies/{study_id}/layers/1?name=test2") @@ -632,8 +635,8 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: assert res.status_code in {200, 201}, res.json() res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), - LayerInfoDTO(id="1", name="test2", areas=["area 2"]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), + LayerInfoDTO(id="1", name="test2", areas=["area 2"]).model_dump(), ] # Delete the layer '1' that has 1 area @@ -643,7 +646,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: # Ensure the layer is deleted res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), ] # Create the layer again without areas @@ -657,7 +660,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: # Ensure the layer is deleted res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), ] # Try to delete a non-existing layer @@ -743,7 +746,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: "simplexOptimizationRange": SimplexOptimizationRange.WEEK.value, } - client.put( + res = client.put( f"/v1/studies/{study_id}/config/optimization/form", json={ "strategicReserve": False, @@ -751,6 +754,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: "simplexOptimizationRange": SimplexOptimizationRange.DAY.value, }, ) + res.raise_for_status() res_optimization_config = client.get(f"/v1/studies/{study_id}/config/optimization/form") res_optimization_config_json = res_optimization_config.json() assert res_optimization_config_json == { @@ -813,7 +817,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: ) assert res.status_code == 422 assert res.json()["exception"] == "RequestValidationError" - assert res.json()["description"] == "value is not a valid integer" + assert res.json()["description"] == "Input should be a valid integer" # General form @@ -1257,7 +1261,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: }, "ntc": {"stochasticTsStatus": False, "intraModal": False}, } - res_ts_config = client.put( + client.put( f"/v1/studies/{study_id}/config/timeseries/form", json={ "thermal": {"stochasticTsStatus": True}, diff --git a/tests/integration/variant_blueprint/test_st_storage.py b/tests/integration/variant_blueprint/test_st_storage.py index c28af6790d..d24001689f 100644 --- a/tests/integration/variant_blueprint/test_st_storage.py +++ b/tests/integration/variant_blueprint/test_st_storage.py @@ -221,18 +221,5 @@ def test_lifecycle( ) assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY description = res.json()["description"] - """ - 4 validation errors for CreateSTStorage - parameters -> group - value is not a valid enumeration member […] - parameters -> injectionnominalcapacity - ensure this value is greater than or equal to 0 (type=value_error.number.not_ge; limit_value=0) - parameters -> initialleveloptim - value could not be parsed to a boolean (type=type_error.bool) - pmax_withdrawal - Matrix values should be between 0 and 1 (type=value_error) - """ - assert "parameters -> group" in description - assert "parameters -> injectionnominalcapacity" in description - assert "parameters -> initialleveloptim" in description - assert "pmax_withdrawal" in description + assert "Matrix values should be between 0 and 1" in description + assert "1 validation error for CreateSTStorage" in description diff --git a/tests/integration/variant_blueprint/test_thermal_cluster.py b/tests/integration/variant_blueprint/test_thermal_cluster.py index 245fd8a02a..914dbc3070 100644 --- a/tests/integration/variant_blueprint/test_thermal_cluster.py +++ b/tests/integration/variant_blueprint/test_thermal_cluster.py @@ -128,7 +128,7 @@ def test_cascade_update( ) assert res.status_code == http.HTTPStatus.OK, res.json() task = TaskDTO(**res.json()) - assert task.dict() == { + assert task.model_dump() == { "completion_date_utc": mock.ANY, "creation_date_utc": mock.ANY, "id": task_id, diff --git a/tests/integration/variant_blueprint/test_variant_manager.py b/tests/integration/variant_blueprint/test_variant_manager.py index 82fa7ab95c..a53b98aaad 100644 --- a/tests/integration/variant_blueprint/test_variant_manager.py +++ b/tests/integration/variant_blueprint/test_variant_manager.py @@ -187,7 +187,7 @@ def test_variant_manager( res = client.get(f"/v1/tasks/{res.json()}?wait_for_completion=true", headers=admin_headers) assert res.status_code == 200 - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result.success # type: ignore @@ -234,7 +234,7 @@ def test_comments(client: TestClient, admin_access_token: str, variant_id: str) # Wait for task completion res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True}) assert res.status_code == 200 - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result is not None assert task_result.result.success @@ -308,7 +308,7 @@ def test_outputs(client: TestClient, admin_access_token: str, variant_id: str, t # Wait for task completion res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True}) res.raise_for_status() - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result is not None assert task_result.result.success diff --git a/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py b/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py index b3a563a71e..aff978eff8 100644 --- a/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py +++ b/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py @@ -11,7 +11,6 @@ def _create_area( client: TestClient, - headers: t.Mapping[str, str], study_id: str, area_name: str, *, @@ -19,7 +18,6 @@ def _create_area( ) -> str: res = client.post( f"/v1/studies/{study_id}/areas", - headers=headers, json={"name": area_name, "type": "AREA", "metadata": {"country": country}}, ) assert res.status_code in {200, 201}, res.json() @@ -28,32 +26,28 @@ def _create_area( def _create_link( client: TestClient, - headers: t.Mapping[str, str], study_id: str, src_area_id: str, dst_area_id: str, ) -> None: - res = client.post( - f"/v1/studies/{study_id}/links", - headers=headers, - json={"area1": src_area_id, "area2": dst_area_id}, - ) + res = client.post(f"/v1/studies/{study_id}/links", json={"area1": src_area_id, "area2": dst_area_id}) assert res.status_code in {200, 201}, res.json() def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_token: str) -> None: headers = {"Authorization": f"Bearer {admin_access_token}"} + client.headers = headers - res = client.post("/v1/studies", headers=headers, params={"name": "foo", "version": "860"}) + res = client.post("/v1/studies", params={"name": "foo", "version": "860"}) assert res.status_code == 201, res.json() study_id = res.json() - area1_id = _create_area(client, headers, study_id, "area1", country="FR") - area2_id = _create_area(client, headers, study_id, "area2", country="DE") - area3_id = _create_area(client, headers, study_id, "area3", country="DE") - _create_link(client, headers, study_id, area1_id, area2_id) + area1_id = _create_area(client, study_id, "area1", country="FR") + area2_id = _create_area(client, study_id, "area2", country="DE") + area3_id = _create_area(client, study_id, "area3", country="DE") + _create_link(client, study_id, area1_id, area2_id) - res = client.post(f"/v1/studies/{study_id}/extensions/xpansion", headers=headers) + res = client.post(f"/v1/studies/{study_id}/extensions/xpansion") assert res.status_code in {200, 201}, res.json() expansion_path = tmp_path / "internal_workspace" / study_id / "user" / "expansion" @@ -61,9 +55,9 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t # Create a client for Xpansion with the xpansion URL xpansion_base_url = f"/v1/studies/{study_id}/extensions/xpansion/" - xp_client = TestClient(client.app, base_url=urljoin(client.base_url, xpansion_base_url)) - - res = xp_client.get("settings", headers=headers) + xp_client = TestClient(client.app, base_url=urljoin(str(client.base_url), xpansion_base_url)) + xp_client.headers = headers + res = xp_client.get("settings") assert res.status_code == 200 assert res.json() == { "master": "integer", @@ -82,7 +76,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "sensitivity_config": {"epsilon": 0.0, "projection": [], "capex": False}, } - res = xp_client.put("settings", headers=headers, json={"optimality_gap": 42}) + res = xp_client.put("settings", json={"optimality_gap": 42}) assert res.status_code == 200 assert res.json() == { "master": "integer", @@ -101,13 +95,13 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "sensitivity_config": {"epsilon": 0.0, "projection": [], "capex": False}, } - res = xp_client.put("settings", headers=headers, json={"additional-constraints": "missing.txt"}) + res = xp_client.put("settings", json={"additional-constraints": "missing.txt"}) assert res.status_code == 404 err_obj = res.json() assert re.search(r"file 'missing.txt' does not exist", err_obj["description"]) assert err_obj["exception"] == "XpansionFileNotFoundError" - res = xp_client.put("settings/additional-constraints", headers=headers, params={"filename": "missing.txt"}) + res = xp_client.put("settings/additional-constraints", params={"filename": "missing.txt"}) assert res.status_code == 404 err_obj = res.json() assert re.search(r"file 'missing.txt' does not exist", err_obj["description"]) @@ -127,7 +121,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "image/jpeg", ) } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code in {200, 201} actual_path = expansion_path / "constraints" / filename_constraints1 assert actual_path.read_text() == content_constraints1 @@ -140,7 +134,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t ), } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -157,7 +151,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "image/jpeg", ), } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code in {200, 201} files = { @@ -167,14 +161,14 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "image/jpeg", ), } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code in {200, 201} - res = xp_client.get(f"resources/constraints/{filename_constraints1}", headers=headers) + res = xp_client.get(f"resources/constraints/{filename_constraints1}") assert res.status_code == 200 assert res.json() == content_constraints1 - res = xp_client.get("resources/constraints/", headers=headers) + res = xp_client.get("resources/constraints/") assert res.status_code == 200 assert res.json() == [ filename_constraints1, @@ -182,14 +176,10 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t filename_constraints3, ] - res = xp_client.put( - "settings/additional-constraints", - headers=headers, - params={"filename": filename_constraints1}, - ) + res = xp_client.put("settings/additional-constraints", params={"filename": filename_constraints1}) assert res.status_code == 200 - res = xp_client.delete(f"resources/constraints/{filename_constraints1}", headers=headers) + res = xp_client.delete(f"resources/constraints/{filename_constraints1}") assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -199,10 +189,10 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t ) assert err_obj["exception"] == "FileCurrentlyUsedInSettings" - res = xp_client.put("settings/additional-constraints", headers=headers) + res = xp_client.put("settings/additional-constraints") assert res.status_code == 200 - res = xp_client.delete(f"resources/constraints/{filename_constraints1}", headers=headers) + res = xp_client.delete(f"resources/constraints/{filename_constraints1}") assert res.status_code == 200 candidate1 = { @@ -211,7 +201,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.post("candidates", headers=headers, json=candidate1) + res = xp_client.post("candidates", json=candidate1) assert res.status_code in {200, 201} candidate2 = { @@ -220,7 +210,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.post("candidates", headers=headers, json=candidate2) + res = xp_client.post("candidates", json=candidate2) assert res.status_code == 404 err_obj = res.json() assert re.search( @@ -236,7 +226,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.post("candidates", headers=headers, json=candidate3) + res = xp_client.post("candidates", json=candidate3) assert res.status_code == 404 err_obj = res.json() assert re.search( @@ -259,12 +249,12 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "txt/csv", ) } - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code in {200, 201} actual_path = expansion_path / "capa" / filename_capa1 assert actual_path.read_text() == content_capa1 - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -281,7 +271,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "txt/csv", ) } - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code in {200, 201} files = { @@ -291,11 +281,11 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "txt/csv", ) } - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code in {200, 201} # get single capa - res = xp_client.get(f"resources/capacities/{filename_capa1}", headers=headers) + res = xp_client.get(f"resources/capacities/{filename_capa1}") assert res.status_code == 200 assert res.json() == { "columns": [0], @@ -303,7 +293,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "index": [0], } - res = xp_client.get("resources/capacities", headers=headers) + res = xp_client.get("resources/capacities") assert res.status_code == 200 assert res.json() == [filename_capa1, filename_capa2, filename_capa3] @@ -314,21 +304,21 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "max-investment": 1.0, "link-profile": filename_capa1, } - res = xp_client.post("candidates", headers=headers, json=candidate4) + res = xp_client.post("candidates", json=candidate4) assert res.status_code in {200, 201} - res = xp_client.get(f"candidates/{candidate1['name']}", headers=headers) + res = xp_client.get(f"candidates/{candidate1['name']}") assert res.status_code == 200 - assert res.json() == XpansionCandidateDTO.parse_obj(candidate1).dict(by_alias=True) + assert res.json() == XpansionCandidateDTO.model_validate(candidate1).model_dump(by_alias=True) - res = xp_client.get("candidates", headers=headers) + res = xp_client.get("candidates") assert res.status_code == 200 assert res.json() == [ - XpansionCandidateDTO.parse_obj(candidate1).dict(by_alias=True), - XpansionCandidateDTO.parse_obj(candidate4).dict(by_alias=True), + XpansionCandidateDTO.model_validate(candidate1).model_dump(by_alias=True), + XpansionCandidateDTO.model_validate(candidate4).model_dump(by_alias=True), ] - res = xp_client.delete(f"resources/capacities/{filename_capa1}", headers=headers) + res = xp_client.delete(f"resources/capacities/{filename_capa1}") assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -343,13 +333,13 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.put(f"candidates/{candidate4['name']}", headers=headers, json=candidate5) + res = xp_client.put(f"candidates/{candidate4['name']}", json=candidate5) assert res.status_code == 200 - res = xp_client.delete(f"resources/capacities/{filename_capa1}", headers=headers) + res = xp_client.delete(f"resources/capacities/{filename_capa1}") assert res.status_code == 200 - res = client.delete(f"/v1/studies/{study_id}/extensions/xpansion", headers=headers) + res = client.delete(f"/v1/studies/{study_id}/extensions/xpansion") assert res.status_code == 200 assert not expansion_path.exists() diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py index 9095673070..8703b4af05 100644 --- a/tests/launcher/test_service.py +++ b/tests/launcher/test_service.py @@ -15,7 +15,7 @@ from antarest.core.config import ( Config, - InvalidConfigurationError, + Launcher, LauncherConfig, LocalConfig, NbCoresConfig, @@ -89,7 +89,7 @@ def test_service_run_study(self, get_current_user_mock) -> None: study_id="study_uuid", job_status=JobStatus.PENDING, launcher="local", - launcher_params=LauncherParametersDTO().json(), + launcher_params=LauncherParametersDTO().model_dump_json(), ) repository = Mock() repository.save.return_value = pending @@ -130,12 +130,12 @@ def test_service_run_study(self, get_current_user_mock) -> None: # so we need to compare them manually. mock_call = repository.save.mock_calls[0] actual_obj: JobResult = mock_call.args[0] - assert actual_obj.to_dto().dict() == pending.to_dto().dict() + assert actual_obj.to_dto().model_dump() == pending.to_dto().model_dump() event_bus.push.assert_called_once_with( Event( type=EventType.STUDY_JOB_STARTED, - payload=pending.to_dto().dict(), + payload=pending.to_dto().model_dump(), permissions=PermissionInfo(owner=0), ) ) @@ -486,11 +486,6 @@ def test_service_get_solver_versions( "unknown", {}, id="local-config-unknown", - marks=pytest.mark.xfail( - reason="Configuration is not available for the 'unknown' launcher", - raises=InvalidConfigurationError, - strict=True, - ), ), pytest.param( { @@ -518,11 +513,6 @@ def test_service_get_solver_versions( "unknown", {}, id="slurm-config-unknown", - marks=pytest.mark.xfail( - reason="Configuration is not available for the 'unknown' launcher", - raises=InvalidConfigurationError, - strict=True, - ), ), pytest.param( { @@ -563,10 +553,13 @@ def test_get_nb_cores( ) # Fetch the number of cores - actual = launcher_service.get_nb_cores(solver) - - # Check the result - assert actual == NbCoresConfig(**expected) + try: + actual = launcher_service.get_nb_cores(Launcher(solver)) + except ValueError as e: + assert e.args[0] == f"'{solver}' is not a valid Launcher" + else: + # Check the result + assert actual == NbCoresConfig(**expected) @pytest.mark.unit_test def test_service_kill_job(self, tmp_path: Path) -> None: @@ -890,7 +883,7 @@ def test_save_solver_stats(self, tmp_path: Path) -> None: solver_stats=expected_saved_stats, owner_id=1, ) - assert actual_obj.to_dto().dict() == expected_obj.to_dto().dict() + assert actual_obj.to_dto().model_dump() == expected_obj.to_dto().model_dump() zip_file = tmp_path / "test.zip" with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: @@ -907,7 +900,7 @@ def test_save_solver_stats(self, tmp_path: Path) -> None: solver_stats="0\n1", owner_id=1, ) - assert actual_obj.to_dto().dict() == expected_obj.to_dto().dict() + assert actual_obj.to_dto().model_dump() == expected_obj.to_dto().model_dump() @pytest.mark.parametrize( ["running_jobs", "expected_result", "default_launcher"], @@ -990,7 +983,7 @@ def test_get_load( job_repository.get_running.return_value = running_jobs - launcher_expected_result = LauncherLoadDTO.parse_obj(expected_result) + launcher_expected_result = LauncherLoadDTO.model_validate(expected_result) actual_result = launcher_service.get_load() assert launcher_expected_result.launcher_status == actual_result.launcher_status diff --git a/tests/launcher/test_web.py b/tests/launcher/test_web.py index e0800cf019..bfa0ebbb4c 100644 --- a/tests/launcher/test_web.py +++ b/tests/launcher/test_web.py @@ -1,5 +1,5 @@ import http -from typing import Dict, List, Union +from typing import List, Union from unittest.mock import Mock, call from uuid import uuid4 @@ -74,7 +74,7 @@ def test_result() -> None: res = client.get(f"/v1/launcher/jobs/{job}") assert res.status_code == 200 - assert JobResultDTO.parse_obj(res.json()) == result.to_dto() + assert JobResultDTO.model_validate(res.json()) == result.to_dto() service.get_result.assert_called_once_with(job, RequestParameters(DEFAULT_ADMIN_USER)) @@ -98,11 +98,11 @@ def test_jobs() -> None: client = TestClient(app) res = client.get(f"/v1/launcher/jobs?study={str(study_id)}") assert res.status_code == 200 - assert [JobResultDTO.parse_obj(j) for j in res.json()] == [result.to_dto()] + assert [JobResultDTO.model_validate(j) for j in res.json()] == [result.to_dto()] res = client.get("/v1/launcher/jobs") assert res.status_code == 200 - assert [JobResultDTO.parse_obj(j) for j in res.json()] == [result.to_dto()] + assert [JobResultDTO.model_validate(j) for j in res.json()] == [result.to_dto()] service.get_jobs.assert_has_calls( [ call( @@ -136,7 +136,7 @@ def test_get_solver_versions() -> None: pytest.param( "", http.HTTPStatus.UNPROCESSABLE_ENTITY, - {"detail": "Unknown solver configuration: ''"}, + "Input should be 'slurm', 'local' or 'default'", id="empty", ), pytest.param("default", http.HTTPStatus.OK, ["1", "2", "3"], id="default"), @@ -145,7 +145,7 @@ def test_get_solver_versions() -> None: pytest.param( "remote", http.HTTPStatus.UNPROCESSABLE_ENTITY, - {"detail": "Unknown solver configuration: 'remote'"}, + "Input should be 'slurm', 'local' or 'default'", id="remote", ), ], @@ -153,7 +153,7 @@ def test_get_solver_versions() -> None: def test_get_solver_versions__with_query_string( solver: str, status_code: http.HTTPStatus, - expected: Union[List[str], Dict[str, str]], + expected: Union[List[str], str], ) -> None: service = Mock() if status_code == http.HTTPStatus.OK: @@ -165,7 +165,12 @@ def test_get_solver_versions__with_query_string( client = TestClient(app) res = client.get(f"/v1/launcher/versions?solver={solver}") assert res.status_code == status_code # OK or UNPROCESSABLE_ENTITY - assert res.json() == expected + if status_code == http.HTTPStatus.OK: + assert res.json() == expected + else: + actual = res.json()["detail"][0] + assert actual["type"] == "enum" + assert actual["msg"] == expected @pytest.mark.unit_test diff --git a/tests/login/test_login_service.py b/tests/login/test_login_service.py index e48a54a918..609b47a0e7 100644 --- a/tests/login/test_login_service.py +++ b/tests/login/test_login_service.py @@ -370,7 +370,7 @@ def test_get_group_info(self, login_service: LoginService) -> None: actual = login_service.get_group_info("superman", _param) assert actual is not None assert actual.name == "Superman" - assert [obj.dict() for obj in actual.users] == [ + assert [obj.model_dump() for obj in actual.users] == [ {"id": 2, "name": "Clark Kent", "role": RoleType.ADMIN}, {"id": 3, "name": "Lois Lane", "role": RoleType.READER}, ] @@ -450,7 +450,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: clark_id = 2 actual = login_service.get_user_info(clark_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": clark_id, "name": "Clark Kent", "roles": [ @@ -468,7 +468,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: lois_id = 3 actual = login_service.get_user_info(lois_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": lois_id, "name": "Lois Lane", "roles": [ @@ -491,7 +491,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: _param = get_user_param(login_service, user_id=lois_id, group_id="superman") actual = login_service.get_user_info(lois_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": lois_id, "name": "Lois Lane", "roles": [ @@ -512,7 +512,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: _param = get_bot_param(login_service, bot_id=bot.id) actual = login_service.get_user_info(lois_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": lois_id, "name": "Lois Lane", "roles": [ @@ -566,13 +566,13 @@ def test_get_bot_info(self, login_service: LoginService) -> None: _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_bot_info(joh_bot.id, _param) assert actual is not None - assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + assert actual.model_dump() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} # Joh Fredersen can get its own bot _param = get_user_param(login_service, user_id=joh_id, group_id="superman") actual = login_service.get_bot_info(joh_bot.id, _param) assert actual is not None - assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + assert actual.model_dump() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} # The bot cannot get itself _param = get_bot_param(login_service, bot_id=joh_bot.id) @@ -601,13 +601,13 @@ def test_get_all_bots_by_owner(self, login_service: LoginService) -> None: _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_bots_by_owner(joh_id, _param) expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] - assert [obj.to_dto().dict() for obj in actual] == expected + assert [obj.to_dto().model_dump() for obj in actual] == expected # Freder Fredersen can get its own bot _param = get_user_param(login_service, user_id=joh_id, group_id="superman") actual = login_service.get_all_bots_by_owner(joh_id, _param) expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] - assert [obj.to_dto().dict() for obj in actual] == expected + assert [obj.to_dto().model_dump() for obj in actual] == expected # The bot cannot get itself _param = get_bot_param(login_service, bot_id=joh_bot.id) @@ -718,7 +718,7 @@ def test_get_all_groups(self, login_service: LoginService) -> None: # The site admin can get all groups _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_groups(_param) - assert [g.dict() for g in actual] == [ + assert [g.model_dump() for g in actual] == [ {"id": "admin", "name": "X-Men"}, {"id": "superman", "name": "Superman"}, {"id": "metropolis", "name": "Metropolis"}, @@ -727,19 +727,19 @@ def test_get_all_groups(self, login_service: LoginService) -> None: # The group admin can its own groups _param = get_user_param(login_service, user_id=2, group_id="superman") actual = login_service.get_all_groups(_param) - assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] + assert [g.model_dump() for g in actual] == [{"id": "superman", "name": "Superman"}] # The user can get its own groups _param = get_user_param(login_service, user_id=3, group_id="superman") actual = login_service.get_all_groups(_param) - assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] + assert [g.model_dump() for g in actual] == [{"id": "superman", "name": "Superman"}] @with_db_context def test_get_all_users(self, login_service: LoginService) -> None: # The site admin can get all users _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_users(_param) - assert [u.dict() for u in actual] == [ + assert [u.model_dump() for u in actual] == [ {"id": 1, "name": "Professor Xavier"}, {"id": 2, "name": "Clark Kent"}, {"id": 3, "name": "Lois Lane"}, @@ -751,7 +751,7 @@ def test_get_all_users(self, login_service: LoginService) -> None: # note: I don't know why the group admin can get all users -- Laurent _param = get_user_param(login_service, user_id=2, group_id="superman") actual = login_service.get_all_users(_param) - assert [u.dict() for u in actual] == [ + assert [u.model_dump() for u in actual] == [ {"id": 1, "name": "Professor Xavier"}, {"id": 2, "name": "Clark Kent"}, {"id": 3, "name": "Lois Lane"}, @@ -762,7 +762,7 @@ def test_get_all_users(self, login_service: LoginService) -> None: # The user can get its own users _param = get_user_param(login_service, user_id=3, group_id="superman") actual = login_service.get_all_users(_param) - assert [u.dict() for u in actual] == [ + assert [u.model_dump() for u in actual] == [ {"id": 2, "name": "Clark Kent"}, {"id": 3, "name": "Lois Lane"}, ] @@ -777,7 +777,7 @@ def test_get_all_bots(self, login_service: LoginService) -> None: # The site admin can get all bots _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_bots(_param) - assert [b.to_dto().dict() for b in actual] == [ + assert [b.to_dto().model_dump() for b in actual] == [ {"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}, ] @@ -796,7 +796,7 @@ def test_get_all_roles_in_group(self, login_service: LoginService) -> None: # The site admin can get all roles in a given group _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_roles_in_group("superman", _param) - assert [b.to_dto().dict() for b in actual] == [ + assert [b.to_dto().model_dump() for b in actual] == [ { "group": {"id": "superman", "name": "Superman"}, "identity": {"id": 2, "name": "Clark Kent"}, @@ -812,7 +812,7 @@ def test_get_all_roles_in_group(self, login_service: LoginService) -> None: # The group admin can get all roles his own group _param = get_user_param(login_service, user_id=2, group_id="superman") actual = login_service.get_all_roles_in_group("superman", _param) - assert [b.to_dto().dict() for b in actual] == [ + assert [b.to_dto().model_dump() for b in actual] == [ { "group": {"id": "superman", "name": "Superman"}, "identity": {"id": 2, "name": "Clark Kent"}, diff --git a/tests/login/test_web.py b/tests/login/test_web.py index 0f7175fc54..636ff48a4b 100644 --- a/tests/login/test_web.py +++ b/tests/login/test_web.py @@ -7,12 +7,12 @@ import pytest from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT from starlette.testclient import TestClient from antarest.core.config import Config, SecurityConfig from antarest.core.jwt import JWTGroup, JWTUser from antarest.core.requests import RequestParameters +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.main import build_login from antarest.login.model import ( Bot, @@ -177,7 +177,7 @@ def test_user() -> None: client = TestClient(app) res = client.get("/v1/users", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [User(id=1, name="user").to_dto().dict()] + assert res.json() == [User(id=1, name="user").to_dto().model_dump()] @pytest.mark.unit_test @@ -189,7 +189,7 @@ def test_user_id() -> None: client = TestClient(app) res = client.get("/v1/users/1", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == User(id=1, name="user").to_dto().dict() + assert res.json() == User(id=1, name="user").to_dto().model_dump() @pytest.mark.unit_test @@ -201,7 +201,7 @@ def test_user_id_with_details() -> None: client = TestClient(app) res = client.get("/v1/users/1?details=true", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == IdentityDTO(id=1, name="user", roles=[]).dict() + assert res.json() == IdentityDTO(id=1, name="user", roles=[]).model_dump() @pytest.mark.unit_test @@ -216,12 +216,12 @@ def test_user_create() -> None: res = client.post( "/v1/users", headers=create_auth_token(app), - json=user.dict(), + json=user.model_dump(), ) assert res.status_code == 200 service.create_user.assert_called_once_with(user, PARAMS) - assert res.json() == user_id.to_dto().dict() + assert res.json() == user_id.to_dto().model_dump() @pytest.mark.unit_test @@ -232,7 +232,7 @@ def test_user_save() -> None: app = create_app(service) client = TestClient(app) - user_obj = user.to_dto().dict() + user_obj = user.to_dto().model_dump() res = client.put( "/v1/users/0", headers=create_auth_token(app), @@ -244,7 +244,7 @@ def test_user_save() -> None: assert service.save_user.call_count == 1 call = service.save_user.call_args_list[0] - assert call[0][0].to_dto().dict() == user_obj + assert call[0][0].to_dto().model_dump() == user_obj assert call[0][1] == PARAMS @@ -269,7 +269,7 @@ def test_group() -> None: client = TestClient(app) res = client.get("/v1/groups", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [Group(id="my-group", name="group").to_dto().dict()] + assert res.json() == [Group(id="my-group", name="group").to_dto().model_dump()] @pytest.mark.unit_test @@ -281,7 +281,7 @@ def test_group_id() -> None: client = TestClient(app) res = client.get("/v1/groups/1", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == Group(id="my-group", name="group").to_dto().dict() + assert res.json() == Group(id="my-group", name="group").to_dto().model_dump() @pytest.mark.unit_test @@ -299,7 +299,7 @@ def test_group_create() -> None: ) assert res.status_code == 200 - assert res.json() == group.to_dto().dict() + assert res.json() == group.to_dto().model_dump() @pytest.mark.unit_test @@ -329,7 +329,7 @@ def test_role() -> None: client = TestClient(app) res = client.get("/v1/roles/group/g", headers=create_auth_token(app)) assert res.status_code == 200 - assert [RoleDetailDTO.parse_obj(el) for el in res.json()] == [role.to_dto()] + assert [RoleDetailDTO.model_validate(el) for el in res.json()] == [role.to_dto()] @pytest.mark.unit_test @@ -351,7 +351,7 @@ def test_role_create() -> None: ) assert res.status_code == 200 - assert RoleDetailDTO.parse_obj(res.json()) == role.to_dto().dict() + assert RoleDetailDTO.model_validate(res.json()) == role.to_dto() @pytest.mark.unit_test @@ -391,10 +391,10 @@ def test_bot_create() -> None: service.save_bot.return_value = bot service.get_group.return_value = Group(id="group", name="group") - print(create.json()) + create.model_dump_json() app = create_app(service) client = TestClient(app) - res = client.post("/v1/bots", headers=create_auth_token(app), json=create.dict()) + res = client.post("/v1/bots", headers=create_auth_token(app), json=create.model_dump()) assert res.status_code == 200 assert len(res.json().split(".")) == 3 @@ -410,7 +410,7 @@ def test_bot() -> None: client = TestClient(app) res = client.get("/v1/bots/0", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == bot.to_dto().dict() + assert res.json() == bot.to_dto().model_dump() @pytest.mark.unit_test @@ -424,11 +424,11 @@ def test_all_bots() -> None: client = TestClient(app) res = client.get("/v1/bots", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [b.to_dto().dict() for b in bots] + assert res.json() == [b.to_dto().model_dump() for b in bots] res = client.get("/v1/bots?owner=4", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [b.to_dto().dict() for b in bots] + assert res.json() == [b.to_dto().model_dump() for b in bots] service.get_all_bots.assert_called_once() service.get_all_bots_by_owner.assert_called_once() diff --git a/tests/matrixstore/test_matrix_editor.py b/tests/matrixstore/test_matrix_editor.py index 2907cfb793..e12a8b4dd7 100644 --- a/tests/matrixstore/test_matrix_editor.py +++ b/tests/matrixstore/test_matrix_editor.py @@ -70,7 +70,7 @@ class TestMatrixSlice: ) def test_init(self, kwargs: Dict[str, Any], expected: Dict[str, Any]) -> None: obj = MatrixSlice(**kwargs) - assert obj.dict(by_alias=False) == expected + assert obj.model_dump(by_alias=False) == expected class TestOperation: @@ -97,12 +97,12 @@ class TestOperation: ) def test_init(self, kwargs: Dict[str, Any], expected: Dict[str, Any]) -> None: obj = Operation(**kwargs) - assert obj.dict(by_alias=False) == expected + assert obj.model_dump(by_alias=False) == expected @pytest.mark.parametrize("operation", list(OPERATIONS)) def test_init__valid_operation(self, operation: str) -> None: obj = Operation(operation=operation, value=123) - assert obj.dict(by_alias=False) == { + assert obj.model_dump(by_alias=False) == { "operation": operation, "value": 123.0, } @@ -192,4 +192,4 @@ class TestMatrixEditInstruction: ) def test_init(self, kwargs: Dict[str, Any], expected: Dict[str, Any]) -> None: obj = MatrixEditInstruction(**kwargs) - assert obj.dict(by_alias=False) == expected + assert obj.model_dump(by_alias=False) == expected diff --git a/tests/matrixstore/test_service.py b/tests/matrixstore/test_service.py index 5c6eb837c7..41828a6cf8 100644 --- a/tests/matrixstore/test_service.py +++ b/tests/matrixstore/test_service.py @@ -485,12 +485,8 @@ def test_dataset_lifecycle() -> None: dataset_repo.delete.assert_called_once() -def _create_upload_file(filename: str, file: io.BytesIO, content_type: str = "") -> UploadFile: - if hasattr(UploadFile, "content_type"): - # `content_type` attribute was replace by a read-ony property in starlette-v0.24. - headers = Headers(headers={"content-type": content_type}) - # noinspection PyTypeChecker,PyArgumentList - return UploadFile(filename=filename, file=file, headers=headers) - else: - # noinspection PyTypeChecker,PyArgumentList - return UploadFile(filename=filename, file=file, content_type=content_type) +def _create_upload_file(filename: str, file: t.IO = None, content_type: str = "") -> UploadFile: + # `content_type` attribute was replace by a read-ony property in starlette-v0.24. + headers = Headers(headers={"content-type": content_type}) + # noinspection PyTypeChecker,PyArgumentList + return UploadFile(filename=filename, file=file, headers=headers) diff --git a/tests/matrixstore/test_web.py b/tests/matrixstore/test_web.py index 36b7fc366d..e1d4d230d2 100644 --- a/tests/matrixstore/test_web.py +++ b/tests/matrixstore/test_web.py @@ -3,10 +3,10 @@ import pytest from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT from starlette.testclient import TestClient from antarest.core.config import Config, SecurityConfig +from antarest.fastapi_jwt_auth import AuthJWT from antarest.main import JwtSettings from antarest.matrixstore.main import build_matrix_service from antarest.matrixstore.model import MatrixDTO, MatrixInfoDTO @@ -62,7 +62,7 @@ def test_create() -> None: json=matrix_data, ) assert res.status_code == 200 - assert res.json() == matrix.dict() + assert res.json() == matrix.model_dump() @pytest.mark.unit_test @@ -84,7 +84,7 @@ def test_get() -> None: client = TestClient(app) res = client.get("/v1/matrix/123", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == matrix.dict() + assert res.json() == matrix.model_dump() service.get.assert_called_once_with("123") @@ -114,4 +114,4 @@ def test_import() -> None: files={"file": ("Matrix.zip", bytes(5), "application/zip")}, ) assert res.status_code == 200 - assert res.json() == matrix_info + assert [MatrixInfoDTO.model_validate(res.json()[0])] == matrix_info diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index a8beff5fc0..93681165c5 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -162,35 +162,36 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): variant_id, [ CommandDTO( + id=None, action=CommandName.UPDATE_CONFIG.value, args=[ { "target": "input/areas/test/ui/ui/x", - "data": "100", + "data": 100, }, { "target": "input/areas/test/ui/ui/y", - "data": "200", + "data": 200, }, { "target": "input/areas/test/ui/ui/color_r", - "data": "255", + "data": 255, }, { "target": "input/areas/test/ui/ui/color_g", - "data": "0", + "data": 0, }, { "target": "input/areas/test/ui/ui/color_b", - "data": "100", + "data": 100, }, { "target": "input/areas/test/ui/layerX/0", - "data": "100", + "data": 100, }, { "target": "input/areas/test/ui/layerY/0", - "data": "200", + "data": 200, }, { "target": "input/areas/test/ui/layerColor/0", @@ -290,7 +291,7 @@ def test_get_all_area(): area_manager.patch_service = Mock() area_manager.patch_service.get.return_value = Patch( areas={"a1": PatchArea(country="fr")}, - thermal_clusters={"a1.a": PatchCluster.parse_obj({"code-oi": "1"})}, + thermal_clusters={"a1.a": PatchCluster.model_validate({"code-oi": "1"})}, ) file_tree_mock.get.side_effect = [ { @@ -350,7 +351,7 @@ def test_get_all_area(): }, ] areas = area_manager.get_all_areas(study, AreaType.AREA) - assert expected_areas == [area.dict() for area in areas] + assert expected_areas == [area.model_dump() for area in areas] expected_clusters = [ { @@ -363,7 +364,7 @@ def test_get_all_area(): } ] clusters = area_manager.get_all_areas(study, AreaType.DISTRICT) - assert expected_clusters == [area.dict() for area in clusters] + assert expected_clusters == [area.model_dump() for area in clusters] file_tree_mock.get.side_effect = [{}, {}, {}] expected_all = [ @@ -401,14 +402,14 @@ def test_get_all_area(): }, ] all_areas = area_manager.get_all_areas(study) - assert expected_all == [area.dict() for area in all_areas] + assert expected_all == [area.model_dump() for area in all_areas] links = link_manager.get_all_links(study) assert [ {"area1": "a1", "area2": "a2", "ui": None}, {"area1": "a1", "area2": "a3", "ui": None}, {"area1": "a2", "area2": "a3", "ui": None}, - ] == [link.dict() for link in links] + ] == [link.model_dump() for link in links] def test_update_area(): @@ -451,7 +452,7 @@ def test_update_area(): new_area_info = area_manager.update_area_metadata(study, "a1", PatchArea(country="fr")) assert new_area_info.id == "a1" - assert new_area_info.metadata == {"country": "fr", "tags": []} + assert new_area_info.metadata.model_dump() == {"country": "fr", "tags": []} def test_update_clusters(): @@ -484,7 +485,7 @@ def test_update_clusters(): area_manager.patch_service = Mock() area_manager.patch_service.get.return_value = Patch( areas={"a1": PatchArea(country="fr")}, - thermal_clusters={"a1.a": PatchCluster.parse_obj({"code-oi": "1"})}, + thermal_clusters={"a1.a": PatchCluster.model_validate({"code-oi": "1"})}, ) file_tree_mock.get.side_effect = [ { diff --git a/tests/storage/business/test_config_manager.py b/tests/storage/business/test_config_manager.py index f4d344a27d..d6b648a2f6 100644 --- a/tests/storage/business/test_config_manager.py +++ b/tests/storage/business/test_config_manager.py @@ -19,7 +19,7 @@ def test_thematic_trimming_config() -> None: - command_context = CommandContext.construct() + command_context = CommandContext.model_construct() command_factory_mock = Mock() command_factory_mock.command_context = command_context raw_study_service = Mock(spec=RawStudyService) @@ -54,27 +54,27 @@ def test_thematic_trimming_config() -> None: study.version = config.version = 700 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) assert actual == expected study.version = config.version = 800 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) expected.avl_dtg = False assert actual == expected study.version = config.version = 820 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) expected.avl_dtg = False assert actual == expected study.version = config.version = 830 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) expected.dens = False expected.profit_by_plant = False @@ -82,7 +82,7 @@ def test_thematic_trimming_config() -> None: study.version = config.version = 840 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, False)) expected.cong_fee_alg = True assert actual == expected diff --git a/tests/storage/business/test_patch_service.py b/tests/storage/business/test_patch_service.py index ed7dd6c444..c2142427e8 100644 --- a/tests/storage/business/test_patch_service.py +++ b/tests/storage/business/test_patch_service.py @@ -182,7 +182,7 @@ def test_set_output_ref(self, tmp_path: Path): additional_data=StudyAdditionalData( author="john.doe", horizon="foo-horizon", - patch=patch_outputs.json(), + patch=patch_outputs.model_dump_json(), ), archived=False, owner=None, diff --git a/tests/storage/business/test_timeseries_config_manager.py b/tests/storage/business/test_timeseries_config_manager.py index 387e42f5dc..de7f04fb76 100644 --- a/tests/storage/business/test_timeseries_config_manager.py +++ b/tests/storage/business/test_timeseries_config_manager.py @@ -44,7 +44,7 @@ def file_study_720(tmpdir: Path) -> FileStudy: def test_ts_field_values(file_study_820: FileStudy, file_study_720: FileStudy): command_factory_mock = Mock() - command_factory_mock.command_context = CommandContext.construct() + command_factory_mock.command_context = CommandContext.model_construct() raw_study_service = Mock(spec=RawStudyService) diff --git a/tests/storage/business/test_xpansion_manager.py b/tests/storage/business/test_xpansion_manager.py index 100bddd286..14d67a048f 100644 --- a/tests/storage/business/test_xpansion_manager.py +++ b/tests/storage/business/test_xpansion_manager.py @@ -197,7 +197,7 @@ def test_get_xpansion_settings(tmp_path: Path, version: int, expected_output: JS xpansion_manager.create_xpansion_configuration(study) actual = xpansion_manager.get_xpansion_settings(study) - assert actual.dict(by_alias=True) == expected_output + assert actual.model_dump(by_alias=True) == expected_output @pytest.mark.unit_test @@ -244,7 +244,7 @@ def test_update_xpansion_settings(tmp_path: Path) -> None: "timelimit": int(1e12), "sensitivity_config": {"epsilon": 10500.0, "projection": ["foo"], "capex": False}, } - assert actual.dict(by_alias=True) == expected + assert actual.model_dump(by_alias=True) == expected @pytest.mark.unit_test @@ -254,7 +254,7 @@ def test_add_candidate(tmp_path: Path) -> None: actual = empty_study.tree.get(["user", "expansion", "candidates"]) assert actual == {} - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -263,7 +263,7 @@ def test_add_candidate(tmp_path: Path) -> None: } ) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -284,13 +284,13 @@ def test_add_candidate(tmp_path: Path) -> None: xpansion_manager.add_candidate(study, new_candidate) - candidates = {"1": new_candidate.dict(by_alias=True, exclude_none=True)} + candidates = {"1": new_candidate.model_dump(by_alias=True, exclude_none=True)} actual = empty_study.tree.get(["user", "expansion", "candidates"]) assert actual == candidates xpansion_manager.add_candidate(study, new_candidate2) - candidates["2"] = new_candidate2.dict(by_alias=True, exclude_none=True) + candidates["2"] = new_candidate2.model_dump(by_alias=True, exclude_none=True) actual = empty_study.tree.get(["user", "expansion", "candidates"]) assert actual == candidates @@ -302,7 +302,7 @@ def test_get_candidate(tmp_path: Path) -> None: assert empty_study.tree.get(["user", "expansion", "candidates"]) == {} - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -311,7 +311,7 @@ def test_get_candidate(tmp_path: Path) -> None: } ) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -335,7 +335,7 @@ def test_get_candidates(tmp_path: Path) -> None: assert empty_study.tree.get(["user", "expansion", "candidates"]) == {} - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -344,7 +344,7 @@ def test_get_candidates(tmp_path: Path) -> None: } ) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -372,7 +372,7 @@ def test_update_candidates(tmp_path: Path) -> None: make_link_and_areas(empty_study) - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -382,7 +382,7 @@ def test_update_candidates(tmp_path: Path) -> None: ) xpansion_manager.add_candidate(study, new_candidate) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -403,7 +403,7 @@ def test_delete_candidate(tmp_path: Path) -> None: make_link_and_areas(empty_study) - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -413,7 +413,7 @@ def test_delete_candidate(tmp_path: Path) -> None: ) xpansion_manager.add_candidate(study, new_candidate) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -488,14 +488,14 @@ def test_add_resources(tmp_path: Path) -> None: settings = xpansion_manager.get_xpansion_settings(study) settings.yearly_weights = filename3 - update_settings = UpdateXpansionSettings(**settings.dict()) + update_settings = UpdateXpansionSettings(**settings.model_dump()) xpansion_manager.update_xpansion_settings(study, update_settings) with pytest.raises(FileCurrentlyUsedInSettings): xpansion_manager.delete_resource(study, XpansionResourceFileType.WEIGHTS, filename3) settings.yearly_weights = "" - update_settings = UpdateXpansionSettings(**settings.dict()) + update_settings = UpdateXpansionSettings(**settings.model_dump()) xpansion_manager.update_xpansion_settings(study, update_settings) xpansion_manager.delete_resource(study, XpansionResourceFileType.WEIGHTS, filename3) diff --git a/tests/storage/rawstudies/test_factory.py b/tests/storage/rawstudies/test_factory.py index e2cd4391b3..7b4841e295 100644 --- a/tests/storage/rawstudies/test_factory.py +++ b/tests/storage/rawstudies/test_factory.py @@ -53,8 +53,8 @@ def test_factory_cache() -> None: cache.get.return_value = None study = factory.create_from_fs(path, study_id) assert study.config == config - cache.put.assert_called_once_with(cache_id, FileStudyTreeConfigDTO.from_build_config(config).dict()) + cache.put.assert_called_once_with(cache_id, FileStudyTreeConfigDTO.from_build_config(config).model_dump()) - cache.get.return_value = FileStudyTreeConfigDTO.from_build_config(config).dict() + cache.get.return_value = FileStudyTreeConfigDTO.from_build_config(config).model_dump() study = factory.create_from_fs(path, study_id) assert study.config == config diff --git a/tests/storage/repository/filesystem/config/test_config_files.py b/tests/storage/repository/filesystem/config/test_config_files.py index 93482d359d..eec1e2f066 100644 --- a/tests/storage/repository/filesystem/config/test_config_files.py +++ b/tests/storage/repository/filesystem/config/test_config_files.py @@ -403,7 +403,7 @@ def test_parse_thermal_860(study_path: Path, version, caplog) -> None: assert not caplog.text else: expected = [ThermalConfig(id="t1", name="t1")] - assert "extra fields not permitted" in caplog.text + assert "Extra inputs are not permitted" in caplog.text assert actual == expected diff --git a/tests/storage/test_model.py b/tests/storage/test_model.py index d17d6c89a4..329fa1965e 100644 --- a/tests/storage/test_model.py +++ b/tests/storage/test_model.py @@ -55,5 +55,5 @@ def test_file_study_tree_config_dto(): enr_modelling="aggregated", ) config_dto = FileStudyTreeConfigDTO.from_build_config(config) - assert sorted(list(config_dto.dict()) + ["cache"]) == sorted(list(config.__dict__)) + assert sorted(list(config_dto.model_dump()) + ["cache"]) == sorted(list(config.__dict__)) assert config_dto.to_build_config() == config diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index 7ebf94a09e..4751e60ef0 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -585,7 +585,7 @@ def test_download_output() -> None: name="east", type=StudyDownloadType.AREA, data={ - 1: [ + "1": [ TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5]), TimeSerie(name="some cluster", unit="Euro/MWh", data=[0.8]), ] @@ -649,7 +649,7 @@ def test_download_output() -> None: TimeSeriesData( name="east^west", type=StudyDownloadType.LINK, - data={1: [TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5])]}, + data={"1": [TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5])]}, ) ], warnings=[], @@ -683,7 +683,7 @@ def test_download_output() -> None: name="north", type=StudyDownloadType.DISTRICT, data={ - 1: [ + "1": [ TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5]), TimeSerie(name="some cluster", unit="Euro/MWh", data=[0.8]), ] @@ -1379,7 +1379,7 @@ def test_unarchive_output(tmp_path: Path) -> None: src=str(tmp_path / "output" / f"{output_id}.zip"), dest=str(tmp_path / "output" / output_id), remove_src=False, - ).dict(), + ).model_dump(), name=f"Unarchive output {study_name}/{output_id} ({study_id})", ref_id=study_id, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), @@ -1510,7 +1510,7 @@ def test_archive_output_locks(tmp_path: Path) -> None: src=str(tmp_path / "output" / f"{output_id}.zip"), dest=str(tmp_path / "output" / output_id), remove_src=False, - ).dict(), + ).model_dump(), name=f"Unarchive output {study_name}/{output_id} ({study_id})", ref_id=study_id, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), diff --git a/tests/storage/web/test_studies_bp.py b/tests/storage/web/test_studies_bp.py index 2ee0a9c7e2..22b2260b34 100644 --- a/tests/storage/web/test_studies_bp.py +++ b/tests/storage/web/test_studies_bp.py @@ -287,7 +287,7 @@ def test_list_studies(tmp_path: str) -> None: client = TestClient(app) result = client.get("/v1/studies") - assert {k: StudyMetadataDTO.parse_obj(v) for k, v in result.json().items()} == studies + assert {k: StudyMetadataDTO.model_validate(v) for k, v in result.json().items()} == studies def test_study_metadata(tmp_path: str) -> None: @@ -322,7 +322,7 @@ def test_study_metadata(tmp_path: str) -> None: client = TestClient(app) result = client.get("/v1/studies/1") - assert StudyMetadataDTO.parse_obj(result.json()) == study + assert StudyMetadataDTO.model_validate(result.json()) == study @pytest.mark.unit_test @@ -370,7 +370,7 @@ def test_export_files(tmp_path: Path) -> None: res.raise_for_status() result = json.loads(data.getvalue()) - assert FileDownloadTaskDTO(**result).json() == expected.json() + assert FileDownloadTaskDTO(**result).model_dump_json() == expected.model_dump_json() mock_storage_service.export_study.assert_called_once_with(UUID, PARAMS, True) @@ -554,9 +554,9 @@ def test_output_download(tmp_path: Path) -> None: client = TestClient(app, raise_server_exceptions=False) res = client.post( f"/v1/studies/{UUID}/outputs/my-output-id/download", - json=study_download.dict(), + json=study_download.model_dump(), ) - assert res.json() == output_data.dict() + assert res.json() == output_data.model_dump() @pytest.mark.unit_test @@ -657,7 +657,8 @@ def test_sim_result() -> None: ) client = TestClient(app, raise_server_exceptions=False) res = client.get(f"/v1/studies/{study_id}/outputs") - assert res.json() == result_data + actual_object = [StudySimResultDTO.parse_obj(res.json()[0])] + assert actual_object == result_data @pytest.mark.unit_test diff --git a/tests/study/business/areas/test_st_storage_management.py b/tests/study/business/areas/test_st_storage_management.py index 74089bc7cd..d6e0fa8981 100644 --- a/tests/study/business/areas/test_st_storage_management.py +++ b/tests/study/business/areas/test_st_storage_management.py @@ -64,6 +64,8 @@ "east": {"list": {}}, } +GEN = np.random.default_rng(1000) + class TestSTStorageManager: @pytest.fixture(name="study_storage_service") @@ -135,7 +137,7 @@ def test_get_all_storages__nominal_case( # Check actual = { - area_id: [form.dict(by_alias=True) for form in clusters_by_ids.values()] + area_id: [form.model_dump(by_alias=True) for form in clusters_by_ids.values()] for area_id, clusters_by_ids in all_storages.items() } expected = { @@ -241,7 +243,7 @@ def test_get_st_storages__nominal_case( groups = manager.get_storages(study, area_id="West") # Check - actual = [form.dict(by_alias=True) for form in groups] + actual = [form.model_dump(by_alias=True) for form in groups] expected = [ { "efficiency": 0.94, @@ -353,7 +355,7 @@ def test_get_st_storage__nominal_case( edit_form = manager.get_storage(study, area_id="West", storage_id="storage1") # Assert that the returned storage fields match the expected fields - actual = edit_form.dict(by_alias=True) + actual = edit_form.model_dump(by_alias=True) expected = { "efficiency": 0.94, "group": STStorageGroup.BATTERY, @@ -443,11 +445,11 @@ def test_update_storage__nominal_case( actual = {(call_args[0][0], tuple(call_args[0][1])) for call_args in file_study.tree.save.call_args_list} expected = { ( - str(0.0), + 0.0, ("input", "st-storage", "clusters", "West", "list", "storage1", "initiallevel"), ), ( - str(False), + False, ("input", "st-storage", "clusters", "West", "list", "storage1", "initialleveloptim"), ), } @@ -516,16 +518,15 @@ def test_get_matrix__nominal_case( # Prepare the mocks storage = study_storage_service.get_storage(study) file_study = storage.get_raw(study) - array = np.random.rand(8760, 1) * 1000 + array = GEN.random((8760, 1)) * 1000 + expected = { + "index": list(range(8760)), + "columns": [0], + "data": array.tolist(), + } file_study.tree = Mock( spec=FileStudyTree, - get=Mock( - return_value={ - "index": list(range(8760)), - "columns": [0], - "data": array.tolist(), - } - ), + get=Mock(return_value=expected), ) # Given the following arguments @@ -535,8 +536,8 @@ def test_get_matrix__nominal_case( matrix = manager.get_matrix(study, area_id="West", storage_id="storage1", ts_name="inflows") # Assert that the returned storage fields match the expected fields - actual = matrix.dict(by_alias=True) - assert actual == matrix + actual = matrix.model_dump(by_alias=True) + assert actual == expected def test_get_matrix__config_not_found( self, @@ -604,7 +605,7 @@ def test_get_matrix__invalid_matrix( # Prepare the mocks storage = study_storage_service.get_storage(study) file_study = storage.get_raw(study) - array = np.random.rand(365, 1) * 1000 + array = GEN.random((365, 1)) * 1000 matrix = { "index": list(range(365)), "columns": [0], @@ -637,11 +638,11 @@ def test_validate_matrices__nominal( # prepare some random matrices, insuring `lower_rule_curve` <= `upper_rule_curve` matrices = { - "pmax_injection": np.random.rand(8760, 1), - "pmax_withdrawal": np.random.rand(8760, 1), - "lower_rule_curve": np.random.rand(8760, 1) / 2, - "upper_rule_curve": np.random.rand(8760, 1) / 2 + 0.5, - "inflows": np.random.rand(8760, 1) * 1000, + "pmax_injection": GEN.random((8760, 1)), + "pmax_withdrawal": GEN.random((8760, 1)), + "lower_rule_curve": GEN.random((8760, 1)) / 2, + "upper_rule_curve": GEN.random((8760, 1)) / 2 + 0.5, + "inflows": GEN.random((8760, 1)) * 1000, } # Prepare the mocks @@ -674,11 +675,11 @@ def test_validate_matrices__out_of_bound( # prepare some random matrices, insuring `lower_rule_curve` <= `upper_rule_curve` matrices = { - "pmax_injection": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "pmax_withdrawal": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "lower_rule_curve": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "upper_rule_curve": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "inflows": np.random.rand(8760, 1) * 1000, + "pmax_injection": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "pmax_withdrawal": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "lower_rule_curve": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "upper_rule_curve": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "inflows": GEN.random((8760, 1)) * 1000, } # Prepare the mocks @@ -695,7 +696,6 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: file_study = storage.get_raw(study) file_study.tree = Mock(spec=FileStudyTree, get=tree_get) - # Given the following arguments, the validation shouldn't raise any exception manager = STStorageManager(study_storage_service) # Run the method being tested and expect an exception @@ -704,29 +704,11 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: match=re.escape("4 validation errors"), ) as ctx: manager.validate_matrices(study, area_id="West", storage_id="storage1") - errors = ctx.value.errors() - assert errors == [ - { - "loc": ("pmax_injection",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - { - "loc": ("pmax_withdrawal",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - { - "loc": ("lower_rule_curve",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - { - "loc": ("upper_rule_curve",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - ] + assert ctx.value.error_count() == 4 + for error in ctx.value.errors(): + assert error["type"] == "value_error" + assert error["msg"] == "Value error, Matrix values should be between 0 and 1" + assert error["loc"][0] in ["upper_rule_curve", "lower_rule_curve", "pmax_withdrawal", "pmax_injection"] # noinspection SpellCheckingInspection def test_validate_matrices__rule_curve( @@ -738,13 +720,15 @@ def test_validate_matrices__rule_curve( # The study must be fetched from the database study: RawStudy = db_session.query(Study).get(study_uuid) - # prepare some random matrices, insuring `lower_rule_curve` <= `upper_rule_curve` + # prepare some random matrices, not respecting `lower_rule_curve` <= `upper_rule_curve` + upper_curve = np.zeros((8760, 1)) + lower_curve = np.ones((8760, 1)) matrices = { - "pmax_injection": np.random.rand(8760, 1), - "pmax_withdrawal": np.random.rand(8760, 1), - "lower_rule_curve": np.random.rand(8760, 1), - "upper_rule_curve": np.random.rand(8760, 1), - "inflows": np.random.rand(8760, 1) * 1000, + "pmax_injection": GEN.random((8760, 1)), + "pmax_withdrawal": GEN.random((8760, 1)), + "lower_rule_curve": lower_curve, + "upper_rule_curve": upper_curve, + "inflows": GEN.random((8760, 1)) * 1000, } # Prepare the mocks @@ -761,7 +745,7 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: file_study = storage.get_raw(study) file_study.tree = Mock(spec=FileStudyTree, get=tree_get) - # Given the following arguments, the validation shouldn't raise any exception + # Given the following arguments manager = STStorageManager(study_storage_service) # Run the method being tested and expect an exception @@ -771,6 +755,8 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: ) as ctx: manager.validate_matrices(study, area_id="West", storage_id="storage1") error = ctx.value.errors()[0] - assert error["loc"] == ("__root__",) - assert "lower_rule_curve" in error["msg"] - assert "upper_rule_curve" in error["msg"] + assert error["type"] == "value_error" + assert ( + error["msg"] + == "Value error, Each 'lower_rule_curve' value must be lower or equal to each 'upper_rule_curve'" + ) diff --git a/tests/study/business/areas/test_thermal_management.py b/tests/study/business/areas/test_thermal_management.py index dd52ec9538..03866ac640 100644 --- a/tests/study/business/areas/test_thermal_management.py +++ b/tests/study/business/areas/test_thermal_management.py @@ -132,7 +132,7 @@ def test_get_cluster__study_legacy( form = manager.get_cluster(study, area_id="north", cluster_id="2 avail and must 1") # Assert that the returned fields match the expected fields - actual = form.dict(by_alias=True) + actual = form.model_dump(by_alias=True) expected = { "id": "2 avail and must 1", "group": ThermalClusterGroup.GAS, @@ -198,7 +198,7 @@ def test_get_clusters__study_legacy( groups = manager.get_clusters(study, area_id="north") # Assert that the returned fields match the expected fields - actual = [form.dict(by_alias=True) for form in groups] + actual = [form.model_dump(by_alias=True) for form in groups] expected = [ { "id": "2 avail and must 1", @@ -354,7 +354,7 @@ def test_create_cluster__study_legacy( form = manager.create_cluster(study, area_id="north", cluster_data=cluster_data) # Assert that the returned fields match the expected fields - actual = form.dict(by_alias=True) + actual = form.model_dump(by_alias=True) expected = { "co2": 12.59, "enabled": True, @@ -414,7 +414,7 @@ def test_update_cluster( # Assert that the returned fields match the expected fields form = manager.get_cluster(study, area_id="north", cluster_id="2 avail and must 1") - actual = form.dict(by_alias=True) + actual = form.model_dump(by_alias=True) expected = { "id": "2 avail and must 1", "group": ThermalClusterGroup.GAS, diff --git a/tests/study/business/test_all_optional_metaclass.py b/tests/study/business/test_all_optional_metaclass.py index 1c379f6460..0cb179c143 100644 --- a/tests/study/business/test_all_optional_metaclass.py +++ b/tests/study/business/test_all_optional_metaclass.py @@ -1,349 +1,48 @@ -import typing as t +from pydantic import BaseModel, Field -import pytest -from pydantic import BaseModel, Field, ValidationError +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model -from antarest.study.business.all_optional_meta import AllOptionalMetaclass -# ============================================== -# Classic way to use default and optional values -# ============================================== +class Model(BaseModel): + float_with_default: float = 1 + float_without_default: float + boolean_with_default: bool = True + boolean_without_default: bool + field_with_alias: str = Field(default="default", alias="field-with-alias") -class ClassicModel(BaseModel): - mandatory: float = Field(ge=0, le=1) - mandatory_with_default: float = Field(ge=0, le=1, default=0.2) - mandatory_with_none: float = Field(ge=0, le=1, default=None) - optional: t.Optional[float] = Field(ge=0, le=1) - optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) - optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) - - -class ClassicSubModel(ClassicModel): - pass - - -class TestClassicModel: - """ - Test that default and optional values work as expected. - """ - - @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) - def test_classes(self, cls: t.Type[BaseModel]) -> None: - assert cls.__fields__["mandatory"].required is True - assert cls.__fields__["mandatory"].allow_none is False - assert cls.__fields__["mandatory"].default is None - assert cls.__fields__["mandatory"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_default"].required is False - assert cls.__fields__["mandatory_with_default"].allow_none is False - assert cls.__fields__["mandatory_with_default"].default == 0.2 - assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_none"].required is False - assert cls.__fields__["mandatory_with_none"].allow_none is True - assert cls.__fields__["mandatory_with_none"].default is None - assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined - - assert cls.__fields__["optional"].required is False - assert cls.__fields__["optional"].allow_none is True - assert cls.__fields__["optional"].default is None - assert cls.__fields__["optional"].default_factory is None # undefined - - assert cls.__fields__["optional_with_default"].required is False - assert cls.__fields__["optional_with_default"].allow_none is True - assert cls.__fields__["optional_with_default"].default == 0.2 - assert cls.__fields__["optional_with_default"].default_factory is None # undefined - - assert cls.__fields__["optional_with_none"].required is False - assert cls.__fields__["optional_with_none"].allow_none is True - assert cls.__fields__["optional_with_none"].default is None - assert cls.__fields__["optional_with_none"].default_factory is None # undefined - - @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) - def test_initialization(self, cls: t.Type[ClassicModel]) -> None: - # We can build a model without providing optional or default values. - # The initialized value will be the default value or `None` for optional values. - obj = cls(mandatory=0.5) - assert obj.mandatory == 0.5 - assert obj.mandatory_with_default == 0.2 - assert obj.mandatory_with_none is None - assert obj.optional is None - assert obj.optional_with_default == 0.2 - assert obj.optional_with_none is None - - # We must provide a value for mandatory fields. - with pytest.raises(ValidationError): - cls() - - @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) - def test_validation(self, cls: t.Type[ClassicModel]) -> None: - # We CANNOT use `None` as a value for a field with a default value. - with pytest.raises(ValidationError): - cls(mandatory=0.5, mandatory_with_default=None) - - # We can use `None` as a value for optional fields with default value. - cls(mandatory=0.5, optional_with_default=None) - - # We can validate a model with valid values. - cls( - mandatory=0.5, - mandatory_with_default=0.2, - mandatory_with_none=0.2, - optional=0.5, - optional_with_default=0.2, - optional_with_none=0.2, - ) - - # We CANNOT validate a model with invalid values. - with pytest.raises(ValidationError): - cls(mandatory=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, mandatory_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, mandatory_with_none=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, optional=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, optional_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, optional_with_none=2) - - -# ========================== -# Using AllOptionalMetaclass -# ========================== - - -class AllOptionalModel(BaseModel, metaclass=AllOptionalMetaclass): - mandatory: float = Field(ge=0, le=1) - mandatory_with_default: float = Field(ge=0, le=1, default=0.2) - mandatory_with_none: float = Field(ge=0, le=1, default=None) - optional: t.Optional[float] = Field(ge=0, le=1) - optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) - optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) - - -class AllOptionalSubModel(AllOptionalModel): +@all_optional_model +class OptionalModel(Model): pass -class InheritedAllOptionalModel(ClassicModel, metaclass=AllOptionalMetaclass): +@all_optional_model +@camel_case_model +class OptionalCamelCaseModel(Model): pass -class TestAllOptionalModel: - """ - Test that AllOptionalMetaclass works with base classes. - """ - - @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) - def test_classes(self, cls: t.Type[BaseModel]) -> None: - assert cls.__fields__["mandatory"].required is False - assert cls.__fields__["mandatory"].allow_none is True - assert cls.__fields__["mandatory"].default is None - assert cls.__fields__["mandatory"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_default"].required is False - assert cls.__fields__["mandatory_with_default"].allow_none is True - assert cls.__fields__["mandatory_with_default"].default == 0.2 - assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_none"].required is False - assert cls.__fields__["mandatory_with_none"].allow_none is True - assert cls.__fields__["mandatory_with_none"].default is None - assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined - - assert cls.__fields__["optional"].required is False - assert cls.__fields__["optional"].allow_none is True - assert cls.__fields__["optional"].default is None - assert cls.__fields__["optional"].default_factory is None # undefined - - assert cls.__fields__["optional_with_default"].required is False - assert cls.__fields__["optional_with_default"].allow_none is True - assert cls.__fields__["optional_with_default"].default == 0.2 - assert cls.__fields__["optional_with_default"].default_factory is None # undefined - - assert cls.__fields__["optional_with_none"].required is False - assert cls.__fields__["optional_with_none"].allow_none is True - assert cls.__fields__["optional_with_none"].default is None - assert cls.__fields__["optional_with_none"].default_factory is None # undefined - - @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) - def test_initialization(self, cls: t.Type[AllOptionalModel]) -> None: - # We can build a model without providing values. - # The initialized value will be the default value or `None` for optional values. - # Note that the mandatory fields are not required anymore, and can be `None`. - obj = cls() - assert obj.mandatory is None - assert obj.mandatory_with_default == 0.2 - assert obj.mandatory_with_none is None - assert obj.optional is None - assert obj.optional_with_default == 0.2 - assert obj.optional_with_none is None - - # If we convert the model to a dictionary, without `None` values, - # we should have a dictionary with default values only. - actual = obj.dict(exclude_none=True) - expected = { - "mandatory_with_default": 0.2, - "optional_with_default": 0.2, - } - assert actual == expected - - @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) - def test_validation(self, cls: t.Type[AllOptionalModel]) -> None: - # We can use `None` as a value for all fields. - cls(mandatory=None) - cls(mandatory_with_default=None) - cls(mandatory_with_none=None) - cls(optional=None) - cls(optional_with_default=None) - cls(optional_with_none=None) - - # We can validate a model with valid values. - cls( - mandatory=0.5, - mandatory_with_default=0.2, - mandatory_with_none=0.2, - optional=0.5, - optional_with_default=0.2, - optional_with_none=0.2, - ) - - # We CANNOT validate a model with invalid values. - with pytest.raises(ValidationError): - cls(mandatory=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_none=2) - - with pytest.raises(ValidationError): - cls(optional=2) - - with pytest.raises(ValidationError): - cls(optional_with_default=2) - - with pytest.raises(ValidationError): - cls(optional_with_none=2) - - -# The `use_none` keyword argument is set to `True` to allow the use of `None` -# as a default value for the fields of the model. - - -class UseNoneModel(BaseModel, metaclass=AllOptionalMetaclass, use_none=True): - mandatory: float = Field(ge=0, le=1) - mandatory_with_default: float = Field(ge=0, le=1, default=0.2) - mandatory_with_none: float = Field(ge=0, le=1, default=None) - optional: t.Optional[float] = Field(ge=0, le=1) - optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) - optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) - - -class UseNoneSubModel(UseNoneModel): - pass - - -class InheritedUseNoneModel(ClassicModel, metaclass=AllOptionalMetaclass, use_none=True): - pass - - -class TestUseNoneModel: - @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) - def test_classes(self, cls: t.Type[BaseModel]) -> None: - assert cls.__fields__["mandatory"].required is False - assert cls.__fields__["mandatory"].allow_none is True - assert cls.__fields__["mandatory"].default is None - assert cls.__fields__["mandatory"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_default"].required is False - assert cls.__fields__["mandatory_with_default"].allow_none is True - assert cls.__fields__["mandatory_with_default"].default is None - assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_none"].required is False - assert cls.__fields__["mandatory_with_none"].allow_none is True - assert cls.__fields__["mandatory_with_none"].default is None - assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined - - assert cls.__fields__["optional"].required is False - assert cls.__fields__["optional"].allow_none is True - assert cls.__fields__["optional"].default is None - assert cls.__fields__["optional"].default_factory is None # undefined - - assert cls.__fields__["optional_with_default"].required is False - assert cls.__fields__["optional_with_default"].allow_none is True - assert cls.__fields__["optional_with_default"].default is None - assert cls.__fields__["optional_with_default"].default_factory is None # undefined - - assert cls.__fields__["optional_with_none"].required is False - assert cls.__fields__["optional_with_none"].allow_none is True - assert cls.__fields__["optional_with_none"].default is None - assert cls.__fields__["optional_with_none"].default_factory is None # undefined - - @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) - def test_initialization(self, cls: t.Type[UseNoneModel]) -> None: - # We can build a model without providing values. - # The initialized value will be the default value or `None` for optional values. - # Note that the mandatory fields are not required anymore, and can be `None`. - obj = cls() - assert obj.mandatory is None - assert obj.mandatory_with_default is None - assert obj.mandatory_with_none is None - assert obj.optional is None - assert obj.optional_with_default is None - assert obj.optional_with_none is None - - # If we convert the model to a dictionary, without `None` values, - # we should have an empty dictionary. - actual = obj.dict(exclude_none=True) - expected = {} - assert actual == expected - - @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) - def test_validation(self, cls: t.Type[UseNoneModel]) -> None: - # We can use `None` as a value for all fields. - cls(mandatory=None) - cls(mandatory_with_default=None) - cls(mandatory_with_none=None) - cls(optional=None) - cls(optional_with_default=None) - cls(optional_with_none=None) - - # We can validate a model with valid values. - cls( - mandatory=0.5, - mandatory_with_default=0.2, - mandatory_with_none=0.2, - optional=0.5, - optional_with_default=0.2, - optional_with_none=0.2, - ) - - # We CANNOT validate a model with invalid values. - with pytest.raises(ValidationError): - cls(mandatory=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_none=2) - - with pytest.raises(ValidationError): - cls(optional=2) - - with pytest.raises(ValidationError): - cls(optional_with_default=2) - - with pytest.raises(ValidationError): - cls(optional_with_none=2) +def test_model() -> None: + optional_model = OptionalModel() + assert optional_model.float_with_default is None + assert optional_model.float_without_default is None + assert optional_model.boolean_with_default is None + assert optional_model.boolean_without_default is None + assert optional_model.field_with_alias is None + + optional_model = OptionalModel(boolean_with_default=False) + assert optional_model.float_with_default is None + assert optional_model.float_without_default is None + assert optional_model.boolean_with_default is False + assert optional_model.boolean_without_default is None + assert optional_model.field_with_alias is None + + # build with alias should succeed + args = {"field-with-alias": "test"} + optional_model = OptionalModel(**args) + assert optional_model.field_with_alias == "test" + + # build with camel_case should succeed + args = {"fieldWithAlias": "test"} + camel_case_model = OptionalCamelCaseModel(**args) + assert camel_case_model.field_with_alias == "test" diff --git a/tests/study/business/test_allocation_manager.py b/tests/study/business/test_allocation_manager.py index 82f49b2ec4..cfc889fed9 100644 --- a/tests/study/business/test_allocation_manager.py +++ b/tests/study/business/test_allocation_manager.py @@ -35,7 +35,7 @@ def test_base(self): def test_camel_case(self): field = AllocationField(areaId="NORTH", coefficient=1) - assert field.dict(by_alias=True) == { + assert field.model_dump(by_alias=True) == { "areaId": "NORTH", "coefficient": 1, } diff --git a/tests/study/storage/variantstudy/model/test_dbmodel.py b/tests/study/storage/variantstudy/model/test_dbmodel.py index 6ed1bbcba1..e541d6b8ab 100644 --- a/tests/study/storage/variantstudy/model/test_dbmodel.py +++ b/tests/study/storage/variantstudy/model/test_dbmodel.py @@ -152,7 +152,7 @@ def test_init(self, db_session: Session, variant_study_id: str) -> None: # check CommandBlock.to_dto() dto = obj.to_dto() # note: it is easier to compare the dict representation of the DTO - assert dto.dict() == { + assert dto.model_dump() == { "id": command_id, "action": command, "args": json.loads(args), diff --git a/tests/study/storage/variantstudy/test_snapshot_generator.py b/tests/study/storage/variantstudy/test_snapshot_generator.py index e9de3da131..de069ab462 100644 --- a/tests/study/storage/variantstudy/test_snapshot_generator.py +++ b/tests/study/storage/variantstudy/test_snapshot_generator.py @@ -851,7 +851,7 @@ def test_generate__nominal_case( assert len(db_recorder.sql_statements) == 5, str(db_recorder) # Check: the variant generation must succeed. - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { @@ -1036,7 +1036,7 @@ def test_generate__with_denormalize_true( ) # Check the results - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { @@ -1159,7 +1159,7 @@ def test_generate__notification_failure( ) # Check the results - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { @@ -1241,7 +1241,7 @@ def test_generate__variant_of_variant( ) # Check the results - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index f7314cdaaa..cc6fba806f 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -905,9 +905,9 @@ def test_get_all__non_admin_permissions_filter( user_2 = User(id=102, name="user2") user_3 = User(id=103, name="user3") - group_1 = Group(id=101, name="group1") - group_2 = Group(id=102, name="group2") - group_3 = Group(id=103, name="group3") + group_1 = Group(id="101", name="group1") + group_2 = Group(id="102", name="group2") + group_3 = Group(id="103", name="group3") user_groups_mapping = {101: [group_2.id], 102: [group_1.id], 103: []} @@ -1179,23 +1179,23 @@ def test_update_tags( (None, [], False, ["5", "6"]), (None, [], True, ["1", "2", "3", "4", "7", "8"]), (None, [], None, ["1", "2", "3", "4", "5", "6", "7", "8"]), - (None, [1, 3, 5, 7], False, ["5"]), - (None, [1, 3, 5, 7], True, ["1", "3", "7"]), - (None, [1, 3, 5, 7], None, ["1", "3", "5", "7"]), + (None, ["1", "3", "5", "7"], False, ["5"]), + (None, ["1", "3", "5", "7"], True, ["1", "3", "7"]), + (None, ["1", "3", "5", "7"], None, ["1", "3", "5", "7"]), (True, [], False, ["5"]), (True, [], True, ["1", "2", "3", "4", "8"]), (True, [], None, ["1", "2", "3", "4", "5", "8"]), - (True, [1, 3, 5, 7], False, ["5"]), - (True, [1, 3, 5, 7], True, ["1", "3"]), - (True, [1, 3, 5, 7], None, ["1", "3", "5"]), - (True, [2, 4, 6, 8], True, ["2", "4", "8"]), - (True, [2, 4, 6, 8], None, ["2", "4", "8"]), + (True, ["1", "3", "5", "7"], False, ["5"]), + (True, ["1", "3", "5", "7"], True, ["1", "3"]), + (True, ["1", "3", "5", "7"], None, ["1", "3", "5"]), + (True, ["2", "4", "6", "8"], True, ["2", "4", "8"]), + (True, ["2", "4", "6", "8"], None, ["2", "4", "8"]), (False, [], False, ["6"]), (False, [], True, ["7"]), (False, [], None, ["6", "7"]), - (False, [1, 3, 5, 7], False, []), - (False, [1, 3, 5, 7], True, ["7"]), - (False, [1, 3, 5, 7], None, ["7"]), + (False, ["1", "3", "5", "7"], False, []), + (False, ["1", "3", "5", "7"], True, ["7"]), + (False, ["1", "3", "5", "7"], None, ["7"]), ], ) def test_count_studies__general_case( @@ -1209,14 +1209,14 @@ def test_count_studies__general_case( icache: Mock = Mock(spec=ICache) repository = StudyMetadataRepository(cache_service=icache, session=db_session) - study_1 = VariantStudy(id=1, name="study-1") - study_2 = VariantStudy(id=2, name="study-2") - study_3 = VariantStudy(id=3, name="study-3") - study_4 = VariantStudy(id=4, name="study-4") - study_5 = RawStudy(id=5, name="study-5", missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) - study_6 = RawStudy(id=6, name="study-6", missing=datetime.datetime.now(), workspace=test_workspace) - study_7 = RawStudy(id=7, name="study-7", missing=None, workspace=test_workspace) - study_8 = RawStudy(id=8, name="study-8", missing=None, workspace=DEFAULT_WORKSPACE_NAME) + study_1 = VariantStudy(id="1", name="study-1") + study_2 = VariantStudy(id="2", name="study-2") + study_3 = VariantStudy(id="3", name="study-3") + study_4 = VariantStudy(id="4", name="study-4") + study_5 = RawStudy(id="5", name="study-5", missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(id="6", name="study-6", missing=datetime.datetime.now(), workspace=test_workspace) + study_7 = RawStudy(id="7", name="study-7", missing=None, workspace=test_workspace) + study_8 = RawStudy(id="8", name="study-8", missing=None, workspace=DEFAULT_WORKSPACE_NAME) db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) db_session.commit() diff --git a/tests/variantstudy/model/command/test_create_cluster.py b/tests/variantstudy/model/command/test_create_cluster.py index 6554bbe6c2..34da4ecd5d 100644 --- a/tests/variantstudy/model/command/test_create_cluster.py +++ b/tests/variantstudy/model/command/test_create_cluster.py @@ -16,11 +16,13 @@ from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig from antarest.study.storage.variantstudy.model.command_context import CommandContext +GEN = np.random.default_rng(1000) + class TestCreateCluster: def test_init(self, command_context: CommandContext): - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() cl = CreateCluster( area_id="foo", cluster_name="Cluster1", @@ -40,7 +42,7 @@ def test_init(self, command_context: CommandContext): modulation_id = command_context.matrix_service.create(modulation) assert cl.area_id == "foo" assert cl.cluster_name == "Cluster1" - assert cl.parameters == {"group": "Nuclear", "nominalcapacity": "2400", "unitcount": "2"} + assert cl.parameters == {"group": "Nuclear", "nominalcapacity": 2400, "unitcount": 2} assert cl.prepro == f"matrix://{prepro_id}" assert cl.modulation == f"matrix://{modulation_id}" @@ -73,8 +75,8 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): "market-bid-cost": "30", } - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() command = CreateCluster( area_id=area_id, cluster_name=cluster_name, @@ -135,8 +137,8 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): ) def test_to_dto(self, command_context: CommandContext): - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() command = CreateCluster( area_id="foo", cluster_name="Cluster1", @@ -148,12 +150,12 @@ def test_to_dto(self, command_context: CommandContext): prepro_id = command_context.matrix_service.create(prepro) modulation_id = command_context.matrix_service.create(modulation) dto = command.to_dto() - assert dto.dict() == { + assert dto.model_dump() == { "action": "create_cluster", "args": { "area_id": "foo", "cluster_name": "Cluster1", - "parameters": {"group": "Nuclear", "nominalcapacity": "2400", "unitcount": "2"}, + "parameters": {"group": "Nuclear", "nominalcapacity": 2400, "unitcount": 2}, "prepro": prepro_id, "modulation": modulation_id, }, @@ -163,8 +165,8 @@ def test_to_dto(self, command_context: CommandContext): def test_match(command_context: CommandContext): - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() base = CreateCluster( area_id="foo", cluster_name="foo", @@ -223,8 +225,8 @@ def test_revert(command_context: CommandContext): def test_create_diff(command_context: CommandContext): - prepro_a = np.random.rand(365, 6).tolist() - modulation_a = np.random.rand(8760, 4).tolist() + prepro_a = GEN.random((365, 6)).tolist() + modulation_a = GEN.random((8760, 4)).tolist() base = CreateCluster( area_id="foo", cluster_name="foo", @@ -234,8 +236,8 @@ def test_create_diff(command_context: CommandContext): command_context=command_context, ) - prepro_b = np.random.rand(365, 6).tolist() - modulation_b = np.random.rand(8760, 4).tolist() + prepro_b = GEN.random((365, 6)).tolist() + modulation_b = GEN.random((8760, 4)).tolist() other_match = CreateCluster( area_id="foo", cluster_name="foo", diff --git a/tests/variantstudy/model/command/test_create_link.py b/tests/variantstudy/model/command/test_create_link.py index eb11d65398..abe1fb0547 100644 --- a/tests/variantstudy/model/command/test_create_link.py +++ b/tests/variantstudy/model/command/test_create_link.py @@ -25,16 +25,15 @@ def test_validation(self, empty_study: FileStudy, command_context: CommandContex area1_id = transform_name_to_id(area1) area2 = "Area2" - area2_id = transform_name_to_id(area2) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area2, "command_context": command_context, @@ -42,7 +41,7 @@ def test_validation(self, empty_study: FileStudy, command_context: CommandContex ).apply(empty_study) with pytest.raises(ValidationError): - create_link_command: ICommand = CreateLink( + CreateLink( area1=area1_id, area2=area1_id, parameters={}, @@ -61,21 +60,21 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): area3 = "Area3" area3_id = transform_name_to_id(area3) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area2, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area3, "command_context": command_context, @@ -133,7 +132,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): # TODO:assert matrix default content : 1 column, 8760 rows, value = 1 - output = CreateLink.parse_obj( + output = CreateLink.model_validate( { "area1": area1_id, "area2": area2_id, @@ -161,7 +160,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): "filter-year-by-year": "hourly", } - create_link_command: ICommand = CreateLink.parse_obj( + create_link_command: ICommand = CreateLink.model_validate( { "area1": area3_id, "area2": area1_id, diff --git a/tests/variantstudy/model/command/test_create_renewables_cluster.py b/tests/variantstudy/model/command/test_create_renewables_cluster.py index 78e6dcf15e..b3f4f6a922 100644 --- a/tests/variantstudy/model/command/test_create_renewables_cluster.py +++ b/tests/variantstudy/model/command/test_create_renewables_cluster.py @@ -34,7 +34,7 @@ def test_init(self, command_context: CommandContext) -> None: # Check the command data assert cl.area_id == "foo" assert cl.cluster_name == "Cluster1" - assert cl.parameters == {"group": "Solar Thermal", "nominalcapacity": "2400", "unitcount": "2"} + assert cl.parameters == {"group": "Solar Thermal", "nominalcapacity": 2400, "unitcount": 2} def test_validate_cluster_name(self, command_context: CommandContext) -> None: with pytest.raises(ValidationError, match="cluster_name"): @@ -119,12 +119,12 @@ def test_to_dto(self, command_context: CommandContext) -> None: command_context=command_context, ) dto = command.to_dto() - assert dto.dict() == { + assert dto.model_dump() == { "action": "create_renewables_cluster", # "renewables" with a final "s". "args": { "area_id": "foo", "cluster_name": "Cluster1", - "parameters": {"group": "Solar Thermal", "nominalcapacity": "2400", "unitcount": "2"}, + "parameters": {"group": "Solar Thermal", "nominalcapacity": 2400, "unitcount": 2}, }, "id": None, "version": 1, diff --git a/tests/variantstudy/model/command/test_create_st_storage.py b/tests/variantstudy/model/command/test_create_st_storage.py index 7e2a7fa7b5..db4dcb5c4f 100644 --- a/tests/variantstudy/model/command/test_create_st_storage.py +++ b/tests/variantstudy/model/command/test_create_st_storage.py @@ -17,6 +17,8 @@ from antarest.study.storage.variantstudy.model.command_context import CommandContext from antarest.study.storage.variantstudy.model.model import CommandDTO +GEN = np.random.default_rng(1000) + @pytest.fixture(name="recent_study") def recent_study_fixture(empty_study: FileStudy) -> FileStudy: @@ -63,8 +65,8 @@ def recent_study_fixture(empty_study: FileStudy) -> FileStudy: class TestCreateSTStorage: # noinspection SpellCheckingInspection def test_init(self, command_context: CommandContext): - pmax_injection = np.random.rand(8760, 1) - inflows = np.random.uniform(0, 1000, size=(8760, 1)) + pmax_injection = GEN.random((8760, 1)) + inflows = GEN.uniform(0, 1000, size=(8760, 1)) cmd = CreateSTStorage( command_context=command_context, area_id="area_fr", @@ -100,17 +102,22 @@ def test_init__invalid_storage_name(self, recent_study: FileStudy, command_conte parameters=STStorageConfig(**parameters), ) # We get 2 errors because the `storage_name` is duplicated in the `parameters`: - assert ctx.value.errors() == [ - { - "loc": ("__root__",), - "msg": "Invalid name '?%$$'.", - "type": "value_error", - } - ] + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Invalid name '?%$$'." + assert raised_error["input"] == { + "efficiency": 0.94, + "group": "Battery", + "initialleveloptim": True, + "injectionnominalcapacity": 1500, + "name": "?%$$", + "reservoircapacity": 20000, + "withdrawalnominalcapacity": 1500, + } - # noinspection SpellCheckingInspection def test_init__invalid_matrix_values(self, command_context: CommandContext): - array = np.random.rand(8760, 1) # OK + array = GEN.random((8760, 1)) array[10] = 25 # BAD with pytest.raises(ValidationError) as ctx: CreateSTStorage( @@ -119,17 +126,15 @@ def test_init__invalid_matrix_values(self, command_context: CommandContext): parameters=STStorageConfig(**PARAMETERS), pmax_injection=array.tolist(), # type: ignore ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - } - ] + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Matrix values should be between 0 and 1" + assert "pmax_injection" in raised_error["input"] # noinspection SpellCheckingInspection def test_init__invalid_matrix_shape(self, command_context: CommandContext): - array = np.random.rand(24, 1) # BAD SHAPE + array = GEN.random((24, 1)) # BAD SHAPE with pytest.raises(ValidationError) as ctx: CreateSTStorage( command_context=command_context, @@ -137,18 +142,14 @@ def test_init__invalid_matrix_shape(self, command_context: CommandContext): parameters=STStorageConfig(**PARAMETERS), pmax_injection=array.tolist(), # type: ignore ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "Invalid matrix shape (24, 1), expected (8760, 1)", - "type": "value_error", - } - ] - - # noinspection SpellCheckingInspection + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Invalid matrix shape (24, 1), expected (8760, 1)" + assert "pmax_injection" in raised_error["input"] def test_init__invalid_nan_value(self, command_context: CommandContext): - array = np.random.rand(8760, 1) # OK + array = GEN.random((8760, 1)) # OK array[20] = np.nan # BAD with pytest.raises(ValidationError) as ctx: CreateSTStorage( @@ -157,37 +158,25 @@ def test_init__invalid_nan_value(self, command_context: CommandContext): parameters=STStorageConfig(**PARAMETERS), pmax_injection=array.tolist(), # type: ignore ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "Matrix values cannot contain NaN", - "type": "value_error", - } - ] - - # noinspection SpellCheckingInspection + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Matrix values cannot contain NaN" + assert "pmax_injection" in raised_error["input"] def test_init__invalid_matrix_type(self, command_context: CommandContext): - array = {"data": [1, 2, 3]} with pytest.raises(ValidationError) as ctx: CreateSTStorage( command_context=command_context, area_id="area_fr", parameters=STStorageConfig(**PARAMETERS), - pmax_injection=array, # type: ignore + pmax_injection=[1, 2, 3], ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "value is not a valid list", - "type": "type_error.list", - }, - { - "loc": ("pmax_injection",), - "msg": "str type expected", - "type": "type_error.str", - }, - ] + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Invalid matrix shape (3,), expected (8760, 1)" + assert "pmax_injection" in raised_error["input"] def test_apply_config__invalid_version(self, empty_study: FileStudy, command_context: CommandContext): # Given an old study in version 720 @@ -293,8 +282,8 @@ def test_apply__nominal_case(self, recent_study: FileStudy, command_context: Com create_area.apply(recent_study) # Then, apply the command to create a new ST Storage - pmax_injection = np.random.rand(8760, 1) - inflows = np.random.uniform(0, 1000, size=(8760, 1)) + pmax_injection = GEN.random((8760, 1)) + inflows = GEN.uniform(0, 1000, size=(8760, 1)) cmd = CreateSTStorage( command_context=command_context, area_id=transform_name_to_id(create_area.area_name), @@ -427,8 +416,8 @@ def test_create_diff__not_equals(self, command_context: CommandContext): area_id="area_fr", parameters=STStorageConfig(**PARAMETERS), ) - upper_rule_curve = np.random.rand(8760, 1) - inflows = np.random.uniform(0, 1000, size=(8760, 1)) + upper_rule_curve = GEN.random((8760, 1)) + inflows = GEN.uniform(0, 1000, size=(8760, 1)) other = CreateSTStorage( command_context=command_context, area_id=cmd.area_id, diff --git a/tests/variantstudy/model/command/test_manage_district.py b/tests/variantstudy/model/command/test_manage_district.py index 78d30bb19f..2e4b93180b 100644 --- a/tests/variantstudy/model/command/test_manage_district.py +++ b/tests/variantstudy/model/command/test_manage_district.py @@ -14,7 +14,6 @@ def test_manage_district(empty_study: FileStudy, command_context: CommandContext): - study_path = empty_study.config.study_path area1 = "Area1" area1_id = transform_name_to_id(area1) @@ -22,23 +21,22 @@ def test_manage_district(empty_study: FileStudy, command_context: CommandContext area2_id = transform_name_to_id(area2) area3 = "Area3" - area3_id = transform_name_to_id(area3) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area2, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area3, "command_context": command_context, diff --git a/tests/variantstudy/model/command/test_remove_area.py b/tests/variantstudy/model/command/test_remove_area.py index 8849bffbd3..31d404166b 100644 --- a/tests/variantstudy/model/command/test_remove_area.py +++ b/tests/variantstudy/model/command/test_remove_area.py @@ -200,7 +200,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): if empty_study.config.version >= 810: default_ruleset[f"r,{area_id2},0,{renewable_id.lower()}"] = 1 if empty_study.config.version >= 870: - default_ruleset[f"bc,bd 2,0"] = 1 + default_ruleset["bc,bd 2,0"] = 1 if empty_study.config.version >= 920: default_ruleset[f"hfl,{area_id2},0"] = 1 if empty_study.config.version >= 910: diff --git a/tests/variantstudy/model/command/test_remove_link.py b/tests/variantstudy/model/command/test_remove_link.py index 2704a54013..c95147e838 100644 --- a/tests/variantstudy/model/command/test_remove_link.py +++ b/tests/variantstudy/model/command/test_remove_link.py @@ -58,7 +58,7 @@ def test_remove_link__validation(self, area1: str, area2: str, expected: t.Dict[ and that the areas are well-ordered in alphabetical order (Antares Solver convention). """ command = RemoveLink(area1=area1, area2=area2, command_context=Mock(spec=CommandContext)) - actual = command.dict(include={"area1", "area2"}) + actual = command.model_dump(include={"area1", "area2"}) assert actual == expected @staticmethod diff --git a/tests/variantstudy/model/command/test_remove_st_storage.py b/tests/variantstudy/model/command/test_remove_st_storage.py index 944f3b57b4..a3f103a13f 100644 --- a/tests/variantstudy/model/command/test_remove_st_storage.py +++ b/tests/variantstudy/model/command/test_remove_st_storage.py @@ -71,9 +71,11 @@ def test_init__invalid_storage_id(self, recent_study: FileStudy, command_context assert ctx.value.errors() == [ { "ctx": {"pattern": "[a-z0-9_(),& -]+"}, + "input": "?%$$", "loc": ("storage_id",), - "msg": 'string does not match regex "[a-z0-9_(),& -]+"', - "type": "value_error.str.regex", + "msg": "String should match pattern '[a-z0-9_(),& -]+'", + "type": "string_pattern_mismatch", + "url": "https://errors.pydantic.dev/2.8/v/string_pattern_mismatch", } ] diff --git a/tests/variantstudy/model/command/test_replace_matrix.py b/tests/variantstudy/model/command/test_replace_matrix.py index 5436f1e98d..f776b211db 100644 --- a/tests/variantstudy/model/command/test_replace_matrix.py +++ b/tests/variantstudy/model/command/test_replace_matrix.py @@ -20,7 +20,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): area1 = "Area1" area1_id = transform_name_to_id(area1) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, @@ -28,7 +28,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): ).apply(empty_study) target_element = f"input/hydro/common/capacity/maxpower_{area1_id}" - replace_matrix = ReplaceMatrix.parse_obj( + replace_matrix = ReplaceMatrix.model_validate( { "target": target_element, "matrix": [[0]], @@ -44,7 +44,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): assert matrix_id in target_path.read_text() target_element = "fake/matrix/path" - replace_matrix = ReplaceMatrix.parse_obj( + replace_matrix = ReplaceMatrix.model_validate( { "target": target_element, "matrix": [[0]], diff --git a/tests/variantstudy/model/command/test_update_config.py b/tests/variantstudy/model/command/test_update_config.py index 999adb6c70..575f165241 100644 --- a/tests/variantstudy/model/command/test_update_config.py +++ b/tests/variantstudy/model/command/test_update_config.py @@ -20,7 +20,7 @@ def test_update_config(empty_study: FileStudy, command_context: CommandContext): area1 = "Area1" area1_id = transform_name_to_id(area1) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, diff --git a/tests/variantstudy/model/test_variant_model.py b/tests/variantstudy/model/test_variant_model.py index 98c73b949f..89e10bd3c0 100644 --- a/tests/variantstudy/model/test_variant_model.py +++ b/tests/variantstudy/model/test_variant_model.py @@ -141,7 +141,7 @@ def test_commands_service( repository=variant_study_service.repository, ) results = generator.generate_snapshot(saved_id, jwt_user, denormalize=False) - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ {