From b3c962ce0e981412a9be4c3120af082dbc6ce24b Mon Sep 17 00:00:00 2001 From: MartinBelthle Date: Wed, 2 Oct 2024 10:46:59 +0200 Subject: [PATCH] fix(pydantic): allow `str` fields to be populated by `int` (#2166) --- antarest/core/cache/business/local_chache.py | 5 +- antarest/core/cache/business/redis_cache.py | 5 +- antarest/core/configdata/model.py | 4 +- antarest/core/core_blueprint.py | 4 +- antarest/core/filetransfer/model.py | 6 +- antarest/core/interfaces/eventbus.py | 5 +- antarest/core/jwt.py | 7 +- antarest/core/model.py | 4 +- antarest/core/serialization/__init__.py | 15 ++ antarest/core/tasks/model.py | 14 +- antarest/core/utils/__init__.py | 2 + antarest/core/utils/utils.py | 13 -- antarest/core/version_info.py | 4 +- antarest/eventbus/web.py | 5 +- antarest/front.py | 4 +- antarest/launcher/adapters/log_parser.py | 4 +- antarest/launcher/model.py | 16 +- antarest/launcher/ssh_config.py | 6 +- antarest/login/auth.py | 5 +- antarest/login/model.py | 26 +-- antarest/login/web.py | 5 +- antarest/matrixstore/matrix_editor.py | 10 +- antarest/matrixstore/model.py | 12 +- antarest/service_creator.py | 14 +- antarest/study/business/area_management.py | 9 +- .../business/areas/st_storage_management.py | 7 +- .../business/binding_constraint_management.py | 11 +- antarest/study/business/district_manager.py | 5 +- antarest/study/business/link_management.py | 7 +- .../study/business/xpansion_management.py | 9 +- antarest/study/model.py | 39 ++--- antarest/study/repository.py | 9 +- .../model/filesystem/config/cluster.py | 6 +- .../model/filesystem/config/identifier.py | 6 +- .../rawstudy/model/filesystem/config/model.py | 15 +- .../command/create_binding_constraint.py | 7 +- .../variantstudy/model/command_context.py | 5 +- .../study/storage/variantstudy/model/model.py | 8 +- antarest/worker/archive_worker.py | 5 +- antarest/worker/simulator_worker.py | 5 +- antarest/worker/worker.py | 7 +- .../study_data_blueprint/test_thermal.py | 149 +++++------------- .../business/test_all_optional_metaclass.py | 5 +- 43 files changed, 225 insertions(+), 284 deletions(-) diff --git a/antarest/core/cache/business/local_chache.py b/antarest/core/cache/business/local_chache.py index ac2a026db5..e903ff9080 100644 --- a/antarest/core/cache/business/local_chache.py +++ b/antarest/core/cache/business/local_chache.py @@ -15,16 +15,15 @@ import time from typing import Dict, List, Optional -from pydantic import BaseModel - from antarest.core.config import CacheConfig from antarest.core.interfaces.cache import ICache from antarest.core.model import JSON +from antarest.core.serialization import AntaresBaseModel logger = logging.getLogger(__name__) -class LocalCacheElement(BaseModel): +class LocalCacheElement(AntaresBaseModel): timeout: int duration: int data: JSON diff --git a/antarest/core/cache/business/redis_cache.py b/antarest/core/cache/business/redis_cache.py index 7793f280c7..11eb3fcffd 100644 --- a/antarest/core/cache/business/redis_cache.py +++ b/antarest/core/cache/business/redis_cache.py @@ -13,17 +13,16 @@ import logging from typing import List, Optional -from pydantic import BaseModel from redis.client import Redis from antarest.core.interfaces.cache import ICache from antarest.core.model import JSON -from antarest.core.serialization import from_json +from antarest.core.serialization import AntaresBaseModel, from_json logger = logging.getLogger(__name__) -class RedisCacheElement(BaseModel): +class RedisCacheElement(AntaresBaseModel): duration: int data: JSON diff --git a/antarest/core/configdata/model.py b/antarest/core/configdata/model.py index 3e0d6b970e..bd243387ec 100644 --- a/antarest/core/configdata/model.py +++ b/antarest/core/configdata/model.py @@ -13,13 +13,13 @@ from enum import Enum from typing import Any, Optional -from pydantic import BaseModel from sqlalchemy import Column, Integer, String # type: ignore from antarest.core.persistence import Base +from antarest.core.serialization import AntaresBaseModel -class ConfigDataDTO(BaseModel): +class ConfigDataDTO(AntaresBaseModel): key: str value: Optional[str] diff --git a/antarest/core/core_blueprint.py b/antarest/core/core_blueprint.py index 27f6591109..d344531699 100644 --- a/antarest/core/core_blueprint.py +++ b/antarest/core/core_blueprint.py @@ -13,14 +13,14 @@ from typing import Any from fastapi import APIRouter -from pydantic import BaseModel from antarest.core.config import Config +from antarest.core.serialization import AntaresBaseModel from antarest.core.utils.web import APITag from antarest.core.version_info import VersionInfoDTO, get_commit_id, get_dependencies -class StatusDTO(BaseModel): +class StatusDTO(AntaresBaseModel): status: str diff --git a/antarest/core/filetransfer/model.py b/antarest/core/filetransfer/model.py index 72463e0bad..6ce194ed19 100644 --- a/antarest/core/filetransfer/model.py +++ b/antarest/core/filetransfer/model.py @@ -15,10 +15,10 @@ from http.client import HTTPException from typing import Optional -from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, Integer, String # type: ignore from antarest.core.persistence import Base +from antarest.core.serialization import AntaresBaseModel class FileDownloadNotFound(HTTPException): @@ -37,7 +37,7 @@ def __init__(self) -> None: ) -class FileDownloadDTO(BaseModel): +class FileDownloadDTO(AntaresBaseModel): id: str name: str filename: str @@ -47,7 +47,7 @@ class FileDownloadDTO(BaseModel): error_message: str = "" -class FileDownloadTaskDTO(BaseModel): +class FileDownloadTaskDTO(AntaresBaseModel): file: FileDownloadDTO task: str diff --git a/antarest/core/interfaces/eventbus.py b/antarest/core/interfaces/eventbus.py index 10965ea831..6057a3bbb3 100644 --- a/antarest/core/interfaces/eventbus.py +++ b/antarest/core/interfaces/eventbus.py @@ -14,9 +14,8 @@ from enum import Enum from typing import Any, Awaitable, Callable, List, Optional -from pydantic import BaseModel - from antarest.core.model import PermissionInfo +from antarest.core.serialization import AntaresBaseModel class EventType(str, Enum): @@ -56,7 +55,7 @@ class EventChannelDirectory: STUDY_GENERATION = "GENERATION_TASK/" -class Event(BaseModel): +class Event(AntaresBaseModel): type: EventType payload: Any permissions: PermissionInfo diff --git a/antarest/core/jwt.py b/antarest/core/jwt.py index b42cc3273b..aa8323abb8 100644 --- a/antarest/core/jwt.py +++ b/antarest/core/jwt.py @@ -12,13 +12,12 @@ from typing import List, Union -from pydantic import BaseModel - from antarest.core.roles import RoleType +from antarest.core.serialization import AntaresBaseModel from antarest.login.model import ADMIN_ID, Group, Identity -class JWTGroup(BaseModel): +class JWTGroup(AntaresBaseModel): """ Sub JWT domain with groups data belongs to user """ @@ -28,7 +27,7 @@ class JWTGroup(BaseModel): role: RoleType -class JWTUser(BaseModel): +class JWTUser(AntaresBaseModel): """ JWT domain with user data. """ diff --git a/antarest/core/model.py b/antarest/core/model.py index b500e9a5c5..dd4ea511aa 100644 --- a/antarest/core/model.py +++ b/antarest/core/model.py @@ -13,7 +13,7 @@ import enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from pydantic import BaseModel +from antarest.core.serialization import AntaresBaseModel if TYPE_CHECKING: # These dependencies are only used for type checking with mypy. @@ -43,7 +43,7 @@ class StudyPermissionType(str, enum.Enum): MANAGE_PERMISSIONS = "MANAGE_PERMISSIONS" -class PermissionInfo(BaseModel): +class PermissionInfo(AntaresBaseModel): owner: Optional[int] = None groups: List[str] = [] public_mode: PublicMode = PublicMode.NONE diff --git a/antarest/core/serialization/__init__.py b/antarest/core/serialization/__init__.py index 6368c02f1e..a8616e3eae 100644 --- a/antarest/core/serialization/__init__.py +++ b/antarest/core/serialization/__init__.py @@ -32,3 +32,18 @@ def to_json(data: t.Any, indent: t.Optional[int] = None) -> bytes: def to_json_string(data: t.Any, indent: t.Optional[int] = None) -> str: return to_json(data, indent=indent).decode("utf-8") + + +class AntaresBaseModel(pydantic.BaseModel): + """ + Due to pydantic migration from v1 to v2, we can have this issue: + + class A(BaseModel): + a: str + + A(a=2) raises ValidationError as we give an int instead of a str + + To avoid this issue we created our own BaseModel class that inherits from BaseModel and allows such object creation. + """ + + model_config = pydantic.config.ConfigDict(coerce_numbers_to_str=True) diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index a7ca9aedb9..9bd921fd0c 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -15,12 +15,12 @@ from datetime import datetime from enum import Enum -from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.orm import relationship, sessionmaker # type: ignore from antarest.core.persistence import Base +from antarest.core.serialization import AntaresBaseModel if t.TYPE_CHECKING: # avoid circular import @@ -57,30 +57,30 @@ def is_final(self) -> bool: ] -class TaskResult(BaseModel, extra="forbid"): +class TaskResult(AntaresBaseModel, extra="forbid"): success: bool message: str # Can be used to store json serialized result return_value: t.Optional[str] = None -class TaskLogDTO(BaseModel, extra="forbid"): +class TaskLogDTO(AntaresBaseModel, extra="forbid"): id: str message: str -class CustomTaskEventMessages(BaseModel, extra="forbid"): +class CustomTaskEventMessages(AntaresBaseModel, extra="forbid"): start: str running: str end: str -class TaskEventPayload(BaseModel, extra="forbid"): +class TaskEventPayload(AntaresBaseModel, extra="forbid"): id: str message: str -class TaskDTO(BaseModel, extra="forbid"): +class TaskDTO(AntaresBaseModel, extra="forbid"): id: str name: str owner: t.Optional[int] = None @@ -93,7 +93,7 @@ class TaskDTO(BaseModel, extra="forbid"): ref_id: t.Optional[str] = None -class TaskListFilter(BaseModel, extra="forbid"): +class TaskListFilter(AntaresBaseModel, extra="forbid"): status: t.List[TaskStatus] = [] name: t.Optional[str] = None type: t.List[TaskType] = [] diff --git a/antarest/core/utils/__init__.py b/antarest/core/utils/__init__.py index 058c6b221a..d1e93fb6a8 100644 --- a/antarest/core/utils/__init__.py +++ b/antarest/core/utils/__init__.py @@ -9,3 +9,5 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. + +__all__ = "AntaresBaseModel" diff --git a/antarest/core/utils/utils.py b/antarest/core/utils/utils.py index 89002edcfe..63576fc1a4 100644 --- a/antarest/core/utils/utils.py +++ b/antarest/core/utils/utils.py @@ -24,10 +24,8 @@ from pathlib import Path import py7zr -import redis from fastapi import HTTPException -from antarest.core.config import RedisConfig from antarest.core.exceptions import ShouldNotHappenException logger = logging.getLogger(__name__) @@ -131,17 +129,6 @@ def get_local_path() -> Path: return filepath -def new_redis_instance(config: RedisConfig) -> redis.Redis: # type: ignore - redis_client = redis.Redis( - host=config.host, - port=config.port, - password=config.password, - db=0, - retry_on_error=[redis.ConnectionError, redis.TimeoutError], # type: ignore - ) - return redis_client # type: ignore - - class StopWatch: def __init__(self) -> None: self.current_time: float = time.time() diff --git a/antarest/core/version_info.py b/antarest/core/version_info.py index ea96eaf4ef..ec20a767c9 100644 --- a/antarest/core/version_info.py +++ b/antarest/core/version_info.py @@ -18,10 +18,10 @@ from pathlib import Path from typing import Dict -from pydantic import BaseModel +from antarest.core.serialization import AntaresBaseModel -class VersionInfoDTO(BaseModel): +class VersionInfoDTO(AntaresBaseModel): name: str = "AntaREST" version: str gitcommit: str diff --git a/antarest/eventbus/web.py b/antarest/eventbus/web.py index a04a4571bb..7f7050a793 100644 --- a/antarest/eventbus/web.py +++ b/antarest/eventbus/web.py @@ -17,7 +17,6 @@ from typing import List, Optional from fastapi import Depends, HTTPException, Query -from pydantic import BaseModel from starlette.websockets import WebSocket, WebSocketDisconnect from antarest.core.application import AppBuildContext @@ -26,7 +25,7 @@ from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser from antarest.core.model import PermissionInfo, StudyPermissionType from antarest.core.permissions import check_permission -from antarest.core.serialization import to_json_string +from antarest.core.serialization import AntaresBaseModel, to_json_string from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth @@ -38,7 +37,7 @@ class WebsocketMessageAction(str, Enum): UNSUBSCRIBE = "UNSUBSCRIBE" -class WebsocketMessage(BaseModel): +class WebsocketMessage(AntaresBaseModel): action: WebsocketMessageAction payload: str diff --git a/antarest/front.py b/antarest/front.py index 8de0f05e82..a0699812bf 100644 --- a/antarest/front.py +++ b/antarest/front.py @@ -25,13 +25,13 @@ from typing import Any, Optional, Sequence from fastapi import FastAPI -from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import FileResponse from starlette.staticfiles import StaticFiles from starlette.types import ASGIApp +from antarest.core.serialization import AntaresBaseModel from antarest.core.utils.string import to_camel_case @@ -77,7 +77,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - return await call_next(request) -class BackEndConfig(BaseModel): +class BackEndConfig(AntaresBaseModel): """ Configuration about backend URLs served to the frontend. """ diff --git a/antarest/launcher/adapters/log_parser.py b/antarest/launcher/adapters/log_parser.py index efd73d1b70..16cadd74ea 100644 --- a/antarest/launcher/adapters/log_parser.py +++ b/antarest/launcher/adapters/log_parser.py @@ -14,7 +14,7 @@ import re import typing as t -from pydantic import BaseModel +from antarest.core.serialization import AntaresBaseModel _SearchFunc = t.Callable[[str], t.Optional[t.Match[str]]] @@ -63,7 +63,7 @@ ) -class LaunchProgressDTO(BaseModel): +class LaunchProgressDTO(AntaresBaseModel): """ Measure the progress of a study simulation. diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index d80400a4ba..d053a55c85 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -14,23 +14,23 @@ import typing as t from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import Field from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.orm import relationship # type: ignore from antarest.core.persistence import Base -from antarest.core.serialization import from_json +from antarest.core.serialization import AntaresBaseModel, from_json from antarest.login.model import Identity, UserInfo from antarest.study.business.all_optional_meta import camel_case_model -class XpansionParametersDTO(BaseModel): +class XpansionParametersDTO(AntaresBaseModel): output_id: t.Optional[str] = None sensitivity_mode: bool = False enabled: bool = True -class LauncherParametersDTO(BaseModel): +class LauncherParametersDTO(AntaresBaseModel): # Warning ! This class must be retro-compatible (that's the reason for the weird bool/XpansionParametersDTO union) # The reason is that it's stored in json format in database and deserialized using the latest class version # If compatibility is to be broken, an (alembic) data migration script should be added @@ -91,7 +91,7 @@ class JobLogType(str, enum.Enum): AFTER = "AFTER" -class JobResultDTO(BaseModel): +class JobResultDTO(AntaresBaseModel): """ A data transfer object (DTO) representing the job result. @@ -232,16 +232,16 @@ def __repr__(self) -> str: ) -class JobCreationDTO(BaseModel): +class JobCreationDTO(AntaresBaseModel): job_id: str -class LauncherEnginesDTO(BaseModel): +class LauncherEnginesDTO(AntaresBaseModel): engines: t.List[str] @camel_case_model -class LauncherLoadDTO(BaseModel, extra="forbid", validate_assignment=True, populate_by_name=True): +class LauncherLoadDTO(AntaresBaseModel, extra="forbid", validate_assignment=True, populate_by_name=True): """ DTO representing the load of the SLURM cluster or local machine. diff --git a/antarest/launcher/ssh_config.py b/antarest/launcher/ssh_config.py index 5238e07608..7d4524d04d 100644 --- a/antarest/launcher/ssh_config.py +++ b/antarest/launcher/ssh_config.py @@ -14,10 +14,12 @@ from typing import Any, Dict, Optional import paramiko -from pydantic import BaseModel, model_validator +from pydantic import model_validator +from antarest.core.serialization import AntaresBaseModel -class SSHConfigDTO(BaseModel): + +class SSHConfigDTO(AntaresBaseModel): config_path: pathlib.Path username: str hostname: str diff --git a/antarest/login/auth.py b/antarest/login/auth.py index f86ca5903b..e0227a51f5 100644 --- a/antarest/login/auth.py +++ b/antarest/login/auth.py @@ -15,13 +15,12 @@ from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union from fastapi import Depends -from pydantic import BaseModel from ratelimit.types import Scope # type: ignore from starlette.requests import Request from antarest.core.config import Config from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser -from antarest.core.serialization import from_json +from antarest.core.serialization import AntaresBaseModel, from_json from antarest.fastapi_jwt_auth import AuthJWT logger = logging.getLogger(__name__) @@ -79,7 +78,7 @@ def get_user_from_token(token: str, jwt_manager: AuthJWT) -> Optional[JWTUser]: return None -class JwtSettings(BaseModel): +class JwtSettings(AntaresBaseModel): authjwt_secret_key: str authjwt_token_location: Tuple[str, ...] authjwt_access_token_expires: Union[int, timedelta] = Auth.ACCESS_TOKEN_DURATION diff --git a/antarest/login/model.py b/antarest/login/model.py index 4f85763c9b..11a3bef802 100644 --- a/antarest/login/model.py +++ b/antarest/login/model.py @@ -15,7 +15,6 @@ import uuid import bcrypt -from pydantic.main import BaseModel from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.exc import IntegrityError # type: ignore @@ -24,6 +23,7 @@ from antarest.core.persistence import Base from antarest.core.roles import RoleType +from antarest.core.serialization import AntaresBaseModel if t.TYPE_CHECKING: # avoid circular import @@ -44,58 +44,58 @@ """Name of the site administrator.""" -class UserInfo(BaseModel): +class UserInfo(AntaresBaseModel): id: int name: str -class BotRoleCreateDTO(BaseModel): +class BotRoleCreateDTO(AntaresBaseModel): group: str role: int -class BotCreateDTO(BaseModel): +class BotCreateDTO(AntaresBaseModel): name: str roles: t.List[BotRoleCreateDTO] is_author: bool = True -class UserCreateDTO(BaseModel): +class UserCreateDTO(AntaresBaseModel): name: str password: str -class GroupDTO(BaseModel): +class GroupDTO(AntaresBaseModel): id: t.Optional[str] = None name: str -class RoleCreationDTO(BaseModel): +class RoleCreationDTO(AntaresBaseModel): type: RoleType group_id: str identity_id: int -class RoleDTO(BaseModel): +class RoleDTO(AntaresBaseModel): group_id: t.Optional[str] group_name: str identity_id: int type: RoleType -class IdentityDTO(BaseModel): +class IdentityDTO(AntaresBaseModel): id: int name: str roles: t.List[RoleDTO] -class RoleDetailDTO(BaseModel): +class RoleDetailDTO(AntaresBaseModel): group: GroupDTO identity: UserInfo type: RoleType -class BotIdentityDTO(BaseModel): +class BotIdentityDTO(AntaresBaseModel): id: int name: str isAuthor: bool @@ -107,7 +107,7 @@ class BotDTO(UserInfo): is_author: bool -class UserRoleDTO(BaseModel): +class UserRoleDTO(AntaresBaseModel): id: int name: str role: RoleType @@ -311,7 +311,7 @@ def to_dto(self) -> RoleDetailDTO: ) -class CredentialsDTO(BaseModel): +class CredentialsDTO(AntaresBaseModel): user: int access_token: str refresh_token: str diff --git a/antarest/login/web.py b/antarest/login/web.py index 5bc85c62a1..6f3968d6a9 100644 --- a/antarest/login/web.py +++ b/antarest/login/web.py @@ -16,13 +16,12 @@ from fastapi import APIRouter, Depends, HTTPException from markupsafe import escape -from pydantic import BaseModel from antarest.core.config import Config from antarest.core.jwt import JWTGroup, JWTUser from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.roles import RoleType -from antarest.core.serialization import from_json +from antarest.core.serialization import AntaresBaseModel, from_json from antarest.core.utils.web import APITag from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth @@ -46,7 +45,7 @@ logger = logging.getLogger(__name__) -class UserCredentials(BaseModel): +class UserCredentials(AntaresBaseModel): username: str password: str diff --git a/antarest/matrixstore/matrix_editor.py b/antarest/matrixstore/matrix_editor.py index 838af83860..89bb866336 100644 --- a/antarest/matrixstore/matrix_editor.py +++ b/antarest/matrixstore/matrix_editor.py @@ -14,10 +14,12 @@ import operator from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import Field, field_validator, model_validator +from antarest.core.serialization import AntaresBaseModel -class MatrixSlice(BaseModel): + +class MatrixSlice(AntaresBaseModel): # NOTE: This Markdown documentation is reflected in the Swagger API """ Represents a group of cells in a matrix for updating. @@ -97,7 +99,7 @@ def check_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: @functools.total_ordering -class Operation(BaseModel): +class Operation(AntaresBaseModel): # NOTE: This Markdown documentation is reflected in the Swagger API """ Represents an update operation to be performed on matrix cells. @@ -140,7 +142,7 @@ def __le__(self, other: Any) -> bool: return NotImplemented # pragma: no cover -class MatrixEditInstruction(BaseModel): +class MatrixEditInstruction(AntaresBaseModel): # NOTE: This Markdown documentation is reflected in the Swagger API """ Provides edit instructions to be applied to a matrix. diff --git a/antarest/matrixstore/model.py b/antarest/matrixstore/model.py index 244cafadca..2dccd350ab 100644 --- a/antarest/matrixstore/model.py +++ b/antarest/matrixstore/model.py @@ -14,11 +14,11 @@ import typing as t import uuid -from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Table # type: ignore from sqlalchemy.orm import relationship # type: ignore from antarest.core.persistence import Base +from antarest.core.serialization import AntaresBaseModel from antarest.login.model import GroupDTO, Identity, UserInfo @@ -58,12 +58,12 @@ def __eq__(self, other: t.Any) -> bool: return res -class MatrixInfoDTO(BaseModel): +class MatrixInfoDTO(AntaresBaseModel): id: str name: str -class MatrixDataSetDTO(BaseModel): +class MatrixDataSetDTO(AntaresBaseModel): id: str name: str matrices: t.List[MatrixInfoDTO] @@ -209,7 +209,7 @@ def __eq__(self, other: t.Any) -> bool: MatrixData = float -class MatrixDTO(BaseModel): +class MatrixDTO(AntaresBaseModel): width: int height: int index: t.List[str] @@ -219,7 +219,7 @@ class MatrixDTO(BaseModel): id: str = "" -class MatrixContent(BaseModel): +class MatrixContent(AntaresBaseModel): """ Matrix content (Data Frame array) @@ -234,7 +234,7 @@ class MatrixContent(BaseModel): columns: t.List[t.Union[int, str]] -class MatrixDataSetUpdateDTO(BaseModel): +class MatrixDataSetUpdateDTO(AntaresBaseModel): name: str groups: t.List[str] public: bool diff --git a/antarest/service_creator.py b/antarest/service_creator.py index 5859806a61..b942418b2c 100644 --- a/antarest/service_creator.py +++ b/antarest/service_creator.py @@ -25,7 +25,7 @@ from antarest.core.application import AppBuildContext from antarest.core.cache.main import build_cache -from antarest.core.config import Config +from antarest.core.config import Config, RedisConfig from antarest.core.filetransfer.main import build_filetransfer_service from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache @@ -34,7 +34,6 @@ from antarest.core.persistence import upgrade_db from antarest.core.tasks.main import build_taskjob_manager from antarest.core.tasks.service import ITaskService -from antarest.core.utils.utils import new_redis_instance from antarest.eventbus.main import build_eventbus from antarest.launcher.main import build_launcher from antarest.login.main import build_login @@ -109,6 +108,17 @@ def init_db_engine( return engine +def new_redis_instance(config: RedisConfig) -> redis.Redis: # type: ignore + redis_client = redis.Redis( + host=config.host, + port=config.port, + password=config.password, + db=0, + retry_on_error=[redis.ConnectionError, redis.TimeoutError], # type: ignore + ) + return redis_client # type: ignore + + def create_event_bus(app_ctxt: t.Optional[AppBuildContext], config: Config) -> t.Tuple[IEventBus, t.Optional[redis.Redis]]: # type: ignore redis_client = new_redis_instance(config.redis) if config.redis is not None else None return ( diff --git a/antarest/study/business/area_management.py b/antarest/study/business/area_management.py index 5682a6fa60..992cc31984 100644 --- a/antarest/study/business/area_management.py +++ b/antarest/study/business/area_management.py @@ -15,10 +15,11 @@ import re import typing as t -from pydantic import BaseModel, Field +from pydantic import Field from antarest.core.exceptions import ConfigFileNotFound, DuplicateAreaName, LayerNotAllowedToBeDeleted, LayerNotFound from antarest.core.model import JSON +from antarest.core.serialization import AntaresBaseModel from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Patch, PatchArea, PatchCluster, RawStudy, Study @@ -47,7 +48,7 @@ class AreaType(enum.Enum): DISTRICT = "DISTRICT" -class AreaCreationDTO(BaseModel): +class AreaCreationDTO(AntaresBaseModel): name: str type: AreaType metadata: t.Optional[PatchArea] = None @@ -76,13 +77,13 @@ class AreaInfoDTO(AreaCreationDTO): thermals: t.Optional[t.List[ClusterInfoDTO]] = None -class LayerInfoDTO(BaseModel): +class LayerInfoDTO(AntaresBaseModel): id: str name: str areas: t.List[str] -class UpdateAreaUi(BaseModel, extra="forbid", populate_by_name=True): +class UpdateAreaUi(AntaresBaseModel, extra="forbid", populate_by_name=True): """ DTO for updating area UI diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index dcb5db5dc8..9f57cb50c1 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -15,7 +15,7 @@ import typing as t import numpy as np -from pydantic import BaseModel, field_validator, model_validator +from pydantic import field_validator, model_validator from typing_extensions import Literal from antarest.core.exceptions import ( @@ -28,6 +28,7 @@ ) from antarest.core.model import JSON from antarest.core.requests import CaseInsensitiveDict +from antarest.core.serialization import AntaresBaseModel from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study @@ -117,7 +118,7 @@ def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: # ============= -class STStorageMatrix(BaseModel): +class STStorageMatrix(AntaresBaseModel): """ Short-Term Storage Matrix Model. @@ -157,7 +158,7 @@ def validate_time_series(cls, data: t.List[t.List[float]]) -> t.List[t.List[floa # noinspection SpellCheckingInspection -class STStorageMatrices(BaseModel): +class STStorageMatrices(AntaresBaseModel): """ Short-Term Storage Matrices Validation Model. diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index 4ef2a5fffc..5db14243ab 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -15,7 +15,7 @@ import typing as t import numpy as np -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import Field, field_validator, model_validator from antarest.core.exceptions import ( BindingConstraintNotFound, @@ -30,6 +30,7 @@ ) from antarest.core.model import JSON from antarest.core.requests import CaseInsensitiveDict +from antarest.core.serialization import AntaresBaseModel from antarest.core.utils.string import to_camel_case from antarest.study.business.all_optional_meta import camel_case_model from antarest.study.business.utils import execute_or_add_commands @@ -79,7 +80,7 @@ } -class LinkTerm(BaseModel): +class LinkTerm(AntaresBaseModel): """ DTO for a constraint term on a link between two areas. @@ -98,7 +99,7 @@ def generate_id(self) -> str: return "%".join(ids) -class ClusterTerm(BaseModel): +class ClusterTerm(AntaresBaseModel): """ DTO for a constraint term on a cluster in an area. @@ -117,7 +118,7 @@ def generate_id(self) -> str: return ".".join(ids) -class ConstraintTerm(BaseModel): +class ConstraintTerm(AntaresBaseModel): """ DTO for a constraint term. @@ -147,7 +148,7 @@ def generate_id(self) -> str: return self.data.generate_id() -class ConstraintFilters(BaseModel, frozen=True, extra="forbid"): +class ConstraintFilters(AntaresBaseModel, frozen=True, extra="forbid"): """ Binding Constraint Filters gathering the main filtering parameters. diff --git a/antarest/study/business/district_manager.py b/antarest/study/business/district_manager.py index c642d61531..38a760b905 100644 --- a/antarest/study/business/district_manager.py +++ b/antarest/study/business/district_manager.py @@ -12,9 +12,8 @@ from typing import List -from pydantic import BaseModel - from antarest.core.exceptions import AreaNotFound, DistrictAlreadyExist, DistrictNotFound +from antarest.core.serialization import AntaresBaseModel from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id @@ -24,7 +23,7 @@ from antarest.study.storage.variantstudy.model.command.update_district import UpdateDistrict -class DistrictUpdateDTO(BaseModel): +class DistrictUpdateDTO(AntaresBaseModel): #: Indicates whether this district is used in the output (usually all #: districts are visible, but the user can decide to hide some of them). output: bool diff --git a/antarest/study/business/link_management.py b/antarest/study/business/link_management.py index 6f696d2a5e..f14c43ef07 100644 --- a/antarest/study/business/link_management.py +++ b/antarest/study/business/link_management.py @@ -12,10 +12,9 @@ import typing as t -from pydantic import BaseModel - from antarest.core.exceptions import ConfigFileNotFound from antarest.core.model import JSON +from antarest.core.serialization import AntaresBaseModel from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import RawStudy @@ -28,13 +27,13 @@ _ALL_LINKS_PATH = "input/links" -class LinkUIDTO(BaseModel): +class LinkUIDTO(AntaresBaseModel): color: str width: float style: str -class LinkInfoDTO(BaseModel): +class LinkInfoDTO(AntaresBaseModel): area1: str area2: str ui: t.Optional[LinkUIDTO] = None diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py index ad459d036f..318adde367 100644 --- a/antarest/study/business/xpansion_management.py +++ b/antarest/study/business/xpansion_management.py @@ -19,10 +19,11 @@ import zipfile from fastapi import HTTPException, UploadFile -from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator +from pydantic import Field, ValidationError, field_validator, model_validator from antarest.core.exceptions import BadZipBinary, ChildNotFoundError from antarest.core.model import JSON +from antarest.core.serialization import AntaresBaseModel from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.model import Study @@ -55,7 +56,7 @@ class Solver(EnumIgnoreCase): XPRESS = "Xpress" -class XpansionSensitivitySettings(BaseModel): +class XpansionSensitivitySettings(AntaresBaseModel): """ A DTO representing the sensitivity analysis settings used for Xpansion. @@ -76,7 +77,7 @@ def projection_validation(cls, v: t.Optional[t.Sequence[str]]) -> t.Sequence[str return [] if v is None else v -class XpansionSettings(BaseModel, extra="ignore", validate_assignment=True, populate_by_name=True): +class XpansionSettings(AntaresBaseModel, extra="ignore", validate_assignment=True, populate_by_name=True): """ A data transfer object representing the general settings used for Xpansion. @@ -230,7 +231,7 @@ class UpdateXpansionSettings(XpansionSettings): ) -class XpansionCandidateDTO(BaseModel): +class XpansionCandidateDTO(AntaresBaseModel): # The id of the candidate is irrelevant, so it should stay hidden for the user # The names should be the section titles of the file, and the id should be removed name: str diff --git a/antarest/study/model.py b/antarest/study/model.py index 081fd48b4a..5e528e57fb 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -18,7 +18,7 @@ from datetime import datetime, timedelta from pathlib import Path -from pydantic import BaseModel, field_validator +from pydantic import field_validator from sqlalchemy import ( # type: ignore Boolean, Column, @@ -34,6 +34,7 @@ from antarest.core.exceptions import ShouldNotHappenException from antarest.core.model import PublicMode from antarest.core.persistence import Base +from antarest.core.serialization import AntaresBaseModel from antarest.login.model import Group, GroupDTO, Identity from antarest.study.css4_colors import COLOR_NAMES @@ -150,7 +151,7 @@ class StudyContentStatus(enum.Enum): ERROR = "ERROR" -class CommentsDto(BaseModel): +class CommentsDto(AntaresBaseModel): comments: str @@ -299,7 +300,7 @@ class StudyFolder: groups: t.List[Group] -class PatchStudy(BaseModel): +class PatchStudy(AntaresBaseModel): scenario: t.Optional[str] = None doc: t.Optional[str] = None status: t.Optional[str] = None @@ -307,12 +308,12 @@ class PatchStudy(BaseModel): tags: t.List[str] = [] -class PatchArea(BaseModel): +class PatchArea(AntaresBaseModel): country: t.Optional[str] = None tags: t.List[str] = [] -class PatchCluster(BaseModel): +class PatchCluster(AntaresBaseModel): type: t.Optional[str] = None code_oi: t.Optional[str] = None @@ -322,23 +323,23 @@ def alias_generator(cls, string: str) -> str: return "-".join(string.split("_")) -class PatchOutputs(BaseModel): +class PatchOutputs(AntaresBaseModel): reference: t.Optional[str] = None -class Patch(BaseModel): +class Patch(AntaresBaseModel): study: t.Optional[PatchStudy] = None areas: t.Optional[t.Dict[str, PatchArea]] = None thermal_clusters: t.Optional[t.Dict[str, PatchCluster]] = None outputs: t.Optional[PatchOutputs] = None -class OwnerInfo(BaseModel): +class OwnerInfo(AntaresBaseModel): id: t.Optional[int] = None name: str -class StudyMetadataDTO(BaseModel): +class StudyMetadataDTO(AntaresBaseModel): id: str name: str version: int @@ -364,7 +365,7 @@ def transform_horizon_to_str(cls, val: t.Union[str, int, None]) -> t.Optional[st return str(val) if val else val # type: ignore -class StudyMetadataPatchDTO(BaseModel): +class StudyMetadataPatchDTO(AntaresBaseModel): name: t.Optional[str] = None author: t.Optional[str] = None horizon: t.Optional[str] = None @@ -387,7 +388,7 @@ def _normalize_tags(cls, v: t.List[str]) -> t.List[str]: return tags -class StudySimSettingsDTO(BaseModel): +class StudySimSettingsDTO(AntaresBaseModel): general: t.Dict[str, t.Any] input: t.Dict[str, t.Any] output: t.Dict[str, t.Any] @@ -398,7 +399,7 @@ class StudySimSettingsDTO(BaseModel): playlist: t.Optional[t.List[int]] = None -class StudySimResultDTO(BaseModel): +class StudySimResultDTO(AntaresBaseModel): name: str type: str settings: StudySimSettingsDTO @@ -478,7 +479,7 @@ def suffix(self) -> str: return mapping[self] -class StudyDownloadDTO(BaseModel): +class StudyDownloadDTO(AntaresBaseModel): """ DTO used to download outputs """ @@ -494,32 +495,32 @@ class StudyDownloadDTO(BaseModel): includeClusters: bool = False -class MatrixIndex(BaseModel): +class MatrixIndex(AntaresBaseModel): start_date: str = "" steps: int = 8760 first_week_size: int = 7 level: StudyDownloadLevelDTO = StudyDownloadLevelDTO.HOURLY -class TimeSerie(BaseModel): +class TimeSerie(AntaresBaseModel): name: str unit: str data: t.List[t.Optional[float]] = [] -class TimeSeriesData(BaseModel): +class TimeSeriesData(AntaresBaseModel): type: StudyDownloadType name: str data: t.Dict[str, t.List[TimeSerie]] = {} -class MatrixAggregationResultDTO(BaseModel): +class MatrixAggregationResultDTO(AntaresBaseModel): index: MatrixIndex data: t.List[TimeSeriesData] warnings: t.List[str] -class MatrixAggregationResult(BaseModel): +class MatrixAggregationResult(AntaresBaseModel): index: MatrixIndex data: t.Dict[t.Tuple[StudyDownloadType, str], t.Dict[str, t.List[TimeSerie]]] warnings: t.List[str] @@ -539,6 +540,6 @@ def to_dto(self) -> MatrixAggregationResultDTO: ) -class ReferenceStudy(BaseModel): +class ReferenceStudy(AntaresBaseModel): version: str template_name: str diff --git a/antarest/study/repository.py b/antarest/study/repository.py index f4d2e691c1..a485b24652 100644 --- a/antarest/study/repository.py +++ b/antarest/study/repository.py @@ -14,7 +14,7 @@ import enum import typing as t -from pydantic import BaseModel, NonNegativeInt +from pydantic import NonNegativeInt from sqlalchemy import and_, func, not_, or_, sql # type: ignore from sqlalchemy.orm import Query, Session, joinedload, with_polymorphic # type: ignore @@ -22,6 +22,7 @@ from antarest.core.jwt import JWTUser from antarest.core.model import PublicMode from antarest.core.requests import RequestParameters +from antarest.core.serialization import AntaresBaseModel from antarest.core.utils.fastapi_sqlalchemy import db from antarest.login.model import Group from antarest.study.model import DEFAULT_WORKSPACE_NAME, RawStudy, Study, StudyAdditionalData, Tag @@ -47,7 +48,7 @@ def escape_like(string: str, escape_char: str = "\\") -> str: return string.replace(escape_char, escape_char * 2).replace("%", escape_char + "%").replace("_", escape_char + "_") -class AccessPermissions(BaseModel, frozen=True, extra="forbid"): +class AccessPermissions(AntaresBaseModel, frozen=True, extra="forbid"): """ This class object is build to pass on the user identity and its associated groups information into the listing function get_all below @@ -84,7 +85,7 @@ def from_params(cls, params: t.Union[RequestParameters, JWTUser]) -> "AccessPerm return cls() -class StudyFilter(BaseModel, frozen=True, extra="forbid"): +class StudyFilter(AntaresBaseModel, frozen=True, extra="forbid"): """Study filter class gathering the main filtering parameters Attributes: @@ -127,7 +128,7 @@ class StudySortBy(str, enum.Enum): DATE_DESC = "-date" -class StudyPagination(BaseModel, frozen=True, extra="forbid"): +class StudyPagination(AntaresBaseModel, frozen=True, extra="forbid"): """ Pagination of a studies query results diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py index 62393794e2..f2a6349d90 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py @@ -19,12 +19,14 @@ import functools import typing as t -from pydantic import BaseModel, Field +from pydantic import Field + +from antarest.core.serialization import AntaresBaseModel @functools.total_ordering class ItemProperties( - BaseModel, + AntaresBaseModel, extra="forbid", validate_assignment=True, populate_by_name=True, diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py b/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py index 2cb5d9ec64..ab428fca75 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py @@ -12,13 +12,15 @@ import typing as t -from pydantic import BaseModel, Field, model_validator +from pydantic import Field, model_validator __all__ = ("IgnoreCaseIdentifier", "LowerCaseIdentifier") +from antarest.core.serialization import AntaresBaseModel + class IgnoreCaseIdentifier( - BaseModel, + AntaresBaseModel, extra="forbid", validate_assignment=True, populate_by_name=True, diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/model.py b/antarest/study/storage/rawstudy/model/filesystem/config/model.py index d0abd57710..5b4bbcab8f 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/model.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/model.py @@ -14,8 +14,9 @@ import typing as t from pathlib import Path -from pydantic import BaseModel, Field, model_validator +from pydantic import Field, model_validator +from antarest.core.serialization import AntaresBaseModel from antarest.core.utils.utils import DTO from antarest.study.business.enum_ignore_case import EnumIgnoreCase @@ -49,7 +50,7 @@ def __str__(self) -> str: return self.value -class Link(BaseModel, extra="ignore"): +class Link(AntaresBaseModel, extra="ignore"): """ Object linked to /input/links//properties.ini information @@ -74,7 +75,7 @@ def validation(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[st return values -class Area(BaseModel, extra="forbid"): +class Area(AntaresBaseModel, extra="forbid"): """ Object linked to /input//optimization.ini information """ @@ -89,7 +90,7 @@ class Area(BaseModel, extra="forbid"): st_storages: t.List[STStorageConfigType] = [] -class DistrictSet(BaseModel): +class DistrictSet(AntaresBaseModel): """ Object linked to /inputs/sets.ini information """ @@ -108,7 +109,7 @@ def get_areas(self, all_areas: t.List[str]) -> t.List[str]: return self.areas or [] -class Simulation(BaseModel): +class Simulation(AntaresBaseModel): """ Object linked to /output//about-the-study/** information """ @@ -130,7 +131,7 @@ def get_file(self) -> str: return f"{self.date}{modes[self.mode]}{dash}{self.name}" -class BindingConstraintDTO(BaseModel): +class BindingConstraintDTO(AntaresBaseModel): """ Object linked to `input/bindingconstraints/bindingconstraints.ini` information @@ -302,7 +303,7 @@ def transform_name_to_id(name: str, lower: bool = True) -> str: return valid_id.lower() if lower else valid_id -class FileStudyTreeConfigDTO(BaseModel): +class FileStudyTreeConfigDTO(AntaresBaseModel): study_path: Path path: Path study_id: str diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index 50ce659fd9..650a19898d 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -15,8 +15,9 @@ from enum import Enum import numpy as np -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import Field, field_validator, model_validator +from antarest.core.serialization import AntaresBaseModel from antarest.matrixstore.model import MatrixData from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( @@ -90,7 +91,7 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp # ================================================================================= -class BindingConstraintPropertiesBase(BaseModel, extra="forbid", populate_by_name=True): +class BindingConstraintPropertiesBase(AntaresBaseModel, extra="forbid", populate_by_name=True): enabled: bool = True time_step: BindingConstraintFrequency = Field(DEFAULT_TIMESTEP, alias="type") operator: BindingConstraintOperator = DEFAULT_OPERATOR @@ -163,7 +164,7 @@ class OptionalProperties(BindingConstraintProperties870): @camel_case_model -class BindingConstraintMatrices(BaseModel, extra="forbid", populate_by_name=True): +class BindingConstraintMatrices(AntaresBaseModel, extra="forbid", populate_by_name=True): """ Class used to store the matrices of a binding constraint. """ diff --git a/antarest/study/storage/variantstudy/model/command_context.py b/antarest/study/storage/variantstudy/model/command_context.py index 5996e63528..f4b9d6f6c7 100644 --- a/antarest/study/storage/variantstudy/model/command_context.py +++ b/antarest/study/storage/variantstudy/model/command_context.py @@ -10,14 +10,13 @@ # # This file is part of the Antares project. -from pydantic import BaseModel - +from antarest.core.serialization import AntaresBaseModel from antarest.matrixstore.service import ISimpleMatrixService from antarest.study.storage.patch_service import PatchService from antarest.study.storage.variantstudy.business.matrix_constants_generator import GeneratorMatrixConstants -class CommandContext(BaseModel): +class CommandContext(AntaresBaseModel): generator_matrix_constants: GeneratorMatrixConstants matrix_service: ISimpleMatrixService patch_service: PatchService diff --git a/antarest/study/storage/variantstudy/model/model.py b/antarest/study/storage/variantstudy/model/model.py index 0be3c75353..3814b2519c 100644 --- a/antarest/study/storage/variantstudy/model/model.py +++ b/antarest/study/storage/variantstudy/model/model.py @@ -14,9 +14,9 @@ import uuid import typing_extensions as te -from pydantic import BaseModel from antarest.core.model import JSON +from antarest.core.serialization import AntaresBaseModel from antarest.study.model import StudyMetadataDTO LegacyDetailsDTO = t.Tuple[str, bool, str] @@ -45,7 +45,7 @@ class NewDetailsDTO(te.TypedDict): DetailsDTO = t.Union[LegacyDetailsDTO, NewDetailsDTO] -class GenerationResultInfoDTO(BaseModel): +class GenerationResultInfoDTO(AntaresBaseModel): """ Result information of a snapshot generation process. @@ -58,7 +58,7 @@ class GenerationResultInfoDTO(BaseModel): details: t.MutableSequence[DetailsDTO] -class CommandDTO(BaseModel): +class CommandDTO(AntaresBaseModel): """ This class represents a command. @@ -75,7 +75,7 @@ class CommandDTO(BaseModel): version: int = 1 -class CommandResultDTO(BaseModel): +class CommandResultDTO(AntaresBaseModel): """ This class represents the result of a command. diff --git a/antarest/worker/archive_worker.py b/antarest/worker/archive_worker.py index 4fbc6a0631..a488d42d0d 100644 --- a/antarest/worker/archive_worker.py +++ b/antarest/worker/archive_worker.py @@ -13,10 +13,9 @@ import logging from pathlib import Path -from pydantic import BaseModel - from antarest.core.config import Config from antarest.core.interfaces.eventbus import IEventBus +from antarest.core.serialization import AntaresBaseModel from antarest.core.tasks.model import TaskResult from antarest.core.utils.utils import StopWatch, unzip from antarest.worker.worker import AbstractWorker, WorkerTaskCommand @@ -24,7 +23,7 @@ logger = logging.getLogger(__name__) -class ArchiveTaskArgs(BaseModel): +class ArchiveTaskArgs(AntaresBaseModel): src: str dest: str remove_src: bool = False diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py index 583f340e7a..5dba1d13db 100644 --- a/antarest/worker/simulator_worker.py +++ b/antarest/worker/simulator_worker.py @@ -18,11 +18,10 @@ from pathlib import Path from typing import cast -from pydantic import BaseModel - from antarest.core.cache.business.local_chache import LocalCache from antarest.core.config import Config, LocalConfig from antarest.core.interfaces.eventbus import IEventBus +from antarest.core.serialization import AntaresBaseModel from antarest.core.tasks.model import TaskResult from antarest.core.utils.fastapi_sqlalchemy import db from antarest.launcher.adapters.log_manager import follow @@ -38,7 +37,7 @@ GENERATE_KIRSHOFF_CONSTRAINTS_TASK_NAME = "generate-kirshoff-constraints" -class GenerateTimeseriesTaskArgs(BaseModel): +class GenerateTimeseriesTaskArgs(AntaresBaseModel): study_id: str study_path: str managed: bool diff --git a/antarest/worker/worker.py b/antarest/worker/worker.py index 7dc2534764..f73dd551f0 100644 --- a/antarest/worker/worker.py +++ b/antarest/worker/worker.py @@ -16,11 +16,10 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, Dict, List, Union -from pydantic import BaseModel - from antarest.core.interfaces.eventbus import Event, EventType, IEventBus from antarest.core.interfaces.service import IService from antarest.core.model import PermissionInfo, PublicMode +from antarest.core.serialization import AntaresBaseModel from antarest.core.tasks.model import TaskResult logger = logging.getLogger(__name__) @@ -28,12 +27,12 @@ MAX_WORKERS = 10 -class WorkerTaskResult(BaseModel): +class WorkerTaskResult(AntaresBaseModel): task_id: str task_result: TaskResult -class WorkerTaskCommand(BaseModel): +class WorkerTaskCommand(AntaresBaseModel): task_id: str task_type: str task_args: Dict[str, Union[int, float, bool, str]] diff --git a/tests/integration/study_data_blueprint/test_thermal.py b/tests/integration/study_data_blueprint/test_thermal.py index 6567fc205f..de00145805 100644 --- a/tests/integration/study_data_blueprint/test_thermal.py +++ b/tests/integration/study_data_blueprint/test_thermal.py @@ -297,22 +297,17 @@ class TestThermal: @pytest.mark.parametrize( "version", [pytest.param(0, id="No Upgrade"), pytest.param(860, id="v8.6"), pytest.param(870, id="v8.7")] ) - def test_lifecycle( - self, client: TestClient, user_access_token: str, internal_study_id: str, admin_access_token: str, version: int - ) -> None: + def test_lifecycle(self, client: TestClient, user_access_token: str, internal_study_id: str, version: int) -> None: + client.headers = {"Authorization": f"Bearer {user_access_token}"} # ============================= # STUDY UPGRADE # ============================= if version != 0: - res = client.put( - f"/v1/studies/{internal_study_id}/upgrade", - headers={"Authorization": f"Bearer {admin_access_token}"}, - params={"target_version": version}, - ) + res = client.put(f"/v1/studies/{internal_study_id}/upgrade", params={"target_version": version}) res.raise_for_status() task_id = res.json() - task = wait_task_completion(client, admin_access_token, task_id) + task = wait_task_completion(client, user_access_token, task_id) from antarest.core.tasks.model import TaskStatus assert task.status == TaskStatus.COMPLETED, task @@ -347,14 +342,18 @@ def test_lifecycle( # or an invalid name should also raise a validation error. attempts = [{}, {"name": ""}, {"name": "!??"}] for attempt in attempts: - res = client.post( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=attempt, - ) + res = client.post(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json=attempt) assert res.status_code == 422, res.json() assert res.json()["exception"] in {"ValidationError", "RequestValidationError"}, res.json() + # creating a thermal cluster with a name as a string should not raise an Exception + res = client.post(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json={"name": 111}) + assert res.status_code == 200, res.json() + res = client.request( + "DELETE", f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json=["111"] + ) + assert res.status_code == 204, res.json() + # We can create a thermal cluster with the following properties: fr_gas_conventional_props = { **DEFAULT_PROPERTIES, @@ -371,9 +370,7 @@ def test_lifecycle( "marketBidCost": 181.267, } res = client.post( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=fr_gas_conventional_props, + f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json=fr_gas_conventional_props ) assert res.status_code == 200, res.json() fr_gas_conventional_id = res.json()["id"] @@ -395,10 +392,7 @@ def test_lifecycle( assert res.json() == fr_gas_conventional_cfg # reading the properties of a thermal cluster - res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}") assert res.status_code == 200, res.json() assert res.json() == fr_gas_conventional_cfg @@ -410,17 +404,11 @@ def test_lifecycle( matrix_path = f"input/thermal/prepro/{area_id}/{fr_gas_conventional_id.lower()}/data" args = {"target": matrix_path, "matrix": matrix} res = client.post( - f"/v1/studies/{internal_study_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, + f"/v1/studies/{internal_study_id}/commands", json=[{"action": "replace_matrix", "args": args}] ) assert res.status_code in {200, 201}, res.json() - res = client.get( - f"/v1/studies/{internal_study_id}/raw", - params={"path": matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/raw", params={"path": matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix @@ -429,17 +417,13 @@ def test_lifecycle( # ================================== # Reading the list of thermal clusters - res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal") assert res.status_code == 200, res.json() assert res.json() == EXISTING_CLUSTERS + [fr_gas_conventional_cfg] # updating properties res = client.patch( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": "FR_Gas conventional old 1", "nominalCapacity": 32.1, @@ -453,10 +437,7 @@ def test_lifecycle( } assert res.json() == fr_gas_conventional_cfg - res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}") assert res.status_code == 200, res.json() assert res.json() == fr_gas_conventional_cfg @@ -467,7 +448,6 @@ def test_lifecycle( # updating properties res = client.patch( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "marginalCost": 182.456, "startupCost": 6140.8, @@ -489,24 +469,19 @@ def test_lifecycle( bad_properties = {"unitCount": 0} res = client.patch( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json=bad_properties, ) assert res.status_code == 422, res.json() assert res.json()["exception"] == "RequestValidationError", res.json() # The thermal cluster properties should not have been updated. - res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}") assert res.status_code == 200, res.json() assert res.json() == fr_gas_conventional_cfg # Update with a pollutant. Should succeed even with versions prior to v8.6 res = client.patch( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"nox": 10.0}, ) assert res.status_code == 200 @@ -514,7 +489,6 @@ def test_lifecycle( # Update with the field `efficiency`. Should succeed even with versions prior to v8.7 res = client.patch( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"efficiency": 97.0}, ) assert res.status_code == 200 @@ -526,7 +500,6 @@ def test_lifecycle( new_name = "Duplicate of Fr_Gas_Conventional" res = client.post( f"/v1/studies/{internal_study_id}/areas/{area_id}/thermals/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, params={"newName": new_name}, ) assert res.status_code in {200, 201}, res.json() @@ -544,11 +517,7 @@ def test_lifecycle( # asserts the matrix has also been duplicated new_cluster_matrix_path = f"input/thermal/prepro/{area_id}/{duplicated_id.lower()}/data" - res = client.get( - f"/v1/studies/{internal_study_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix @@ -558,8 +527,7 @@ def test_lifecycle( # Everything is fine at the beginning res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate", - headers={"Authorization": f"Bearer {user_access_token}"}, + f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate" ) assert res.status_code == 200 assert res.json() is True @@ -575,8 +543,7 @@ def test_lifecycle( # Validation should fail res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate", - headers={"Authorization": f"Bearer {user_access_token}"}, + f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate" ) assert res.status_code == 422 obj = res.json() @@ -594,8 +561,7 @@ def test_lifecycle( # Validation should succeed again res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate", - headers={"Authorization": f"Bearer {user_access_token}"}, + f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate" ) assert res.status_code == 200 assert res.json() is True @@ -612,8 +578,7 @@ def test_lifecycle( # Validation should fail res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate", - headers={"Authorization": f"Bearer {user_access_token}"}, + f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}/validate" ) assert res.status_code == 422 obj = res.json() @@ -648,19 +613,12 @@ def test_lifecycle( bc_obj["lessTermMatrix"] = matrix.tolist() # noinspection SpellCheckingInspection - res = client.post( - f"/v1/studies/{internal_study_id}/bindingconstraints", - json=bc_obj, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{internal_study_id}/bindingconstraints", json=bc_obj) assert res.status_code in {200, 201}, res.json() # verify that we can't delete the thermal cluster because it is referenced in a binding constraint res = client.request( - "DELETE", - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[fr_gas_conventional_id], + "DELETE", f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json=[fr_gas_conventional_id] ) assert res.status_code == 403, res.json() description = res.json()["description"] @@ -668,37 +626,23 @@ def test_lifecycle( assert res.json()["exception"] == "ReferencedObjectDeletionNotAllowed" # delete the binding constraint - res = client.delete( - f"/v1/studies/{internal_study_id}/bindingconstraints/{bc_obj['name']}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.delete(f"/v1/studies/{internal_study_id}/bindingconstraints/{bc_obj['name']}") assert res.status_code == 200, res.json() # Now we can delete the thermal cluster res = client.request( - "DELETE", - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[fr_gas_conventional_id], + "DELETE", f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json=[fr_gas_conventional_id] ) assert res.status_code == 204, res.json() # check that the binding constraint has been deleted # noinspection SpellCheckingInspection - res = client.get( - f"/v1/studies/{internal_study_id}/bindingconstraints", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/bindingconstraints") assert res.status_code == 200, res.json() assert len(res.json()) == 0 # If the thermal cluster list is empty, the deletion should be a no-op. - res = client.request( - "DELETE", - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[], - ) + res = client.request("DELETE", f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", json=[]) assert res.status_code == 204, res.json() assert res.text in {"", "null"} # Old FastAPI versions return 'null'. @@ -709,17 +653,13 @@ def test_lifecycle( res = client.request( "DELETE", f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json=[other_cluster_id1, other_cluster_id2], ) assert res.status_code == 204, res.json() assert res.text in {"", "null"} # Old FastAPI versions return 'null'. # The list of thermal clusters should not contain the deleted ones. - res = client.get( - f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal") assert res.status_code == 200, res.json() deleted_clusters = [other_cluster_id1, other_cluster_id2, fr_gas_conventional_id] for cluster in res.json(): @@ -734,7 +674,6 @@ def test_lifecycle( res = client.request( "DELETE", f"/v1/studies/{internal_study_id}/areas/{bad_area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json=[fr_gas_conventional_id], ) assert res.status_code == 500, res.json() @@ -750,10 +689,7 @@ def test_lifecycle( # Check DELETE with the wrong value of `study_id` bad_study_id = "bad_study" res = client.request( - "DELETE", - f"/v1/studies/{bad_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[fr_gas_conventional_id], + "DELETE", f"/v1/studies/{bad_study_id}/areas/{area_id}/clusters/thermal", json=[fr_gas_conventional_id] ) obj = res.json() description = obj["description"] @@ -762,8 +698,7 @@ def test_lifecycle( # Check GET with wrong `area_id` res = client.get( - f"/v1/studies/{internal_study_id}/areas/{bad_area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, + f"/v1/studies/{internal_study_id}/areas/{bad_area_id}/clusters/thermal/{fr_gas_conventional_id}" ) obj = res.json() description = obj["description"] @@ -771,10 +706,7 @@ def test_lifecycle( assert res.status_code == 404, res.json() # Check GET with wrong `study_id` - res = client.get( - f"/v1/studies/{bad_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{bad_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}") obj = res.json() description = obj["description"] assert res.status_code == 404, res.json() @@ -783,7 +715,6 @@ def test_lifecycle( # Check POST with wrong `study_id` res = client.post( f"/v1/studies/{bad_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"name": fr_gas_conventional, "group": "Battery"}, ) obj = res.json() @@ -794,7 +725,6 @@ def test_lifecycle( # Check POST with wrong `area_id` res = client.post( f"/v1/studies/{internal_study_id}/areas/{bad_area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": fr_gas_conventional, "group": "Oil", @@ -817,7 +747,6 @@ def test_lifecycle( # Check POST with wrong `group` res = client.post( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"name": fr_gas_conventional, "group": "GroupFoo"}, ) assert res.status_code == 200, res.json() @@ -828,7 +757,6 @@ def test_lifecycle( # Check PATCH with the wrong `area_id` res = client.patch( f"/v1/studies/{internal_study_id}/areas/{bad_area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "group": "Oil", "unitCount": 1, @@ -850,7 +778,6 @@ def test_lifecycle( bad_cluster_id = "bad_cluster" res = client.patch( f"/v1/studies/{internal_study_id}/areas/{area_id}/clusters/thermal/{bad_cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "group": "Oil", "unitCount": 1, @@ -871,7 +798,6 @@ def test_lifecycle( # Check PATCH with the wrong `study_id` res = client.patch( f"/v1/studies/{bad_study_id}/areas/{area_id}/clusters/thermal/{fr_gas_conventional_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "group": "Oil", "unitCount": 1, @@ -891,9 +817,7 @@ def test_lifecycle( # Cannot duplicate a fake cluster unknown_id = "unknown" res = client.post( - f"/v1/studies/{internal_study_id}/areas/{area_id}/thermals/{unknown_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": "duplicate"}, + f"/v1/studies/{internal_study_id}/areas/{area_id}/thermals/{unknown_id}", params={"newName": "duplicate"} ) assert res.status_code == 404, res.json() obj = res.json() @@ -903,7 +827,6 @@ def test_lifecycle( # Cannot duplicate with an existing id res = client.post( f"/v1/studies/{internal_study_id}/areas/{area_id}/thermals/{duplicated_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, params={"newName": new_name.upper()}, # different case but same ID ) assert res.status_code == 409, res.json() diff --git a/tests/study/business/test_all_optional_metaclass.py b/tests/study/business/test_all_optional_metaclass.py index b8d1197c5e..5001019595 100644 --- a/tests/study/business/test_all_optional_metaclass.py +++ b/tests/study/business/test_all_optional_metaclass.py @@ -10,12 +10,13 @@ # # This file is part of the Antares project. -from pydantic import BaseModel, Field +from pydantic import Field +from antarest.core.serialization import AntaresBaseModel from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model -class Model(BaseModel): +class Model(AntaresBaseModel): float_with_default: float = 1 float_without_default: float boolean_with_default: bool = True