From c3647ab849e24291ccba5f6a8e9f590a45d37f3e Mon Sep 17 00:00:00 2001 From: Laurent LAPORTE <43534797+laurent-laporte-pro@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:58:53 +0200 Subject: [PATCH] build(python): bump main project dependencies (#1728) bumps main projection dependencies: - `pydantic` from 1.9 to 2.8: huge breaking change but with large performance benefits expected on serialization - `fastapi` from 0.73 to 0.110 - `uvicorn` from to 0.15 to 0.30 - `mypy` from 1.4 to 1.11 It also brings few changes inside dependencies - Drop `requests` in favor of `httpx` - Drop `fastapi-jwt-auth` as they do not and will not support pydantic v2. We've decided to fork their code and adapt it as it's really light (see new folder `/antarest/fastapi_jwt_auth`) These changes also induced other minor dependencies bump: `jinja2`, `typing_extensions`, `PyJWT`, `python-multipart` Last, this work includes fixes on the API prefix addition when running in standalone mode (desktop version). We now distinguish properties root_path (used when behind a proxy) and api_prefix (which actually makes our server add a prefix). Co-authored-by: belthlemar Co-authored-by: Sylvain Leclerc --- antarest/core/application.py | 48 + antarest/core/cache/business/redis_cache.py | 2 +- antarest/core/config.py | 19 +- antarest/core/core_blueprint.py | 18 +- antarest/core/filesystem_blueprint.py | 34 +- antarest/core/filetransfer/main.py | 9 +- antarest/core/filetransfer/model.py | 2 +- antarest/core/maintenance/main.py | 9 +- antarest/core/model.py | 2 +- antarest/core/permissions.py | 15 +- antarest/core/requests.py | 45 +- antarest/core/tasks/main.py | 9 +- antarest/core/tasks/model.py | 43 +- antarest/core/tasks/service.py | 10 +- .../utils/fastapi_sqlalchemy/middleware.py | 4 +- antarest/core/version_info.py | 2 +- antarest/eventbus/business/redis_eventbus.py | 4 +- antarest/eventbus/main.py | 9 +- antarest/eventbus/web.py | 11 +- antarest/fastapi_jwt_auth/LICENSE | 21 + antarest/fastapi_jwt_auth/README.md | 4 + antarest/fastapi_jwt_auth/__init__.py | 7 + antarest/fastapi_jwt_auth/auth_config.py | 114 +++ antarest/fastapi_jwt_auth/auth_jwt.py | 817 ++++++++++++++++++ antarest/fastapi_jwt_auth/config.py | 90 ++ antarest/fastapi_jwt_auth/exceptions.py | 89 ++ antarest/front.py | 139 +++ antarest/gui.py | 16 +- .../launcher/adapters/abstractlauncher.py | 4 +- antarest/launcher/main.py | 9 +- antarest/launcher/model.py | 21 +- antarest/launcher/service.py | 19 +- antarest/launcher/ssh_client.py | 2 +- antarest/launcher/ssh_config.py | 4 +- antarest/launcher/web.py | 14 +- antarest/login/auth.py | 6 +- antarest/login/ldap.py | 9 +- antarest/login/main.py | 19 +- antarest/login/service.py | 2 +- antarest/login/web.py | 24 +- antarest/main.py | 139 +-- antarest/matrixstore/main.py | 11 +- antarest/matrixstore/matrix_editor.py | 35 +- antarest/matrixstore/service.py | 1 + antarest/singleton_services.py | 4 +- .../business/adequacy_patch_management.py | 22 +- .../advanced_parameters_management.py | 52 +- antarest/study/business/all_optional_meta.py | 99 +-- .../study/business/allocation_management.py | 40 +- antarest/study/business/area_management.py | 47 +- .../study/business/areas/hydro_management.py | 36 +- .../business/areas/properties_management.py | 24 +- .../business/areas/renewable_management.py | 67 +- .../business/areas/st_storage_management.py | 98 +-- .../business/areas/thermal_management.py | 81 +- .../business/binding_constraint_management.py | 64 +- .../study/business/correlation_management.py | 31 +- antarest/study/business/district_manager.py | 27 +- antarest/study/business/general_management.py | 57 +- antarest/study/business/link_management.py | 11 +- .../study/business/optimization_management.py | 33 +- .../study/business/table_mode_management.py | 27 +- .../business/thematic_trimming_field_infos.py | 8 +- .../business/thematic_trimming_management.py | 2 +- .../business/timeseries_config_management.py | 49 +- antarest/study/business/utils.py | 6 +- .../study/business/xpansion_management.py | 66 +- antarest/study/main.py | 18 +- antarest/study/model.py | 37 +- antarest/study/service.py | 9 +- .../study/storage/abstract_storage_service.py | 4 +- antarest/study/storage/patch_service.py | 6 +- .../rawstudy/model/filesystem/config/area.py | 72 +- .../model/filesystem/config/cluster.py | 8 +- .../filesystem/config/field_validators.py | 2 +- .../rawstudy/model/filesystem/config/files.py | 1 - .../model/filesystem/config/identifier.py | 39 +- .../model/filesystem/config/ini_properties.py | 18 +- .../rawstudy/model/filesystem/config/links.py | 12 +- .../rawstudy/model/filesystem/config/model.py | 6 +- .../model/filesystem/config/thermal.py | 2 +- .../rawstudy/model/filesystem/factory.py | 4 +- .../storage/rawstudy/raw_study_service.py | 2 +- .../study/storage/study_download_utils.py | 2 +- .../business/command_extractor.py | 11 +- .../variantstudy/business/command_reverter.py | 2 +- .../business/utils_binding_constraint.py | 1 - .../storage/variantstudy/command_factory.py | 10 +- .../variantstudy/model/command/create_area.py | 4 +- .../command/create_binding_constraint.py | 61 +- .../model/command/create_cluster.py | 37 +- .../model/command/create_district.py | 8 +- .../variantstudy/model/command/create_link.py | 23 +- .../command/create_renewables_cluster.py | 10 +- .../model/command/create_st_storage.py | 50 +- .../generate_thermal_cluster_timeseries.py | 4 +- .../variantstudy/model/command/icommand.py | 10 +- .../variantstudy/model/command/remove_area.py | 2 +- .../command/remove_binding_constraint.py | 2 +- .../model/command/remove_cluster.py | 17 +- .../model/command/remove_district.py | 4 +- .../variantstudy/model/command/remove_link.py | 22 +- .../command/remove_renewables_cluster.py | 4 +- .../model/command/remove_st_storage.py | 8 +- .../model/command/replace_matrix.py | 12 +- .../command/update_binding_constraint.py | 15 +- .../model/command/update_comments.py | 4 +- .../model/command/update_config.py | 4 +- .../model/command/update_district.py | 12 +- .../model/command/update_playlist.py | 4 +- .../model/command/update_raw_file.py | 4 +- .../model/command/update_scenario_builder.py | 8 +- .../variantstudy/model/command_context.py | 1 - .../study/storage/variantstudy/model/model.py | 2 +- .../variantstudy/snapshot_generator.py | 6 +- .../variantstudy/variant_study_service.py | 4 +- antarest/study/web/studies_blueprint.py | 2 +- antarest/study/web/study_data_blueprint.py | 28 +- antarest/study/web/variant_blueprint.py | 7 +- .../study/web/xpansion_studies_blueprint.py | 2 +- antarest/tools/cli.py | 15 +- antarest/tools/lib.py | 63 +- antarest/utils.py | 50 +- antarest/worker/archive_worker.py | 6 +- antarest/worker/simulator_worker.py | 2 +- antarest/worker/worker.py | 6 +- pyproject.toml | 10 +- requirements-dev.txt | 2 +- requirements-test.txt | 3 +- requirements.txt | 32 +- resources/application.yaml | 19 +- resources/deploy/config.prod.yaml | 4 + resources/deploy/config.yaml | 3 +- resources/templates/.placeholder | 0 scripts/build-front.sh | 1 - sonar-project.properties | 2 +- tests/cache/test_local_cache.py | 6 +- tests/cache/test_redis_cache.py | 6 +- tests/core/test_tasks.py | 2 +- tests/eventbus/test_redis_event_bus.py | 2 +- tests/eventbus/test_websocket_manager.py | 4 +- .../filesystem_blueprint/test_model.py | 6 +- .../launcher_blueprint/test_launcher_local.py | 12 +- .../test_aggregate_raw_data.py | 211 +++-- .../test_download_matrices.py | 2 +- .../test_synthesis/raw_study.synthesis.json | 612 +------------ .../variant_study.synthesis.json | 612 +------------ .../studies_blueprint/test_comments.py | 42 +- .../studies_blueprint/test_disk_usage.py | 2 +- .../studies_blueprint/test_get_studies.py | 18 +- .../studies_blueprint/test_update_tags.py | 39 +- .../test_advanced_parameters.py | 2 +- .../test_binding_constraints.py | 4 +- .../test_config_general.py | 2 +- .../study_data_blueprint/test_renewable.py | 48 +- .../study_data_blueprint/test_st_storage.py | 123 +-- .../study_data_blueprint/test_thermal.py | 78 +- tests/integration/test_apidoc.py | 15 +- tests/integration/test_core_blueprint.py | 21 - tests/integration/test_integration.py | 28 +- .../test_integration_variantmanager_tool.py | 6 +- .../variant_blueprint/test_st_storage.py | 17 +- .../variant_blueprint/test_thermal_cluster.py | 2 +- .../variant_blueprint/test_variant_manager.py | 6 +- .../test_integration_xpansion.py | 96 +- tests/launcher/test_service.py | 35 +- tests/launcher/test_web.py | 29 +- tests/login/test_login_service.py | 36 +- tests/login/test_web.py | 42 +- tests/matrixstore/test_matrix_editor.py | 8 +- tests/matrixstore/test_service.py | 14 +- tests/matrixstore/test_web.py | 15 +- .../storage/business/test_arealink_manager.py | 29 +- tests/storage/business/test_config_manager.py | 12 +- tests/storage/business/test_patch_service.py | 2 +- .../test_timeseries_config_manager.py | 2 +- .../storage/business/test_xpansion_manager.py | 32 +- tests/storage/integration/conftest.py | 2 +- tests/storage/integration/test_STA_mini.py | 56 +- tests/storage/integration/test_exporter.py | 7 +- tests/storage/rawstudies/test_factory.py | 4 +- .../filesystem/config/test_config_files.py | 2 +- tests/storage/test_model.py | 2 +- tests/storage/test_service.py | 10 +- tests/storage/web/test_studies_bp.py | 282 +----- .../areas/test_st_storage_management.py | 106 +-- .../business/areas/test_thermal_management.py | 8 +- .../business/test_all_optional_metaclass.py | 375 +------- .../study/business/test_allocation_manager.py | 2 +- .../variantstudy/model/test_dbmodel.py | 2 +- .../variantstudy/test_snapshot_generator.py | 8 +- tests/study/test_repository.py | 44 +- tests/test_front.py | 111 +++ .../model/command/test_create_cluster.py | 32 +- .../model/command/test_create_link.py | 17 +- .../command/test_create_renewables_cluster.py | 6 +- .../model/command/test_create_st_storage.py | 101 +-- .../model/command/test_manage_district.py | 8 +- .../model/command/test_remove_area.py | 2 +- .../model/command/test_remove_link.py | 2 +- .../model/command/test_remove_st_storage.py | 6 +- .../model/command/test_replace_matrix.py | 6 +- .../model/command/test_update_config.py | 2 +- .../variantstudy/model/test_variant_model.py | 2 +- 204 files changed, 3415 insertions(+), 3797 deletions(-) create mode 100644 antarest/core/application.py create mode 100644 antarest/fastapi_jwt_auth/LICENSE create mode 100644 antarest/fastapi_jwt_auth/README.md create mode 100644 antarest/fastapi_jwt_auth/__init__.py create mode 100644 antarest/fastapi_jwt_auth/auth_config.py create mode 100644 antarest/fastapi_jwt_auth/auth_jwt.py create mode 100644 antarest/fastapi_jwt_auth/config.py create mode 100644 antarest/fastapi_jwt_auth/exceptions.py create mode 100644 antarest/front.py delete mode 100644 resources/templates/.placeholder create mode 100644 tests/test_front.py diff --git a/antarest/core/application.py b/antarest/core/application.py new file mode 100644 index 0000000000..3f09bd4102 --- /dev/null +++ b/antarest/core/application.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from dataclasses import dataclass +from typing import Optional + +from fastapi import APIRouter, FastAPI + + +@dataclass(frozen=True) +class AppBuildContext: + """ + Base elements of the application, for use at construction time: + - app: the actual fastapi application, where middlewares, exception handlers, etc. may be added + - api_root: the route under which all API and WS endpoints must be registered + + API routes should not be added straight to app, but under api_root instead, + so that they are correctly prefixed if needed (/api for standalone mode). + + Warning: the inclusion of api_root must happen AFTER all subroutes + have been registered, hence the build method. + """ + + app: FastAPI + api_root: APIRouter + + def build(self) -> FastAPI: + """ + Finalizes the app construction by including the API route. + Must be performed AFTER all subroutes have been added. + """ + self.app.include_router(self.api_root) + return self.app + + +def create_app_ctxt(app: FastAPI, api_root: Optional[APIRouter] = None) -> AppBuildContext: + if not api_root: + api_root = APIRouter() + return AppBuildContext(app, api_root) diff --git a/antarest/core/cache/business/redis_cache.py b/antarest/core/cache/business/redis_cache.py index 75423cd190..0234fd1666 100644 --- a/antarest/core/cache/business/redis_cache.py +++ b/antarest/core/cache/business/redis_cache.py @@ -40,7 +40,7 @@ def put(self, id: str, data: JSON, duration: int = 3600) -> None: redis_element = RedisCacheElement(duration=duration, data=data) redis_key = f"cache:{id}" logger.info(f"Adding cache key {id}") - self.redis.set(redis_key, redis_element.json()) + self.redis.set(redis_key, redis_element.model_dump_json()) self.redis.expire(redis_key, duration) def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]: diff --git a/antarest/core/config.py b/antarest/core/config.py index 7aa54d02f8..4b85331d6a 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -13,6 +13,7 @@ import multiprocessing import tempfile from dataclasses import asdict, dataclass, field +from enum import Enum from pathlib import Path from typing import Dict, List, Optional @@ -24,6 +25,12 @@ DEFAULT_WORKSPACE_NAME = "default" +class Launcher(str, Enum): + SLURM = "slurm" + LOCAL = "local" + DEFAULT = "default" + + @dataclass(frozen=True) class ExternalAuthConfig: """ @@ -399,7 +406,7 @@ def __post_init__(self) -> None: msg = f"Invalid configuration: {self.default=} must be one of {possible!r}" raise ValueError(msg) - def get_nb_cores(self, launcher: str) -> "NbCoresConfig": + def get_nb_cores(self, launcher: Launcher) -> "NbCoresConfig": """ Retrieve the number of cores configuration for a given launcher: "local" or "slurm". If "default" is specified, retrieve the configuration of the default launcher. @@ -416,12 +423,12 @@ def get_nb_cores(self, launcher: str) -> "NbCoresConfig": """ config_map = {"local": self.local, "slurm": self.slurm} config_map["default"] = config_map[self.default] - launcher_config = config_map.get(launcher) + launcher_config = config_map.get(launcher.value) if launcher_config is None: - raise InvalidConfigurationError(launcher) + raise InvalidConfigurationError(launcher.value) return launcher_config.nb_cores - def get_time_limit(self, launcher: str) -> TimeLimitConfig: + def get_time_limit(self, launcher: Launcher) -> TimeLimitConfig: """ Retrieve the time limit for a job of the given launcher: "local" or "slurm". If "default" is specified, retrieve the configuration of the default launcher. @@ -438,7 +445,7 @@ def get_time_limit(self, launcher: str) -> TimeLimitConfig: """ config_map = {"local": self.local, "slurm": self.slurm} config_map["default"] = config_map[self.default] - launcher_config = config_map.get(launcher) + launcher_config = config_map.get(launcher.value) if launcher_config is None: raise InvalidConfigurationError(launcher) return launcher_config.time_limit @@ -586,6 +593,7 @@ class Config: cache: CacheConfig = CacheConfig() tasks: TaskConfig = TaskConfig() root_path: str = "" + api_prefix: str = "" @classmethod def from_dict(cls, data: JSON) -> "Config": @@ -604,6 +612,7 @@ def from_dict(cls, data: JSON) -> "Config": cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else defaults.cache, tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else defaults.tasks, root_path=data.get("root_path", defaults.root_path), + api_prefix=data.get("api_prefix", defaults.api_prefix), ) @classmethod diff --git a/antarest/core/core_blueprint.py b/antarest/core/core_blueprint.py index 988ad9f641..27f6591109 100644 --- a/antarest/core/core_blueprint.py +++ b/antarest/core/core_blueprint.py @@ -10,18 +10,14 @@ # # This file is part of the Antares project. -import logging from typing import Any -from fastapi import APIRouter, Depends +from fastapi import APIRouter from pydantic import BaseModel from antarest.core.config import Config -from antarest.core.jwt import JWTUser -from antarest.core.requests import UserHasNotPermissionError from antarest.core.utils.web import APITag from antarest.core.version_info import VersionInfoDTO, get_commit_id, get_dependencies -from antarest.login.auth import Auth class StatusDTO(BaseModel): @@ -36,7 +32,6 @@ def create_utils_routes(config: Config) -> APIRouter: config: main server configuration """ bp = APIRouter() - auth = Auth(config) @bp.get("/health", tags=[APITag.misc], response_model=StatusDTO) def health() -> Any: @@ -66,15 +61,4 @@ def version_info() -> Any: dependencies=get_dependencies(), ) - @bp.get("/kill", include_in_schema=False) - def kill_worker( - current_user: JWTUser = Depends(auth.get_current_user), - ) -> Any: - if not current_user.is_site_admin(): - raise UserHasNotPermissionError() - logging.getLogger(__name__).critical("Killing the worker") - # PyInstaller modifies the behavior of built-in functions, such as `exit`. - # It is advisable to use `sys.exit` or raise the `SystemExit` exception instead. - raise SystemExit(f"Worker killed by the user #{current_user.id}") - return bp diff --git a/antarest/core/filesystem_blueprint.py b/antarest/core/filesystem_blueprint.py index 02652f2213..d625d19c07 100644 --- a/antarest/core/filesystem_blueprint.py +++ b/antarest/core/filesystem_blueprint.py @@ -23,21 +23,21 @@ import typing_extensions as te from fastapi import APIRouter, Depends, HTTPException -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Field from starlette.responses import PlainTextResponse, StreamingResponse from antarest.core.config import Config from antarest.core.utils.web import APITag from antarest.login.auth import Auth -FilesystemName = te.Annotated[str, Field(regex=r"^\w+$", description="Filesystem name")] -MountPointName = te.Annotated[str, Field(regex=r"^\w+$", description="Mount point name")] +FilesystemName = te.Annotated[str, Field(pattern=r"^\w+$", description="Filesystem name")] +MountPointName = te.Annotated[str, Field(pattern=r"^\w+$", description="Mount point name")] class FilesystemDTO( BaseModel, - extra=Extra.forbid, - schema_extra={ + extra="forbid", + json_schema_extra={ "example": { "name": "ws", "mount_dirs": { @@ -62,8 +62,8 @@ class FilesystemDTO( class MountPointDTO( BaseModel, - extra=Extra.forbid, - schema_extra={ + extra="forbid", + json_schema_extra={ "example": { "name": "default", "path": "/path/to/workspaces/internal_studies", @@ -89,10 +89,10 @@ class MountPointDTO( name: MountPointName path: Path = Field(description="Full path of the mount point in Antares Web Server") - total_bytes: int = Field(0, description="Total size of the mount point in bytes") - used_bytes: int = Field(0, description="Used size of the mount point in bytes") - free_bytes: int = Field(0, description="Free size of the mount point in bytes") - message: str = Field("", description="A message describing the status of the mount point") + total_bytes: int = Field(default=0, description="Total size of the mount point in bytes") + used_bytes: int = Field(default=0, description="Used size of the mount point in bytes") + free_bytes: int = Field(default=0, description="Free size of the mount point in bytes") + message: str = Field(default="", description="A message describing the status of the mount point") @classmethod async def from_path(cls, name: str, path: Path) -> "MountPointDTO": @@ -110,8 +110,8 @@ async def from_path(cls, name: str, path: Path) -> "MountPointDTO": class FileInfoDTO( BaseModel, - extra=Extra.forbid, - schema_extra={ + extra="forbid", + json_schema_extra={ "example": { "path": "/path/to/workspaces/internal_studies/5a503c20-24a3-4734-9cf8-89565c9db5ec/study.antares", "file_type": "file", @@ -142,12 +142,12 @@ class FileInfoDTO( path: Path = Field(description="Full path of the file or directory in Antares Web Server") file_type: str = Field(description="Type of the file or directory") - file_count: int = Field(1, description="Number of files and folders in the directory (1 for files)") - size_bytes: int = Field(0, description="Size of the file or total size of the directory in bytes") + file_count: int = Field(default=1, description="Number of files and folders in the directory (1 for files)") + size_bytes: int = Field(default=0, description="Size of the file or total size of the directory in bytes") created: datetime.datetime = Field(description="Creation date of the file or directory (local time)") modified: datetime.datetime = Field(description="Last modification date of the file or directory (local time)") accessed: datetime.datetime = Field(description="Last access date of the file or directory (local time)") - message: str = Field("OK", description="A message describing the status of the file") + message: str = Field(default="OK", description="A message describing the status of the file") @classmethod async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfoDTO": @@ -160,6 +160,7 @@ async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfo path=full_path, file_type="unknown", file_count=0, # missing + size_bytes=0, # missing created=datetime.datetime.min, modified=datetime.datetime.min, accessed=datetime.datetime.min, @@ -174,6 +175,7 @@ async def from_path(cls, full_path: Path, *, details: bool = False) -> "FileInfo created=datetime.datetime.fromtimestamp(file_stat.st_ctime), modified=datetime.datetime.fromtimestamp(file_stat.st_mtime), accessed=datetime.datetime.fromtimestamp(file_stat.st_atime), + message="OK", ) if stat.S_ISDIR(file_stat.st_mode): diff --git a/antarest/core/filetransfer/main.py b/antarest/core/filetransfer/main.py index cbf732ef22..3583dc5701 100644 --- a/antarest/core/filetransfer/main.py +++ b/antarest/core/filetransfer/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.repository import FileDownloadRepository from antarest.core.filetransfer.service import FileTransferManager @@ -22,10 +23,10 @@ def build_filetransfer_service( - application: Optional[FastAPI], event_bus: IEventBus, config: Config + app_ctxt: Optional[AppBuildContext], event_bus: IEventBus, config: Config ) -> FileTransferManager: ftm = FileTransferManager(repository=FileDownloadRepository(), event_bus=event_bus, config=config) - if application: - application.include_router(create_file_transfer_api(ftm, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_file_transfer_api(ftm, config)) return ftm diff --git a/antarest/core/filetransfer/model.py b/antarest/core/filetransfer/model.py index 1ae51c1009..72463e0bad 100644 --- a/antarest/core/filetransfer/model.py +++ b/antarest/core/filetransfer/model.py @@ -41,7 +41,7 @@ class FileDownloadDTO(BaseModel): id: str name: str filename: str - expiration_date: Optional[str] + expiration_date: Optional[str] = None ready: bool failed: bool = False error_message: str = "" diff --git a/antarest/core/maintenance/main.py b/antarest/core/maintenance/main.py index b129d85691..8717150d07 100644 --- a/antarest/core/maintenance/main.py +++ b/antarest/core/maintenance/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus @@ -23,7 +24,7 @@ def build_maintenance_manager( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, cache: ICache, event_bus: IEventBus = DummyEventBusService(), @@ -31,7 +32,7 @@ def build_maintenance_manager( repository = MaintenanceRepository() service = MaintenanceService(config, repository, event_bus, cache) - if application: - application.include_router(create_maintenance_api(service, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_maintenance_api(service, config)) return service diff --git a/antarest/core/model.py b/antarest/core/model.py index d13c2931f6..b500e9a5c5 100644 --- a/antarest/core/model.py +++ b/antarest/core/model.py @@ -21,7 +21,7 @@ JSON = Dict[str, Any] ELEMENT = Union[str, int, float, bool, bytes] -SUB_JSON = Union[ELEMENT, JSON, List, None] +SUB_JSON = Union[ELEMENT, JSON, List[Any], None] class PublicMode(str, enum.Enum): diff --git a/antarest/core/permissions.py b/antarest/core/permissions.py index 0a2d5d6d6f..55defacf7d 100644 --- a/antarest/core/permissions.py +++ b/antarest/core/permissions.py @@ -11,6 +11,7 @@ # This file is part of the Antares project. import logging +import typing as t from antarest.core.jwt import JWTUser from antarest.core.model import PermissionInfo, PublicMode, StudyPermissionType @@ -19,8 +20,8 @@ logger = logging.getLogger(__name__) -permission_matrix = { - StudyPermissionType.READ: { +permission_matrix: t.Dict[str, t.Dict[str, t.Sequence[t.Union[RoleType, PublicMode]]]] = { + StudyPermissionType.READ.value: { "roles": [ RoleType.ADMIN, RoleType.RUNNER, @@ -34,15 +35,15 @@ PublicMode.READ, ], }, - StudyPermissionType.RUN: { + StudyPermissionType.RUN.value: { "roles": [RoleType.ADMIN, RoleType.RUNNER, RoleType.WRITER], "public_modes": [PublicMode.FULL, PublicMode.EDIT, PublicMode.EXECUTE], }, - StudyPermissionType.WRITE: { + StudyPermissionType.WRITE.value: { "roles": [RoleType.ADMIN, RoleType.WRITER], "public_modes": [PublicMode.FULL, PublicMode.EDIT], }, - StudyPermissionType.MANAGE_PERMISSIONS: { + StudyPermissionType.MANAGE_PERMISSIONS.value: { "roles": [RoleType.ADMIN], "public_modes": [], }, @@ -77,11 +78,11 @@ def check_permission( allowed_roles = permission_matrix[permission]["roles"] group_permission = any( - role in allowed_roles # type: ignore + role in allowed_roles for role in [group.role for group in (user.groups or []) if group.id in permission_info.groups] ) if group_permission: return True allowed_public_modes = permission_matrix[permission]["public_modes"] - return permission_info.public_mode in allowed_public_modes # type: ignore + return permission_info.public_mode in allowed_public_modes diff --git a/antarest/core/requests.py b/antarest/core/requests.py index 52614f8ab7..d276f79dca 100644 --- a/antarest/core/requests.py +++ b/antarest/core/requests.py @@ -10,8 +10,10 @@ # # This file is part of the Antares project. +import typing as t +from collections import OrderedDict from dataclasses import dataclass -from typing import Optional +from typing import Any, Generator, Tuple from fastapi import HTTPException from markupsafe import escape @@ -29,13 +31,52 @@ } +class CaseInsensitiveDict(t.MutableMapping[str, t.Any]): # copy of the requests class to avoid importing the package + def __init__(self, data=None, **kwargs) -> None: # type: ignore + self._store: OrderedDict[str, t.Any] = OrderedDict() + if data is None: + data = {} + self.update(data, **kwargs) + + def __setitem__(self, key: str, value: t.Any) -> None: + self._store[key.lower()] = (key, value) + + def __getitem__(self, key: str) -> t.Any: + return self._store[key.lower()][1] + + def __delitem__(self, key: str) -> None: + del self._store[key.lower()] + + def __iter__(self) -> t.Any: + return (casedkey for casedkey, mappedvalue in self._store.values()) + + def __len__(self) -> int: + return len(self._store) + + def lower_items(self) -> Generator[Tuple[Any, Any], Any, None]: + return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) + + def __eq__(self, other: t.Any) -> bool: + if isinstance(other, t.Mapping): + other = CaseInsensitiveDict(other) + else: + return NotImplemented + return dict(self.lower_items()) == dict(other.lower_items()) + + def copy(self) -> "CaseInsensitiveDict": + return CaseInsensitiveDict(self._store.values()) + + def __repr__(self) -> str: + return str(dict(self.items())) + + @dataclass class RequestParameters: """ DTO object to handle data inside request to send to service """ - user: Optional[JWTUser] = None + user: t.Optional[JWTUser] = None def get_user_id(self) -> str: return str(escape(str(self.user.id))) if self.user else "Unknown" diff --git a/antarest/core/tasks/main.py b/antarest/core/tasks/main.py index ae3d7dafb8..74685ba836 100644 --- a/antarest/core/tasks/main.py +++ b/antarest/core/tasks/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus from antarest.core.tasks.repository import TaskJobRepository @@ -22,14 +23,14 @@ def build_taskjob_manager( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, event_bus: IEventBus = DummyEventBusService(), ) -> ITaskService: repository = TaskJobRepository() service = TaskJobService(config, repository, event_bus) - if application: - application.include_router(create_tasks_api(service, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_tasks_api(service, config)) return service diff --git a/antarest/core/tasks/model.py b/antarest/core/tasks/model.py index 9fb0aadb42..a7ca9aedb9 100644 --- a/antarest/core/tasks/model.py +++ b/antarest/core/tasks/model.py @@ -15,7 +15,7 @@ from datetime import datetime from enum import Enum -from pydantic import BaseModel, Extra +from pydantic import BaseModel from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.orm import relationship, sessionmaker # type: ignore @@ -57,43 +57,43 @@ def is_final(self) -> bool: ] -class TaskResult(BaseModel, extra=Extra.forbid): +class TaskResult(BaseModel, extra="forbid"): success: bool message: str # Can be used to store json serialized result - return_value: t.Optional[str] + return_value: t.Optional[str] = None -class TaskLogDTO(BaseModel, extra=Extra.forbid): +class TaskLogDTO(BaseModel, extra="forbid"): id: str message: str -class CustomTaskEventMessages(BaseModel, extra=Extra.forbid): +class CustomTaskEventMessages(BaseModel, extra="forbid"): start: str running: str end: str -class TaskEventPayload(BaseModel, extra=Extra.forbid): +class TaskEventPayload(BaseModel, extra="forbid"): id: str message: str -class TaskDTO(BaseModel, extra=Extra.forbid): +class TaskDTO(BaseModel, extra="forbid"): id: str name: str - owner: t.Optional[int] + owner: t.Optional[int] = None status: TaskStatus creation_date_utc: str - completion_date_utc: t.Optional[str] - result: t.Optional[TaskResult] - logs: t.Optional[t.List[TaskLogDTO]] + completion_date_utc: t.Optional[str] = None + result: t.Optional[TaskResult] = None + logs: t.Optional[t.List[TaskLogDTO]] = None type: t.Optional[str] = None ref_id: t.Optional[str] = None -class TaskListFilter(BaseModel, extra=Extra.forbid): +class TaskListFilter(BaseModel, extra="forbid"): status: t.List[TaskStatus] = [] name: t.Optional[str] = None type: t.List[TaskType] = [] @@ -171,6 +171,15 @@ class TaskJob(Base): # type: ignore study: "Study" = relationship("Study", back_populates="jobs", uselist=False) def to_dto(self, with_logs: bool = False) -> TaskDTO: + result = None + if self.completion_date: + assert self.result_status is not None + assert self.result_msg is not None + result = TaskResult( + success=self.result_status, + message=self.result_msg, + return_value=self.result, + ) return TaskDTO( id=self.id, owner=self.owner_id, @@ -178,15 +187,7 @@ def to_dto(self, with_logs: bool = False) -> TaskDTO: completion_date_utc=str(self.completion_date) if self.completion_date else None, name=self.name, status=TaskStatus(self.status), - result=( - TaskResult( - success=self.result_status, - message=self.result_msg, - return_value=self.result, - ) - if self.completion_date - else None - ), + result=result, logs=sorted([log.to_dto() for log in self.logs], key=lambda log: log.id) if with_logs else None, type=self.type, ref_id=self.ref_id, diff --git a/antarest/core/tasks/service.py b/antarest/core/tasks/service.py index 34dd8a3d73..e4855e01c3 100644 --- a/antarest/core/tasks/service.py +++ b/antarest/core/tasks/service.py @@ -144,7 +144,7 @@ def _create_awaiter( res_wrapper: t.List[TaskResult], ) -> t.Callable[[Event], t.Awaitable[None]]: async def _await_task_end(event: Event) -> None: - task_event = WorkerTaskResult.parse_obj(event.payload) + task_event = WorkerTaskResult.model_validate(event.payload) if task_event.task_id == task_id: res_wrapper.append(task_event.task_result) @@ -252,7 +252,7 @@ def _launch_task( message=custom_event_messages.start if custom_event_messages is not None else f"Task {task.id} added", - ).dict(), + ).model_dump(), permissions=PermissionInfo(owner=request_params.user.impersonator), ) ) @@ -361,7 +361,7 @@ def _run_task( message=custom_event_messages.running if custom_event_messages is not None else f"Task {task_id} is running", - ).dict(), + ).model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, ) @@ -407,7 +407,7 @@ def _run_task( if custom_event_messages is not None else f"Task {task_id} {event_msg}" ), - ).dict(), + ).model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, ) @@ -432,7 +432,7 @@ def _run_task( self.event_bus.push( Event( type=EventType.TASK_FAILED, - payload=TaskEventPayload(id=task_id, message=message).dict(), + payload=TaskEventPayload(id=task_id, message=message).model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.TASK + task_id, ) diff --git a/antarest/core/utils/fastapi_sqlalchemy/middleware.py b/antarest/core/utils/fastapi_sqlalchemy/middleware.py index 9a98b4ef1b..dcc1f95b25 100644 --- a/antarest/core/utils/fastapi_sqlalchemy/middleware.py +++ b/antarest/core/utils/fastapi_sqlalchemy/middleware.py @@ -12,7 +12,7 @@ from antarest.core.utils.fastapi_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError -_Session: sessionmaker = None +_Session: Optional[sessionmaker] = None _session: ContextVar[Optional[Session]] = ContextVar("_session", default=None) @@ -93,4 +93,4 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: _session.reset(self.token) -db: DBSessionMeta = DBSession +db: Type[DBSession] = DBSession diff --git a/antarest/core/version_info.py b/antarest/core/version_info.py index fe144d1a02..ea96eaf4ef 100644 --- a/antarest/core/version_info.py +++ b/antarest/core/version_info.py @@ -28,7 +28,7 @@ class VersionInfoDTO(BaseModel): dependencies: Dict[str, str] class Config: - schema_extra = { + json_schema_extra = { "example": { "name": "AntaREST", "version": "2.13.2", diff --git a/antarest/eventbus/business/redis_eventbus.py b/antarest/eventbus/business/redis_eventbus.py index 82325984ef..f3642bf994 100644 --- a/antarest/eventbus/business/redis_eventbus.py +++ b/antarest/eventbus/business/redis_eventbus.py @@ -29,10 +29,10 @@ def __init__(self, redis_client: Redis) -> None: # type: ignore self.pubsub.subscribe(REDIS_STORE_KEY) def push_event(self, event: Event) -> None: - self.redis.publish(REDIS_STORE_KEY, event.json()) + self.redis.publish(REDIS_STORE_KEY, event.model_dump_json()) def queue_event(self, event: Event, queue: str) -> None: - self.redis.rpush(queue, event.json()) + self.redis.rpush(queue, event.model_dump_json()) def pull_queue(self, queue: str) -> Optional[Event]: event = self.redis.lpop(queue) diff --git a/antarest/eventbus/main.py b/antarest/eventbus/main.py index 7c53f5cbce..6ccf56e644 100644 --- a/antarest/eventbus/main.py +++ b/antarest/eventbus/main.py @@ -12,9 +12,10 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from redis import Redis +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.business.redis_eventbus import RedisEventBus @@ -23,7 +24,7 @@ def build_eventbus( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, autostart: bool = True, redis_client: Optional[Redis] = None, # type: ignore @@ -33,6 +34,6 @@ def build_eventbus( autostart, ) - if application: - configure_websockets(application, config, eventbus) + if app_ctxt: + configure_websockets(app_ctxt, config, eventbus) return eventbus diff --git a/antarest/eventbus/web.py b/antarest/eventbus/web.py index 481db7b371..ba363db6c4 100644 --- a/antarest/eventbus/web.py +++ b/antarest/eventbus/web.py @@ -17,16 +17,17 @@ from http import HTTPStatus from typing import List, Optional -from fastapi import Depends, FastAPI, HTTPException, Query -from fastapi_jwt_auth import AuthJWT # type: ignore +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query from pydantic import BaseModel from starlette.websockets import WebSocket, WebSocketDisconnect +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.eventbus import Event, IEventBus from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser from antarest.core.model import PermissionInfo, StudyPermissionType from antarest.core.permissions import check_permission +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth logger = logging.getLogger(__name__) @@ -91,16 +92,16 @@ async def broadcast(self, message: str, permissions: PermissionInfo, channel: st await connection.websocket.send_text(message) -def configure_websockets(application: FastAPI, config: Config, event_bus: IEventBus) -> None: +def configure_websockets(app_ctxt: AppBuildContext, config: Config, event_bus: IEventBus) -> None: manager = ConnectionManager() async def send_event_to_ws(event: Event) -> None: - event_data = event.dict() + event_data = event.model_dump() del event_data["permissions"] del event_data["channel"] await manager.broadcast(json.dumps(event_data), event.permissions, event.channel) - @application.websocket("/ws") + @app_ctxt.api_root.websocket("/ws") async def connect( websocket: WebSocket, token: str = Query(...), diff --git a/antarest/fastapi_jwt_auth/LICENSE b/antarest/fastapi_jwt_auth/LICENSE new file mode 100644 index 0000000000..ad5e3c8004 --- /dev/null +++ b/antarest/fastapi_jwt_auth/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Nyoman Pradipta Dewantara + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/antarest/fastapi_jwt_auth/README.md b/antarest/fastapi_jwt_auth/README.md new file mode 100644 index 0000000000..c4dfd9017e --- /dev/null +++ b/antarest/fastapi_jwt_auth/README.md @@ -0,0 +1,4 @@ +# FastAPI JWT Auth + +Forked from https://github.com/IndominusByte/fastapi-jwt-auth, +licensed under MIT license, see [LICENSE](LICENSE). diff --git a/antarest/fastapi_jwt_auth/__init__.py b/antarest/fastapi_jwt_auth/__init__.py new file mode 100644 index 0000000000..a3748ffe33 --- /dev/null +++ b/antarest/fastapi_jwt_auth/__init__.py @@ -0,0 +1,7 @@ +"""FastAPI extension that provides JWT Auth support (secure, easy to use and lightweight)""" + +__version__ = "0.5.0" + +__all__ = ["AuthJWT"] + +from .auth_jwt import AuthJWT diff --git a/antarest/fastapi_jwt_auth/auth_config.py b/antarest/fastapi_jwt_auth/auth_config.py new file mode 100644 index 0000000000..55ca04a25d --- /dev/null +++ b/antarest/fastapi_jwt_auth/auth_config.py @@ -0,0 +1,114 @@ +from datetime import timedelta +from typing import Callable, List + +from pydantic import ValidationError + +from .config import LoadConfig + + +class AuthConfig: + _token = None + _token_location = {"headers"} + + _secret_key = None + _public_key = None + _private_key = None + _algorithm = "HS256" + _decode_algorithms = None + _decode_leeway = 0 + _encode_issuer = None + _decode_issuer = None + _decode_audience = None + _denylist_enabled = False + _denylist_token_checks = {"access", "refresh"} + _header_name = "Authorization" + _header_type = "Bearer" + _token_in_denylist_callback = None + _access_token_expires = timedelta(minutes=15) + _refresh_token_expires = timedelta(days=30) + + # option for create cookies + _access_cookie_key = "access_token_cookie" + _refresh_cookie_key = "refresh_token_cookie" + _access_cookie_path = "/" + _refresh_cookie_path = "/" + _cookie_max_age = None + _cookie_domain = None + _cookie_secure = False + _cookie_samesite = None + + # option for double submit csrf protection + _cookie_csrf_protect = True + _access_csrf_cookie_key = "csrf_access_token" + _refresh_csrf_cookie_key = "csrf_refresh_token" + _access_csrf_cookie_path = "/" + _refresh_csrf_cookie_path = "/" + _access_csrf_header_name = "X-CSRF-Token" + _refresh_csrf_header_name = "X-CSRF-Token" + _csrf_methods = {"POST", "PUT", "PATCH", "DELETE"} + + @property + def jwt_in_cookies(self) -> bool: + return "cookies" in self._token_location + + @property + def jwt_in_headers(self) -> bool: + return "headers" in self._token_location + + @classmethod + def load_config(cls, settings: Callable[..., List[tuple]]) -> "AuthConfig": + try: + config = LoadConfig(**{key.lower(): value for key, value in settings()}) + + cls._token_location = config.authjwt_token_location + cls._secret_key = config.authjwt_secret_key + cls._public_key = config.authjwt_public_key + cls._private_key = config.authjwt_private_key + cls._algorithm = config.authjwt_algorithm + cls._decode_algorithms = config.authjwt_decode_algorithms + cls._decode_leeway = config.authjwt_decode_leeway + cls._encode_issuer = config.authjwt_encode_issuer + cls._decode_issuer = config.authjwt_decode_issuer + cls._decode_audience = config.authjwt_decode_audience + cls._denylist_enabled = config.authjwt_denylist_enabled + cls._denylist_token_checks = config.authjwt_denylist_token_checks + cls._header_name = config.authjwt_header_name + cls._header_type = config.authjwt_header_type + cls._access_token_expires = config.authjwt_access_token_expires + cls._refresh_token_expires = config.authjwt_refresh_token_expires + # option for create cookies + cls._access_cookie_key = config.authjwt_access_cookie_key + cls._refresh_cookie_key = config.authjwt_refresh_cookie_key + cls._access_cookie_path = config.authjwt_access_cookie_path + cls._refresh_cookie_path = config.authjwt_refresh_cookie_path + cls._cookie_max_age = config.authjwt_cookie_max_age + cls._cookie_domain = config.authjwt_cookie_domain + cls._cookie_secure = config.authjwt_cookie_secure + cls._cookie_samesite = config.authjwt_cookie_samesite + # option for double submit csrf protection + cls._cookie_csrf_protect = config.authjwt_cookie_csrf_protect + cls._access_csrf_cookie_key = config.authjwt_access_csrf_cookie_key + cls._refresh_csrf_cookie_key = config.authjwt_refresh_csrf_cookie_key + cls._access_csrf_cookie_path = config.authjwt_access_csrf_cookie_path + cls._refresh_csrf_cookie_path = config.authjwt_refresh_csrf_cookie_path + cls._access_csrf_header_name = config.authjwt_access_csrf_header_name + cls._refresh_csrf_header_name = config.authjwt_refresh_csrf_header_name + cls._csrf_methods = config.authjwt_csrf_methods + except ValidationError as e: + raise e + except Exception: + raise TypeError("Config must be pydantic 'BaseSettings' or list of tuple") + + @classmethod + def token_in_denylist_loader(cls, callback: Callable[..., bool]) -> "AuthConfig": + """ + This decorator sets the callback function that will be called when + a protected endpoint is accessed and will check if the JWT has been + been revoked. By default, this callback is not used. + + *HINT*: The callback must be a function that takes decrypted_token argument, + args for object AuthJWT and this is not used, decrypted_token is decode + JWT (python dictionary) and returns *`True`* if the token has been deny, + or *`False`* otherwise. + """ + cls._token_in_denylist_callback = callback diff --git a/antarest/fastapi_jwt_auth/auth_jwt.py b/antarest/fastapi_jwt_auth/auth_jwt.py new file mode 100644 index 0000000000..19fa3173b4 --- /dev/null +++ b/antarest/fastapi_jwt_auth/auth_jwt.py @@ -0,0 +1,817 @@ +import hmac +import re +import uuid +from datetime import datetime, timedelta, timezone +from typing import Dict, Optional, Sequence, Union + +import jwt +from fastapi import Request, Response, WebSocket +from jwt.algorithms import has_crypto, requires_cryptography + +from .auth_config import AuthConfig +from .exceptions import ( + AccessTokenRequired, + CSRFError, + FreshTokenRequired, + InvalidHeaderError, + JWTDecodeError, + MissingTokenError, + RefreshTokenRequired, + RevokedTokenError, +) + +TYPE_ERROR_MSG = "The response must be an object response FastAPI" + + +class AuthJWT(AuthConfig): + def __init__(self, req: Request = None, res: Response = None): + """ + Get jwt header from incoming request or get + request and response object if jwt in the cookie + + :param req: all incoming request + :param res: response from endpoint + """ + if res and self.jwt_in_cookies: + self._response = res + + if req: + # get request object when cookies in token location + if self.jwt_in_cookies: + self._request = req + # get jwt in headers when headers in token location + if self.jwt_in_headers: + auth = req.headers.get(self._header_name.lower()) + if auth: + self._get_jwt_from_headers(auth) + + def _get_jwt_from_headers(self, auth: str) -> "AuthJWT": + """ + Get token from the headers + + :param auth: value from HeaderName + """ + header_name, header_type = self._header_name, self._header_type + + parts = auth.split() + + # Make sure the header is in a valid format that we are expecting, ie + if not header_type: + # : + if len(parts) != 1: + msg = "Bad {} header. Expected value ''".format(header_name) + raise InvalidHeaderError(status_code=422, message=msg) + self._token = parts[0] + else: + # : + if not re.match(r"{}\s".format(header_type), auth) or len(parts) != 2: + msg = "Bad {} header. Expected value '{} '".format(header_name, header_type) + raise InvalidHeaderError(status_code=422, message=msg) + self._token = parts[1] + + def _get_jwt_identifier(self) -> str: + return str(uuid.uuid4()) + + def _get_int_from_datetime(self, value: datetime) -> int: + """ + :param value: datetime with or without timezone, if don't contains timezone + it will managed as it is UTC + :return: Seconds since the Epoch + """ + if not isinstance(value, datetime): # pragma: no cover + raise TypeError("a datetime is required") + return int(value.timestamp()) + + def _get_secret_key(self, algorithm: str, process: str) -> str: + """ + Get key with a different algorithm + + :param algorithm: algorithm for decode and encode token + :param process: for indicating get key for encode or decode token + + :return: plain text or RSA depends on algorithm + """ + symmetric_algorithms, asymmetric_algorithms = {"HS256", "HS384", "HS512"}, requires_cryptography + + if algorithm not in symmetric_algorithms and algorithm not in asymmetric_algorithms: + raise ValueError("Algorithm {} could not be found".format(algorithm)) + + if algorithm in symmetric_algorithms: + if not self._secret_key: + raise RuntimeError("authjwt_secret_key must be set when using symmetric algorithm {}".format(algorithm)) + + return self._secret_key + + if algorithm in asymmetric_algorithms and not has_crypto: + raise RuntimeError( + "Missing dependencies for using asymmetric algorithms. run 'pip install fastapi-jwt-auth[asymmetric]'" + ) + + if process == "encode": + if not self._private_key: + raise RuntimeError( + "authjwt_private_key must be set when using asymmetric algorithm {}".format(algorithm) + ) + + return self._private_key + + if process == "decode": + if not self._public_key: + raise RuntimeError( + "authjwt_public_key must be set when using asymmetric algorithm {}".format(algorithm) + ) + + return self._public_key + + def _create_token( + self, + subject: Union[str, int], + type_token: str, + exp_time: Optional[int], + fresh: Optional[bool] = False, + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + issuer: Optional[str] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, + ) -> str: + """ + Create token for access_token and refresh_token (utf-8) + + :param subject: Identifier for who this token is for example id or username from database. + :param type_token: indicate token is access_token or refresh_token + :param exp_time: Set the duration of the JWT + :param fresh: Optional when token is access_token this param required + :param algorithm: algorithm allowed to encode the token + :param headers: valid dict for specifying additional headers in JWT header section + :param issuer: expected issuer in the JWT + :param audience: expected audience in the JWT + :param user_claims: Custom claims to include in this token. This data must be dictionary + + :return: Encoded token + """ + # Validation type data + if not isinstance(subject, (str, int)): + raise TypeError("subject must be a string or integer") + if not isinstance(fresh, bool): + raise TypeError("fresh must be a boolean") + if audience and not isinstance(audience, (str, list, tuple, set, frozenset)): + raise TypeError("audience must be a string or sequence") + if algorithm and not isinstance(algorithm, str): + raise TypeError("algorithm must be a string") + if user_claims and not isinstance(user_claims, dict): + raise TypeError("user_claims must be a dictionary") + + # Data section + reserved_claims = { + "sub": subject, + "iat": self._get_int_from_datetime(datetime.now(timezone.utc)), + "nbf": self._get_int_from_datetime(datetime.now(timezone.utc)), + "jti": self._get_jwt_identifier(), + } + + custom_claims = {"type": type_token} + + # for access_token only fresh needed + if type_token == "access": + custom_claims["fresh"] = fresh + # if cookie in token location and csrf protection enabled + if self.jwt_in_cookies and self._cookie_csrf_protect: + custom_claims["csrf"] = self._get_jwt_identifier() + + if exp_time: + reserved_claims["exp"] = exp_time + if issuer: + reserved_claims["iss"] = issuer + if audience: + reserved_claims["aud"] = audience + + algorithm = algorithm or self._algorithm + + secret_key = self._get_secret_key(algorithm, "encode") + + return jwt.encode( + {**reserved_claims, **custom_claims, **user_claims}, secret_key, algorithm=algorithm, headers=headers + ) + + def _has_token_in_denylist_callback(self) -> bool: + """ + Return True if token denylist callback set + """ + return self._token_in_denylist_callback is not None + + def _check_token_is_revoked(self, raw_token: Dict[str, Union[str, int, bool]]) -> None: + """ + Ensure that AUTHJWT_DENYLIST_ENABLED is true and callback regulated, and then + call function denylist callback with passing decode JWT, if true + raise exception Token has been revoked + """ + if not self._denylist_enabled: + return + + if not self._has_token_in_denylist_callback(): + raise RuntimeError( + "A token_in_denylist_callback must be provided via " + "the '@AuthJWT.token_in_denylist_loader' if " + "authjwt_denylist_enabled is 'True'" + ) + + if self._token_in_denylist_callback.__func__(raw_token): + raise RevokedTokenError(status_code=401, message="Token has been revoked") + + def _get_expired_time( + self, type_token: str, expires_time: Optional[Union[timedelta, int, bool]] = None + ) -> Union[None, int]: + """ + Dynamic token expired, if expires_time is False exp claim not created + + :param type_token: indicate token is access_token or refresh_token + :param expires_time: duration expired jwt + + :return: duration exp claim jwt + """ + if expires_time and not isinstance(expires_time, (timedelta, int, bool)): + raise TypeError("expires_time must be between timedelta, int, bool") + + if expires_time is not False: + if type_token == "access": + expires_time = expires_time or self._access_token_expires + if type_token == "refresh": + expires_time = expires_time or self._refresh_token_expires + + if expires_time is not False: + if isinstance(expires_time, bool): + if type_token == "access": + expires_time = self._access_token_expires + if type_token == "refresh": + expires_time = self._refresh_token_expires + if isinstance(expires_time, timedelta): + expires_time = int(expires_time.total_seconds()) + + return self._get_int_from_datetime(datetime.now(timezone.utc)) + expires_time + else: + return None + + def create_access_token( + self, + subject: Union[str, int], + fresh: Optional[bool] = False, + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, + ) -> str: + """ + Create a access token with 15 minutes for expired time (default), + info for param and return check to function create token + + :return: hash token + """ + return self._create_token( + subject=subject, + type_token="access", + exp_time=self._get_expired_time("access", expires_time), + fresh=fresh, + algorithm=algorithm, + headers=headers, + audience=audience, + user_claims=user_claims, + issuer=self._encode_issuer, + ) + + def create_refresh_token( + self, + subject: Union[str, int], + algorithm: Optional[str] = None, + headers: Optional[Dict] = None, + expires_time: Optional[Union[timedelta, int, bool]] = None, + audience: Optional[Union[str, Sequence[str]]] = None, + user_claims: Optional[Dict] = {}, + ) -> str: + """ + Create a refresh token with 30 days for expired time (default), + info for param and return check to function create token + + :return: hash token + """ + return self._create_token( + subject=subject, + type_token="refresh", + exp_time=self._get_expired_time("refresh", expires_time), + algorithm=algorithm, + headers=headers, + audience=audience, + user_claims=user_claims, + ) + + def _get_csrf_token(self, encoded_token: str) -> str: + """ + Returns the CSRF double submit token from an encoded JWT. + + :param encoded_token: The encoded JWT + :return: The CSRF double submit token + """ + return self._verified_token(encoded_token)["csrf"] + + def set_access_cookies( + self, encoded_access_token: str, response: Optional[Response] = None, max_age: Optional[int] = None + ) -> None: + """ + Configures the response to set access token in a cookie. + this will also set the CSRF double submit values in a separate cookie + + :param encoded_access_token: The encoded access token to set in the cookies + :param response: The FastAPI response object to set the access cookies in + :param max_age: The max age of the cookie value should be the number of seconds (integer) + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "set_access_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if max_age and not isinstance(max_age, int): + raise TypeError("max_age must be a integer") + if response and not isinstance(response, Response): + raise TypeError(TYPE_ERROR_MSG) + + response = response or self._response + + # Set the access JWT in the cookie + response.set_cookie( + self._access_cookie_key, + encoded_access_token, + max_age=max_age or self._cookie_max_age, + path=self._access_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=True, + samesite=self._cookie_samesite, + ) + + # If enabled, set the csrf double submit access cookie + if self._cookie_csrf_protect: + response.set_cookie( + self._access_csrf_cookie_key, + self._get_csrf_token(encoded_access_token), + max_age=max_age or self._cookie_max_age, + path=self._access_csrf_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=False, + samesite=self._cookie_samesite, + ) + + def set_refresh_cookies( + self, encoded_refresh_token: str, response: Optional[Response] = None, max_age: Optional[int] = None + ) -> None: + """ + Configures the response to set refresh token in a cookie. + this will also set the CSRF double submit values in a separate cookie + + :param encoded_refresh_token: The encoded refresh token to set in the cookies + :param response: The FastAPI response object to set the refresh cookies in + :param max_age: The max age of the cookie value should be the number of seconds (integer) + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "set_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if max_age and not isinstance(max_age, int): + raise TypeError("max_age must be a integer") + if response and not isinstance(response, Response): + raise TypeError(TYPE_ERROR_MSG) + + response = response or self._response + + # Set the refresh JWT in the cookie + response.set_cookie( + self._refresh_cookie_key, + encoded_refresh_token, + max_age=max_age or self._cookie_max_age, + path=self._refresh_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=True, + samesite=self._cookie_samesite, + ) + + # If enabled, set the csrf double submit refresh cookie + if self._cookie_csrf_protect: + response.set_cookie( + self._refresh_csrf_cookie_key, + self._get_csrf_token(encoded_refresh_token), + max_age=max_age or self._cookie_max_age, + path=self._refresh_csrf_cookie_path, + domain=self._cookie_domain, + secure=self._cookie_secure, + httponly=False, + samesite=self._cookie_samesite, + ) + + def unset_jwt_cookies(self, response: Optional[Response] = None) -> None: + """ + Unset (delete) all jwt stored in a cookie + + :param response: The FastAPI response object to delete the JWT cookies in. + """ + self.unset_access_cookies(response) + self.unset_refresh_cookies(response) + + def unset_access_cookies(self, response: Optional[Response] = None) -> None: + """ + Remove access token and access CSRF double submit from the response cookies + + :param response: The FastAPI response object to delete the access cookies in. + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "unset_access_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if response and not isinstance(response, Response): + raise TypeError(TYPE_ERROR_MSG) + + response = response or self._response + + response.delete_cookie(self._access_cookie_key, path=self._access_cookie_path, domain=self._cookie_domain) + + if self._cookie_csrf_protect: + response.delete_cookie( + self._access_csrf_cookie_key, path=self._access_csrf_cookie_path, domain=self._cookie_domain + ) + + def unset_refresh_cookies(self, response: Optional[Response] = None) -> None: + """ + Remove refresh token and refresh CSRF double submit from the response cookies + + :param response: The FastAPI response object to delete the refresh cookies in. + """ + if not self.jwt_in_cookies: + raise RuntimeWarning( + "unset_refresh_cookies() called without 'authjwt_token_location' configured to use cookies" + ) + + if response and not isinstance(response, Response): + raise TypeError(TYPE_ERROR_MSG) + + response = response or self._response + + response.delete_cookie(self._refresh_cookie_key, path=self._refresh_cookie_path, domain=self._cookie_domain) + + if self._cookie_csrf_protect: + response.delete_cookie( + self._refresh_csrf_cookie_key, path=self._refresh_csrf_cookie_path, domain=self._cookie_domain + ) + + def _verify_and_get_jwt_optional_in_cookies( + self, + request: Union[Request, WebSocket], + csrf_token: Optional[str] = None, + ) -> "AuthJWT": + """ + Optionally check if cookies have a valid access token. if an access token present in + cookies, self._token will set. raises exception error when an access token is invalid + or doesn't match with CSRF token double submit + + :param request: for identity get cookies from HTTP or WebSocket + :param csrf_token: the CSRF double submit token + """ + if not isinstance(request, (Request, WebSocket)): + raise TypeError("request must be an instance of 'Request' or 'WebSocket'") + + cookie_key = self._access_cookie_key + cookie = request.cookies.get(cookie_key) + if not isinstance(request, WebSocket): + csrf_token = request.headers.get(self._access_csrf_header_name) + + if cookie and self._cookie_csrf_protect and not csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + raise CSRFError(status_code=401, message="Missing CSRF Token") + + # set token from cookie and verify jwt + self._token = cookie + self._verify_jwt_optional_in_request(self._token) + + decoded_token = self.get_raw_jwt() + + if decoded_token and self._cookie_csrf_protect and csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + if "csrf" not in decoded_token: + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token["csrf"]): + raise CSRFError(status_code=401, message="CSRF double submit tokens do not match") + + def _verify_and_get_jwt_in_cookies( + self, + type_token: str, + request: Union[Request, WebSocket], + csrf_token: Optional[str] = None, + fresh: Optional[bool] = False, + ) -> "AuthJWT": + """ + Check if cookies have a valid access or refresh token. if an token present in + cookies, self._token will set. raises exception error when an access or refresh token + is invalid or doesn't match with CSRF token double submit + + :param type_token: indicate token is access or refresh token + :param request: for identity get cookies from HTTP or WebSocket + :param csrf_token: the CSRF double submit token + :param fresh: check freshness token if True + """ + if type_token not in ["access", "refresh"]: + raise ValueError("type_token must be between 'access' or 'refresh'") + if not isinstance(request, (Request, WebSocket)): + raise TypeError("request must be an instance of 'Request' or 'WebSocket'") + + if type_token == "access": + cookie_key = self._access_cookie_key + cookie = request.cookies.get(cookie_key) + if not isinstance(request, WebSocket): + csrf_token = request.headers.get(self._access_csrf_header_name) + if type_token == "refresh": + cookie_key = self._refresh_cookie_key + cookie = request.cookies.get(cookie_key) + if not isinstance(request, WebSocket): + csrf_token = request.headers.get(self._refresh_csrf_header_name) + + if not cookie: + raise MissingTokenError(status_code=401, message="Missing cookie {}".format(cookie_key)) + + if self._cookie_csrf_protect and not csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + raise CSRFError(status_code=401, message="Missing CSRF Token") + + # set token from cookie and verify jwt + self._token = cookie + self._verify_jwt_in_request(self._token, type_token, "cookies", fresh) + + decoded_token = self.get_raw_jwt() + + if self._cookie_csrf_protect and csrf_token: + if isinstance(request, WebSocket) or request.method in self._csrf_methods: + if "csrf" not in decoded_token: + raise JWTDecodeError(status_code=422, message="Missing claim: csrf") + if not hmac.compare_digest(csrf_token, decoded_token["csrf"]): + raise CSRFError(status_code=401, message="CSRF double submit tokens do not match") + + def _verify_jwt_optional_in_request(self, token: str) -> None: + """ + Optionally check if this request has a valid access token + + :param token: The encoded JWT + """ + if token: + self._verifying_token(token) + + if token and self.get_raw_jwt(token)["type"] != "access": + raise AccessTokenRequired(status_code=422, message="Only access tokens are allowed") + + def _verify_jwt_in_request( + self, token: str, type_token: str, token_from: str, fresh: Optional[bool] = False + ) -> None: + """ + Ensure that the requester has a valid token. this also check the freshness of the access token + + :param token: The encoded JWT + :param type_token: indicate token is access or refresh token + :param token_from: indicate token from headers cookies, websocket + :param fresh: check freshness token if True + """ + if type_token not in ["access", "refresh"]: + raise ValueError("type_token must be between 'access' or 'refresh'") + if token_from not in ["headers", "cookies", "websocket"]: + raise ValueError("token_from must be between 'headers', 'cookies', 'websocket'") + + if not token: + if token_from == "headers": + raise MissingTokenError(status_code=401, message="Missing {} Header".format(self._header_name)) + if token_from == "websocket": + raise MissingTokenError( + status_code=1008, message="Missing {} token from Query or Path".format(type_token) + ) + + # verify jwt + issuer = self._decode_issuer if type_token == "access" else None + self._verifying_token(token, issuer) + + if self.get_raw_jwt(token)["type"] != type_token: + msg = "Only {} tokens are allowed".format(type_token) + if type_token == "access": + raise AccessTokenRequired(status_code=422, message=msg) + if type_token == "refresh": + raise RefreshTokenRequired(status_code=422, message=msg) + + if fresh and not self.get_raw_jwt(token)["fresh"]: + raise FreshTokenRequired(status_code=401, message="Fresh token required") + + def _verifying_token(self, encoded_token: str, issuer: Optional[str] = None) -> None: + """ + Verified token and check if token is revoked + + :param encoded_token: token hash + :param issuer: expected issuer in the JWT + """ + raw_token = self._verified_token(encoded_token, issuer) + if raw_token["type"] in self._denylist_token_checks: + self._check_token_is_revoked(raw_token) + + def _verified_token(self, encoded_token: str, issuer: Optional[str] = None) -> Dict[str, Union[str, int, bool]]: + """ + Verified token and catch all error from jwt package and return decode token + + :param encoded_token: token hash + :param issuer: expected issuer in the JWT + + :return: raw data from the hash token in the form of a dictionary + """ + algorithms = self._decode_algorithms or [self._algorithm] + secret_key = self._get_secret_key(self._algorithm, "decode") + try: + return jwt.decode( + encoded_token, + secret_key, + issuer=issuer, + audience=self._decode_audience, + leeway=self._decode_leeway, + algorithms=algorithms, + ) + except Exception as err: + raise JWTDecodeError(status_code=422, message=str(err)) + + def jwt_required( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + Only access token can access this function + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_in_cookies("access", websocket, csrf_token) + else: + self._verify_jwt_in_request(token, "access", "websocket") + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers") + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request) + else: + if self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers") + if self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request) + + def jwt_optional( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + If an access token in present in the request you can get data from get_raw_jwt() or get_jwt_subject(), + If no access token is present in the request, this endpoint will still be called, but + get_raw_jwt() or get_jwt_subject() will return None + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_optional_in_cookies(websocket, csrf_token) + else: + self._verify_jwt_optional_in_request(token) + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_optional_in_request(self._token) + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_optional_in_cookies(self._request) + else: + if self.jwt_in_headers: + self._verify_jwt_optional_in_request(self._token) + if self.jwt_in_cookies: + self._verify_and_get_jwt_optional_in_cookies(self._request) + + def jwt_refresh_token_required( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + This function will ensure that the requester has a valid refresh token + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_in_cookies("refresh", websocket, csrf_token) + else: + self._verify_jwt_in_request(token, "refresh", "websocket") + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "refresh", "headers") + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("refresh", self._request) + else: + if self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "refresh", "headers") + if self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("refresh", self._request) + + def fresh_jwt_required( + self, + auth_from: str = "request", + token: Optional[str] = None, + websocket: Optional[WebSocket] = None, + csrf_token: Optional[str] = None, + ) -> None: + """ + This function will ensure that the requester has a valid access token and fresh token + + :param auth_from: for identity get token from HTTP or WebSocket + :param token: the encoded JWT, it's required if the protected endpoint use WebSocket to + authorization and get token from Query Url or Path + :param websocket: an instance of WebSocket, it's required if protected endpoint use a cookie to authorization + :param csrf_token: the CSRF double submit token. since WebSocket cannot add specifying additional headers + its must be passing csrf_token manually and can achieve by Query Url or Path + """ + if auth_from == "websocket": + if websocket: + self._verify_and_get_jwt_in_cookies("access", websocket, csrf_token, True) + else: + self._verify_jwt_in_request(token, "access", "websocket", True) + + if auth_from == "request": + if len(self._token_location) == 2: + if self._token and self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers", True) + if not self._token and self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request, fresh=True) + else: + if self.jwt_in_headers: + self._verify_jwt_in_request(self._token, "access", "headers", True) + if self.jwt_in_cookies: + self._verify_and_get_jwt_in_cookies("access", self._request, fresh=True) + + def get_raw_jwt(self, encoded_token: Optional[str] = None) -> Optional[Dict[str, Union[str, int, bool]]]: + """ + this will return the python dictionary which has all of the claims of the JWT that is accessing the endpoint. + If no JWT is currently present, return None instead + + :param encoded_token: The encoded JWT from parameter + :return: claims of JWT + """ + token = encoded_token or self._token + + if token: + return self._verified_token(token) + return None + + def get_jti(self, encoded_token: str) -> str: + """ + Returns the JTI (unique identifier) of an encoded JWT + + :param encoded_token: The encoded JWT from parameter + :return: string of JTI + """ + return self._verified_token(encoded_token)["jti"] + + def get_jwt_subject(self) -> Optional[Union[str, int]]: + """ + this will return the subject of the JWT that is accessing this endpoint. + If no JWT is present, `None` is returned instead. + + :return: sub of JWT + """ + if self._token: + return self._verified_token(self._token)["sub"] + return None diff --git a/antarest/fastapi_jwt_auth/config.py b/antarest/fastapi_jwt_auth/config.py new file mode 100644 index 0000000000..f04ace9cd7 --- /dev/null +++ b/antarest/fastapi_jwt_auth/config.py @@ -0,0 +1,90 @@ +import typing as t +from datetime import timedelta +from typing import List, Optional, Sequence, Union + +from pydantic import BaseModel, StrictBool, StrictInt, StrictStr, ValidationError, field_validator, model_validator + + +class LoadConfig(BaseModel): + authjwt_token_location: Optional[t.Set[StrictStr]] = {"headers"} + authjwt_secret_key: Optional[StrictStr] = None + authjwt_public_key: Optional[StrictStr] = None + authjwt_private_key: Optional[StrictStr] = None + authjwt_algorithm: Optional[StrictStr] = "HS256" + authjwt_decode_algorithms: Optional[List[StrictStr]] = None + authjwt_decode_leeway: Optional[Union[StrictInt, timedelta]] = 0 + authjwt_encode_issuer: Optional[StrictStr] = None + authjwt_decode_issuer: Optional[StrictStr] = None + authjwt_decode_audience: Optional[Union[StrictStr, Sequence[StrictStr]]] = None + authjwt_denylist_enabled: Optional[StrictBool] = False + authjwt_denylist_token_checks: Optional[t.Set[StrictStr]] = {"access", "refresh"} + authjwt_header_name: Optional[StrictStr] = "Authorization" + authjwt_header_type: Optional[StrictStr] = "Bearer" + authjwt_access_token_expires: Optional[Union[StrictBool, StrictInt, timedelta]] = timedelta(minutes=15) + authjwt_refresh_token_expires: Optional[Union[StrictBool, StrictInt, timedelta]] = timedelta(days=30) + # option for create cookies + authjwt_access_cookie_key: Optional[StrictStr] = "access_token_cookie" + authjwt_refresh_cookie_key: Optional[StrictStr] = "refresh_token_cookie" + authjwt_access_cookie_path: Optional[StrictStr] = "/" + authjwt_refresh_cookie_path: Optional[StrictStr] = "/" + authjwt_cookie_max_age: Optional[StrictInt] = None + authjwt_cookie_domain: Optional[StrictStr] = None + authjwt_cookie_secure: Optional[StrictBool] = False + authjwt_cookie_samesite: Optional[StrictStr] = None + # option for double submit csrf protection + authjwt_cookie_csrf_protect: Optional[StrictBool] = True + authjwt_access_csrf_cookie_key: Optional[StrictStr] = "csrf_access_token" + authjwt_refresh_csrf_cookie_key: Optional[StrictStr] = "csrf_refresh_token" + authjwt_access_csrf_cookie_path: Optional[StrictStr] = "/" + authjwt_refresh_csrf_cookie_path: Optional[StrictStr] = "/" + authjwt_access_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" + authjwt_refresh_csrf_header_name: Optional[StrictStr] = "X-CSRF-Token" + authjwt_csrf_methods: Optional[t.Set[StrictStr]] = {"POST", "PUT", "PATCH", "DELETE"} + + @field_validator("authjwt_access_token_expires") + def validate_access_token_expires( + cls, v: Optional[Union[StrictBool, StrictInt, timedelta]] + ) -> Optional[Union[StrictBool, StrictInt, timedelta]]: + if v is True: + raise ValueError("The 'authjwt_access_token_expires' only accept value False (bool)") + return v + + @field_validator("authjwt_refresh_token_expires") + def validate_refresh_token_expires( + cls, v: Optional[Union[StrictBool, StrictInt, timedelta]] + ) -> Optional[Union[StrictBool, StrictInt, timedelta]]: + if v is True: + raise ValueError("The 'authjwt_refresh_token_expires' only accept value False (bool)") + return v + + @field_validator("authjwt_cookie_samesite") + def validate_cookie_samesite(cls, v: Optional[StrictStr]) -> Optional[StrictStr]: + if v not in ["strict", "lax", "none"]: + raise ValueError("The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none'") + return v + + @model_validator(mode="before") + def check_type_validity(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + for _ in values.get("authjwt_csrf_methods", []): + if _.upper() not in ["POST", "PUT", "PATCH", "DELETE"]: + raise ValidationError( + f"The 'authjwt_csrf_methods' must be between http request methods and it's {_.upper()}" + ) + + for _ in values.get("authjwt_cookie_samesite", []): + if _ not in ["strict", "lax", "none"]: + raise ValidationError( + f"The 'authjwt_cookie_samesite' must be between 'strict', 'lax', 'none' and it's {_}" + ) + + for _ in values.get("authjwt_token_location", []): + if _ not in ["headers", "cookies"]: + raise ValidationError( + f"The 'authjwt_token_location' must be between 'headers' or 'cookies' and it's {_}" + ) + + return values + + class Config: + str_min_length = 1 + str_strip_whitespace = True diff --git a/antarest/fastapi_jwt_auth/exceptions.py b/antarest/fastapi_jwt_auth/exceptions.py new file mode 100644 index 0000000000..1057571c0a --- /dev/null +++ b/antarest/fastapi_jwt_auth/exceptions.py @@ -0,0 +1,89 @@ +class AuthJWTException(Exception): + """ + Base except which all fastapi_jwt_auth errors extend + """ + + pass + + +class InvalidHeaderError(AuthJWTException): + """ + An error getting jwt in header or jwt header information from a request + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class JWTDecodeError(AuthJWTException): + """ + An error decoding a JWT + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class CSRFError(AuthJWTException): + """ + An error with CSRF protection + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class MissingTokenError(AuthJWTException): + """ + Error raised when token not found + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class RevokedTokenError(AuthJWTException): + """ + Error raised when a revoked token attempt to access a protected endpoint + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class AccessTokenRequired(AuthJWTException): + """ + Error raised when a valid, non-access JWT attempt to access an endpoint + protected by jwt_required, jwt_optional, fresh_jwt_required + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class RefreshTokenRequired(AuthJWTException): + """ + Error raised when a valid, non-refresh JWT attempt to access an endpoint + protected by jwt_refresh_token_required + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message + + +class FreshTokenRequired(AuthJWTException): + """ + Error raised when a valid, non-fresh JWT attempt to access an endpoint + protected by fresh_jwt_required + """ + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + self.message = message diff --git a/antarest/front.py b/antarest/front.py new file mode 100644 index 0000000000..8de0f05e82 --- /dev/null +++ b/antarest/front.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. +""" +This module contains the logic necessary to serve both +the front-end application and the backend HTTP application. + +This includes: + - serving static frontend files + - redirecting "not found" requests to home, which itself redirects to index.html + - providing the endpoint /config.json, which the front-end uses to know + what are the API and websocket prefixes +""" + +import re +from pathlib import Path +from typing import Any, Optional, Sequence + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import FileResponse +from starlette.staticfiles import StaticFiles +from starlette.types import ASGIApp + +from antarest.core.utils.string import to_camel_case + + +class RedirectMiddleware(BaseHTTPMiddleware): + """ + Middleware that rewrites the URL path to "/" for incoming requests + that do not match the known end points. This is useful for redirecting requests + to the main page of a ReactJS application when the user refreshes the browser. + """ + + def __init__( + self, + app: ASGIApp, + dispatch: Optional[DispatchFunction] = None, + route_paths: Sequence[str] = (), + ) -> None: + """ + Initializes an instance of the URLRewriterMiddleware. + + Args: + app: The ASGI application to which the middleware is applied. + dispatch: The dispatch function to use. + route_paths: The known route paths of the application. + Requests that do not match any of these paths will be rewritten to the root path. + + Note: + The `route_paths` should contain all the known endpoints of the application. + """ + dispatch = self.dispatch if dispatch is None else dispatch + super().__init__(app, dispatch) + self.known_prefixes = {re.findall(r"/(?:(?!/).)*", p)[0] for p in route_paths if p != "/"} + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Any: + """ + Intercepts the incoming request and rewrites the URL path if necessary. + Passes the modified or original request to the next middleware or endpoint handler. + """ + url_path = request.scope["path"] + if url_path in {"", "/"}: + pass + elif not any(url_path.startswith(ep) for ep in self.known_prefixes): + request.scope["path"] = "/" + return await call_next(request) + + +class BackEndConfig(BaseModel): + """ + Configuration about backend URLs served to the frontend. + """ + + rest_endpoint: str + ws_endpoint: str + + class Config: + populate_by_name = True + alias_generator = to_camel_case + + +def create_backend_config(api_prefix: str) -> BackEndConfig: + if not api_prefix.startswith("/"): + api_prefix = "/" + api_prefix + return BackEndConfig(rest_endpoint=f"{api_prefix}", ws_endpoint=f"{api_prefix}/ws") + + +def add_front_app(application: FastAPI, resources_dir: Path, api_prefix: str) -> None: + """ + This functions adds the logic necessary to serve both + the front-end application and the backend HTTP application. + + This includes: + - serving static frontend files + - redirecting "not found" requests to home, which itself redirects to index.html + - providing the endpoint /config.json, which the front-end uses to know + what are the API and websocket prefixes + """ + backend_config = create_backend_config(api_prefix) + + front_app_dir = resources_dir / "webapp" + + # Serve front-end files + application.mount( + "/static", + StaticFiles(directory=front_app_dir), + name="static", + ) + + # Redirect home to index.html + @application.get("/", include_in_schema=False) + def home(request: Request) -> Any: + return FileResponse(front_app_dir / "index.html", 200) + + # Serve config for the front-end at /config.json + @application.get("/config.json", include_in_schema=False) + def get_api_paths_config(request: Request) -> BackEndConfig: + return backend_config + + # When the web application is running in Desktop mode, the ReactJS web app + # is served at the `/static` entry point. Any requests that are not API + # requests should be redirected to the `index.html` file, which will handle + # the route provided by the URL. + route_paths = [r.path for r in application.routes] # type: ignore + application.add_middleware( + RedirectMiddleware, + route_paths=route_paths, + ) diff --git a/antarest/gui.py b/antarest/gui.py index af4c41d1e3..f36904a060 100644 --- a/antarest/gui.py +++ b/antarest/gui.py @@ -18,16 +18,8 @@ from multiprocessing import Process from pathlib import Path -try: - # `httpx` is a modern alternative to the `requests` library - import httpx as requests - from httpx import ConnectError as ConnectionError -except ImportError: - # noinspection PyUnresolvedReferences, PyPackageRequirements - import requests - from requests import ConnectionError - -import uvicorn # type: ignore +import httpx +import uvicorn from PyQt5.QtGui import QIcon from PyQt5.QtWidgets import QAction, QApplication, QMenu, QSystemTrayIcon @@ -102,8 +94,8 @@ def main() -> None: ) server.start() for _ in range(30, 0, -1): - with contextlib.suppress(ConnectionError): - res = requests.get("http://localhost:8080") + with contextlib.suppress(httpx.ConnectError): + res = httpx.get("http://localhost:8080") if res.status_code == 200: break time.sleep(1) diff --git a/antarest/launcher/adapters/abstractlauncher.py b/antarest/launcher/adapters/abstractlauncher.py index 64aed89fe0..4b845f8de0 100644 --- a/antarest/launcher/adapters/abstractlauncher.py +++ b/antarest/launcher/adapters/abstractlauncher.py @@ -100,7 +100,7 @@ def update_log(log_line: str) -> None: ) launch_progress_json = self.cache.get(id=f"Launch_Progress_{job_id}") or {} - launch_progress_dto = LaunchProgressDTO.parse_obj(launch_progress_json) + launch_progress_dto = LaunchProgressDTO.model_validate(launch_progress_json) if launch_progress_dto.parse_log_lines(log_line.splitlines()): self.event_bus.push( Event( @@ -114,6 +114,6 @@ def update_log(log_line: str) -> None: channel=EventChannelDirectory.JOB_STATUS + job_id, ) ) - self.cache.put(f"Launch_Progress_{job_id}", launch_progress_dto.dict()) + self.cache.put(f"Launch_Progress_{job_id}", launch_progress_dto.model_dump()) return update_log diff --git a/antarest/launcher/main.py b/antarest/launcher/main.py index 3370936f54..1916b9607b 100644 --- a/antarest/launcher/main.py +++ b/antarest/launcher/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache @@ -26,7 +27,7 @@ def build_launcher( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, study_service: StudyService, file_transfer_manager: FileTransferManager, @@ -49,7 +50,7 @@ def build_launcher( cache=cache, ) - if service_launcher and application: - application.include_router(create_launcher_api(service_launcher, config)) + if service_launcher and app_ctxt: + app_ctxt.api_root.include_router(create_launcher_api(service_launcher, config)) return service_launcher diff --git a/antarest/launcher/model.py b/antarest/launcher/model.py index 8b47d3a64f..9a330454dc 100644 --- a/antarest/launcher/model.py +++ b/antarest/launcher/model.py @@ -15,17 +15,17 @@ import typing as t from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore from sqlalchemy.orm import relationship # type: ignore from antarest.core.persistence import Base -from antarest.core.utils.string import to_camel_case from antarest.login.model import Identity, UserInfo +from antarest.study.business.all_optional_meta import camel_case_model class XpansionParametersDTO(BaseModel): - output_id: t.Optional[str] + output_id: t.Optional[str] = None sensitivity_mode: bool = False enabled: bool = True @@ -54,7 +54,7 @@ def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO """ if params is None: return cls() - return cls.parse_obj(json.loads(params)) + return cls.model_validate(json.loads(params)) class LogType(str, enum.Enum): @@ -125,7 +125,7 @@ class JobResultDTO(BaseModel): class Config: @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = JobResultDTO( id="b2a9f6a7-7f8f-4f7a-9a8b-1f9b4c5d6e7f", study_id="b2a9f6a7-7f8f-4f7a-9a8b-1f9b4c5d6e7f", @@ -139,7 +139,7 @@ def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: exit_code=0, solver_stats="time: 1651s, call_count: 1, optimization_issues: []", owner=UserInfo(id=0o007, name="James BOND"), - ) + ).model_dump() class JobLog(Base): # type: ignore @@ -240,13 +240,8 @@ class LauncherEnginesDTO(BaseModel): engines: t.List[str] -class LauncherLoadDTO( - BaseModel, - extra="forbid", - validate_assignment=True, - allow_population_by_field_name=True, - alias_generator=to_camel_case, -): +@camel_case_model +class LauncherLoadDTO(BaseModel, extra="forbid", validate_assignment=True, populate_by_name=True): """ DTO representing the load of the SLURM cluster or local machine. diff --git a/antarest/launcher/service.py b/antarest/launcher/service.py index 346b03c371..2dad1aef00 100644 --- a/antarest/launcher/service.py +++ b/antarest/launcher/service.py @@ -22,7 +22,7 @@ from fastapi import HTTPException -from antarest.core.config import Config, NbCoresConfig +from antarest.core.config import Config, Launcher, NbCoresConfig from antarest.core.exceptions import StudyNotFoundError from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.filetransfer.service import FileTransferManager @@ -114,7 +114,7 @@ def _init_extensions(self) -> Dict[str, ILauncherExtension]: def get_launchers(self) -> List[str]: return list(self.launchers.keys()) - def get_nb_cores(self, launcher: str) -> NbCoresConfig: + def get_nb_cores(self, launcher: Launcher) -> NbCoresConfig: """ Retrieve the configuration of the launcher's nb of cores. @@ -123,9 +123,6 @@ def get_nb_cores(self, launcher: str) -> NbCoresConfig: Returns: Number of cores of the launcher - - Raises: - InvalidConfigurationError: if the launcher configuration is not available """ return self.config.launcher.get_nb_cores(launcher) @@ -186,7 +183,7 @@ def update( self.event_bus.push( Event( type=EventType.STUDY_JOB_COMPLETED if final_status else EventType.STUDY_JOB_STATUS_UPDATE, - payload=job_result.to_dto().dict(), + payload=job_result.to_dto().model_dump(), permissions=PermissionInfo(public_mode=PublicMode.READ), channel=EventChannelDirectory.JOB_STATUS + job_result.id, ) @@ -247,7 +244,7 @@ def run_study( study_id=study_uuid, job_status=JobStatus.PENDING, launcher=launcher, - launcher_params=launcher_parameters.json() if launcher_parameters else None, + launcher_params=launcher_parameters.model_dump_json() if launcher_parameters else None, owner_id=(owner_id or None), ) self.job_result_repository.save(job_status) @@ -263,7 +260,7 @@ def run_study( self.event_bus.push( Event( type=EventType.STUDY_JOB_STARTED, - payload=job_status.to_dto().dict(), + payload=job_status.to_dto().model_dump(), permissions=PermissionInfo.from_study(study_info), ) ) @@ -304,7 +301,7 @@ def kill_job(self, job_id: str, params: RequestParameters) -> JobResult: self.event_bus.push( Event( type=EventType.STUDY_JOB_CANCELLED, - payload=job_status.to_dto().dict(), + payload=job_status.to_dto().model_dump(), permissions=PermissionInfo.from_study(study), channel=EventChannelDirectory.JOB_STATUS + job_result.id, ) @@ -718,5 +715,7 @@ def get_launch_progress(self, job_id: str, params: RequestParameters) -> float: if launcher is None: raise ValueError(f"Job {job_id} has no launcher") - launch_progress_json = self.launchers[launcher].cache.get(id=f"Launch_Progress_{job_id}") or {"progress": 0} + launch_progress_json: Dict[str, float] = self.launchers[launcher].cache.get(id=f"Launch_Progress_{job_id}") or { + "progress": 0 + } return launch_progress_json.get("progress", 0) diff --git a/antarest/launcher/ssh_client.py b/antarest/launcher/ssh_client.py index 38e7334b75..175c8a739b 100644 --- a/antarest/launcher/ssh_client.py +++ b/antarest/launcher/ssh_client.py @@ -43,7 +43,7 @@ class SlurmError(Exception): def execute_command(ssh_config: SSHConfigDTO, args: List[str]) -> Any: command = " ".join(args) try: - with ssh_client(ssh_config) as client: # type: ignore + with ssh_client(ssh_config) as client: # type: paramiko.SSHClient _, stdout, stderr = client.exec_command(command, timeout=10) output = stdout.read().decode("utf-8").strip() error = stderr.read().decode("utf-8").strip() diff --git a/antarest/launcher/ssh_config.py b/antarest/launcher/ssh_config.py index f35d4d0da0..5238e07608 100644 --- a/antarest/launcher/ssh_config.py +++ b/antarest/launcher/ssh_config.py @@ -14,7 +14,7 @@ from typing import Any, Dict, Optional import paramiko -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, model_validator class SSHConfigDTO(BaseModel): @@ -26,7 +26,7 @@ class SSHConfigDTO(BaseModel): key_password: Optional[str] = "" password: Optional[str] = "" - @root_validator() + @model_validator(mode="before") def validate_connection_information(cls, values: Dict[str, Any]) -> Dict[str, Any]: if "private_key_file" not in values and "password" not in values: raise paramiko.AuthenticationException("SSH config needs at least a private key or a password") diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index 88bceaad35..c07261384a 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -18,7 +18,7 @@ from fastapi import APIRouter, Depends, Query from fastapi.exceptions import HTTPException -from antarest.core.config import Config, InvalidConfigurationError +from antarest.core.config import Config, InvalidConfigurationError, Launcher from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters @@ -54,8 +54,8 @@ def __init__(self, solver: str) -> None: LauncherQuery = Query( - "default", - examples={ + default=Launcher.DEFAULT, + openapi_examples={ "Default launcher": { "description": "Default solver (auto-detected)", "value": "default", @@ -245,7 +245,7 @@ def get_load() -> LauncherLoadDTO: summary="Get list of supported solver versions", response_model=List[str], ) - def get_solver_versions(solver: str = LauncherQuery) -> List[str]: + def get_solver_versions(solver: Launcher = Launcher.DEFAULT) -> List[str]: """ Get list of supported solver versions defined in the configuration. @@ -253,8 +253,6 @@ def get_solver_versions(solver: str = LauncherQuery) -> List[str]: - `solver`: name of the configuration to read: "default", "slurm" or "local". """ logger.info(f"Fetching the list of solver versions for the '{solver}' configuration") - if solver not in {"default", "slurm", "local"}: - raise UnknownSolverConfig(solver) return service.get_solver_versions(solver) # noinspection SpellCheckingInspection @@ -264,7 +262,7 @@ def get_solver_versions(solver: str = LauncherQuery) -> List[str]: summary="Retrieving Min, Default, and Max Core Count", response_model=Dict[str, int], ) - def get_nb_cores(launcher: str = LauncherQuery) -> Dict[str, int]: + def get_nb_cores(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]: """ Retrieve the numer of cores of the launcher. @@ -289,7 +287,7 @@ def get_nb_cores(launcher: str = LauncherQuery) -> Dict[str, int]: tags=[APITag.launcher], summary="Retrieve the time limit for a job (in hours)", ) - def get_time_limit(launcher: str = LauncherQuery) -> Dict[str, int]: + def get_time_limit(launcher: Launcher = LauncherQuery) -> Dict[str, int]: """ Retrieve the time limit for a job (in hours) of the given launcher: "local" or "slurm". diff --git a/antarest/login/auth.py b/antarest/login/auth.py index 6cd2a3139f..ad4884111a 100644 --- a/antarest/login/auth.py +++ b/antarest/login/auth.py @@ -16,13 +16,13 @@ from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union from fastapi import Depends -from fastapi_jwt_auth import AuthJWT # type: ignore from pydantic import BaseModel from ratelimit.types import Scope # type: ignore from starlette.requests import Request from antarest.core.config import Config from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser +from antarest.fastapi_jwt_auth import AuthJWT logger = logging.getLogger(__name__) @@ -66,14 +66,14 @@ def get_current_user(self, auth_jwt: AuthJWT = Depends()) -> JWTUser: auth_jwt.jwt_required() - user = JWTUser.parse_obj(json.loads(auth_jwt.get_jwt_subject())) + user = JWTUser.model_validate(json.loads(auth_jwt.get_jwt_subject())) return user @staticmethod def get_user_from_token(token: str, jwt_manager: AuthJWT) -> Optional[JWTUser]: try: token_data = jwt_manager._verified_token(token) - return JWTUser.parse_obj(json.loads(token_data["sub"])) + return JWTUser.model_validate(json.loads(token_data["sub"])) except Exception as e: logger.debug("Failed to retrieve user from token", exc_info=e) return None diff --git a/antarest/login/ldap.py b/antarest/login/ldap.py index 636e2a564f..1635efe09d 100644 --- a/antarest/login/ldap.py +++ b/antarest/login/ldap.py @@ -14,12 +14,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional -try: - # `httpx` is a modern alternative to the `requests` library - import httpx as requests -except ImportError: - # noinspection PyUnresolvedReferences, PyPackageRequirements - import requests +import httpx from antarest.core.config import Config from antarest.core.model import JSON @@ -110,7 +105,7 @@ def _fetch(self, name: str, password: str) -> Optional[ExternalUser]: auth = AuthDTO(user=name, password=password) try: - res = requests.post(url=f"{self.url}/auth", json=auth.to_json()) + res = httpx.post(url=f"{self.url}/auth", json=auth.to_json()) except Exception as e: logger.warning( "Failed to retrieve user from external auth service", diff --git a/antarest/login/main.py b/antarest/login/main.py index fe1acea550..09fac37ee5 100644 --- a/antarest/login/main.py +++ b/antarest/login/main.py @@ -14,15 +14,16 @@ from http import HTTPStatus from typing import Any, Optional -from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT # type: ignore -from fastapi_jwt_auth.exceptions import AuthJWTException # type: ignore +from fastapi import APIRouter, FastAPI from starlette.requests import Request from starlette.responses import JSONResponse +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus from antarest.core.utils.fastapi_sqlalchemy import db +from antarest.fastapi_jwt_auth import AuthJWT +from antarest.fastapi_jwt_auth.exceptions import AuthJWTException from antarest.login.ldap import LdapService from antarest.login.repository import BotRepository, GroupRepository, RoleRepository, UserLdapRepository, UserRepository from antarest.login.service import LoginService @@ -30,7 +31,7 @@ def build_login( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, service: Optional[LoginService] = None, event_bus: IEventBus = DummyEventBusService(), @@ -39,7 +40,7 @@ def build_login( Login module linking dependency Args: - application: flask application + app_ctxt: application config: server configuration service: used by testing to inject mock. Let None to use true instantiation event_bus: used by testing to inject mock. Let None to use true instantiation @@ -66,9 +67,9 @@ def build_login( event_bus=event_bus, ) - if application: + if app_ctxt: - @application.exception_handler(AuthJWTException) + @app_ctxt.app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> Any: return JSONResponse( status_code=HTTPStatus.UNAUTHORIZED, @@ -83,6 +84,6 @@ def check_if_token_is_revoked(decrypted_token: Any) -> bool: with db(): return token_type == "bots" and service is not None and not service.exists_bot(user_id) - if application: - application.include_router(create_login_api(service, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_login_api(service, config)) return service diff --git a/antarest/login/service.py b/antarest/login/service.py index 8ae4cbdb55..d3103a6dfc 100644 --- a/antarest/login/service.py +++ b/antarest/login/service.py @@ -485,7 +485,7 @@ def authenticate(self, name: str, pwd: str) -> Optional[JWTUser]: """ intern: Optional[User] = self.users.get_by_name(name) - if intern and intern.password.check(pwd): # type: ignore + if intern and intern.password.check(pwd): logger.info("successful login from intern user %s", name) return self.get_jwt(intern.id) diff --git a/antarest/login/web.py b/antarest/login/web.py index 8f98012529..6e2b20a19c 100644 --- a/antarest/login/web.py +++ b/antarest/login/web.py @@ -16,7 +16,6 @@ from typing import Any, List, Optional, Union from fastapi import APIRouter, Depends, HTTPException -from fastapi_jwt_auth import AuthJWT # type: ignore from markupsafe import escape from pydantic import BaseModel @@ -25,6 +24,7 @@ from antarest.core.requests import RequestParameters, UserHasNotPermissionError from antarest.core.roles import RoleType from antarest.core.utils.web import APITag +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.auth import Auth from antarest.login.model import ( BotCreateDTO, @@ -67,8 +67,8 @@ def create_login_api(service: LoginService, config: Config) -> APIRouter: auth = Auth(config) def generate_tokens(user: JWTUser, jwt_manager: AuthJWT, expire: Optional[timedelta] = None) -> CredentialsDTO: - access_token = jwt_manager.create_access_token(subject=user.json(), expires_time=expire) - refresh_token = jwt_manager.create_refresh_token(subject=user.json()) + access_token = jwt_manager.create_access_token(subject=user.model_dump_json(), expires_time=expire) + refresh_token = jwt_manager.create_refresh_token(subject=user.model_dump_json()) return CredentialsDTO( user=user.id, access_token=access_token.decode() if isinstance(access_token, bytes) else access_token, @@ -126,11 +126,7 @@ def users_get_all( params = RequestParameters(user=current_user) return service.get_all_users(params, details) - @bp.get( - "/users/{id}", - tags=[APITag.users], - response_model=Union[IdentityDTO, UserInfo], # type: ignore - ) + @bp.get("/users/{id}", tags=[APITag.users], response_model=Union[IdentityDTO, UserInfo]) def users_get_id( id: int, details: bool = False, @@ -204,11 +200,7 @@ def groups_get_all( params = RequestParameters(user=current_user) return service.get_all_groups(params, details) - @bp.get( - "/groups/{id}", - tags=[APITag.users], - response_model=Union[GroupDetailDTO, GroupDTO], # type: ignore - ) + @bp.get("/groups/{id}", tags=[APITag.users], response_model=Union[GroupDetailDTO, GroupDTO]) def groups_get_id( id: str, details: bool = False, @@ -326,11 +318,7 @@ def bots_create( tokens = generate_tokens(jwt, jwt_manager, expire=timedelta(days=368 * 200)) return tokens.access_token - @bp.get( - "/bots/{id}", - tags=[APITag.users], - response_model=Union[BotIdentityDTO, BotDTO], # type: ignore - ) + @bp.get("/bots/{id}", tags=[APITag.users], response_model=Union[BotIdentityDTO, BotDTO]) def get_bot( id: int, verbose: Optional[int] = None, diff --git a/antarest/main.py b/antarest/main.py index c5e4d30197..b1fd1480b6 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -13,29 +13,25 @@ import argparse import copy import logging -import re +from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, cast +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, cast import pydantic -import uvicorn # type: ignore -import uvicorn.config # type: ignore -from fastapi import FastAPI, HTTPException +import uvicorn +import uvicorn.config +from fastapi import APIRouter, FastAPI, HTTPException from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError -from fastapi_jwt_auth import AuthJWT # type: ignore from ratelimit import RateLimitMiddleware # type: ignore from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore -from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.staticfiles import StaticFiles -from starlette.templating import Jinja2Templates -from starlette.types import ASGIApp from antarest import __version__ +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.core_blueprint import create_utils_routes from antarest.core.filesystem_blueprint import create_file_system_blueprint @@ -46,6 +42,8 @@ from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata +from antarest.fastapi_jwt_auth import AuthJWT +from antarest.front import add_front_app from antarest.login.auth import Auth, JwtSettings from antarest.login.model import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector @@ -189,55 +187,6 @@ def parse_arguments() -> argparse.Namespace: return parser.parse_args() -class URLRewriterMiddleware(BaseHTTPMiddleware): - """ - Middleware that rewrites the URL path to "/" (root path) for incoming requests - that do not match the known end points. This is useful for redirecting requests - to the main page of a ReactJS application when the user refreshes the browser. - """ - - def __init__( - self, - app: ASGIApp, - dispatch: Optional[DispatchFunction] = None, - root_path: str = "", - route_paths: Sequence[str] = (), - ) -> None: - """ - Initializes an instance of the URLRewriterMiddleware. - - Args: - app: The ASGI application to which the middleware is applied. - dispatch: The dispatch function to use. - root_path: The root path of the application. - The URL path will be rewritten relative to this root path. - route_paths: The known route paths of the application. - Requests that do not match any of these paths will be rewritten to the root path. - - Note: - The `root_path` can be set to a specific component of the URL path, such as "api". - The `route_paths` should contain all the known endpoints of the application. - """ - dispatch = self.dispatch if dispatch is None else dispatch - super().__init__(app, dispatch) - self.root_path = f"/{root_path}" if root_path else "" - self.known_prefixes = {re.findall(r"/(?:(?!/).)*", p)[0] for p in route_paths if p != "/"} - - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Any: - """ - Intercepts the incoming request and rewrites the URL path if necessary. - Passes the modified or original request to the next middleware or endpoint handler. - """ - url_path = request.scope["path"] - if url_path in {"", "/"}: - pass - elif self.root_path and url_path.startswith(self.root_path): - request.scope["path"] = url_path[len(self.root_path) :] - elif not any(url_path.startswith(ep) for ep in self.known_prefixes): - request.scope["path"] = "/" - return await call_next(request) - - def fastapi_app( config_file: Path, resource_path: Optional[Path] = None, @@ -250,46 +199,39 @@ def fastapi_app( logger.info("Initiating application") + @asynccontextmanager + async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]: + import asyncio + from concurrent.futures import ThreadPoolExecutor + + loop = asyncio.get_running_loop() + loop.set_default_executor(ThreadPoolExecutor(max_workers=config.server.worker_threadpool_size)) + yield + application = FastAPI( title="AntaREST", version=__version__, docs_url=None, root_path=config.root_path, openapi_tags=tags_metadata, + lifespan=set_default_executor, + openapi_url=f"{config.api_prefix}/openapi.json", ) + api_root = APIRouter(prefix=config.api_prefix) + + app_ctxt = AppBuildContext(application, api_root) + # Database engine = init_db_engine(config_file, config, auto_upgrade_db) application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS) + # Since Starlette Version 0.24.0, the middlewares are lazily build inside this function + # But we need to instantiate this middleware as it's needed for the study service. + # So we manually instantiate it here. + DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS)) application.add_middleware(LoggingMiddleware) - if mount_front: - application.mount( - "/static", - StaticFiles(directory=str(res / "webapp")), - name="static", - ) - templates = Jinja2Templates(directory=str(res / "templates")) - - @application.get("/", include_in_schema=False) - def home(request: Request) -> Any: - return templates.TemplateResponse("index.html", {"request": request}) - - else: - # noinspection PyUnusedLocal - @application.get("/", include_in_schema=False) - def home(request: Request) -> Any: - return "" - - @application.on_event("startup") - def set_default_executor() -> None: - import asyncio - from concurrent.futures import ThreadPoolExecutor - - loop = asyncio.get_running_loop() - loop.set_default_executor(ThreadPoolExecutor(max_workers=config.server.worker_threadpool_size)) - # TODO move that elsewhere @AuthJWT.load_config # type: ignore def get_config() -> JwtSettings: @@ -308,8 +250,8 @@ def get_config() -> JwtSettings: allow_methods=["*"], allow_headers=["*"], ) - application.include_router(create_utils_routes(config)) - application.include_router(create_file_system_blueprint(config)) + api_root.include_router(create_utils_routes(config)) + api_root.include_router(create_file_system_blueprint(config)) # noinspection PyUnusedLocal @application.exception_handler(HTTPException) @@ -419,19 +361,9 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: ) init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd) - services = create_services(config, application) + services = create_services(config, app_ctxt) - if mount_front: - # When the web application is running in Desktop mode, the ReactJS web app - # is served at the `/static` entry point. Any requests that are not API - # requests should be redirected to the `index.html` file, which will handle - # the route provided by the URL. - route_paths = [r.path for r in application.routes] # type: ignore - application.add_middleware( - URLRewriterMiddleware, - root_path=application.root_path, - route_paths=route_paths, - ) + application.include_router(api_root) if config.server.services and Module.WATCHER.value in config.server.services: watcher = cast(Watcher, services["watcher"]) @@ -446,6 +378,15 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: auto_archiver.start() customize_openapi(application) + + if mount_front: + add_front_app(application, res, config.api_prefix) + else: + # noinspection PyUnusedLocal + @application.get("/", include_in_schema=False) + def home(request: Request) -> Any: + return "" + cancel_orphan_tasks(engine=engine, session_args=SESSION_ARGS) return application, services diff --git a/antarest/matrixstore/main.py b/antarest/matrixstore/main.py index eaec4f956a..d8eaf0390a 100644 --- a/antarest/matrixstore/main.py +++ b/antarest/matrixstore/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.service import FileTransferManager from antarest.core.tasks.service import ITaskService @@ -24,7 +25,7 @@ def build_matrix_service( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, file_transfer_manager: FileTransferManager, task_service: ITaskService, @@ -35,7 +36,7 @@ def build_matrix_service( Matrix module linking dependency Args: - application: flask application + app_ctxt: application config: server configuration file_transfer_manager: File transfer manager task_service: Task manager @@ -60,7 +61,7 @@ def build_matrix_service( config=config, ) - if application: - application.include_router(create_matrix_api(service, file_transfer_manager, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_matrix_api(service, file_transfer_manager, config)) return service diff --git a/antarest/matrixstore/matrix_editor.py b/antarest/matrixstore/matrix_editor.py index 6fb55225ee..838af83860 100644 --- a/antarest/matrixstore/matrix_editor.py +++ b/antarest/matrixstore/matrix_editor.py @@ -14,7 +14,7 @@ import operator from typing import Any, Dict, List, Optional, Tuple -from pydantic import BaseModel, Extra, Field, root_validator, validator +from pydantic import BaseModel, Field, field_validator, model_validator class MatrixSlice(BaseModel): @@ -35,8 +35,8 @@ class MatrixSlice(BaseModel): column_to: int class Config: - extra = Extra.forbid - schema_extra = { + extra = "forbid" + json_schema_extra = { "example": { "column_from": 5, "column_to": 8, @@ -45,7 +45,7 @@ class Config: } } - @root_validator(pre=True) + @model_validator(mode="before") def check_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Converts and validates the slice coordinates. @@ -107,12 +107,12 @@ class Operation(BaseModel): - `value`: The value associated with the operation. """ - operation: str = Field(regex=r"[=/*+-]|ABS") + operation: str = Field(pattern=r"[=/*+-]|ABS") value: float class Config: - extra = Extra.forbid - schema_extra = {"example": {"operation": "=", "value": 120.0}} + extra = "forbid" + json_schema_extra = {"example": {"operation": "=", "value": 120.0}} # noinspection SpellCheckingInspection def compute(self, x: Any, use_coords: bool = False) -> Any: @@ -157,16 +157,17 @@ class MatrixEditInstruction(BaseModel): operation: Operation class Config: - extra = Extra.forbid - - @staticmethod - def schema_extra(schema: Dict[str, Any]) -> None: - schema["example"] = MatrixEditInstruction( - coordinates=[(0, 10), (0, 11), (0, 12)], - operation=Operation(operation="=", value=120.0), - ) + extra = "forbid" + json_schema_extra = { + "example": { + "column_from": 5, + "column_to": 8, + "row_from": 0, + "row_to": 8760, + } + } - @root_validator(pre=True) + @model_validator(mode="before") def check_slice_coordinates(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validates the 'slices' and 'coordinates' fields. @@ -191,7 +192,7 @@ def check_slice_coordinates(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values - @validator("coordinates") + @field_validator("coordinates") def validate_coordinates(cls, coordinates: Optional[List[Tuple[int, int]]]) -> Optional[List[Tuple[int, int]]]: """ Validates the `coordinates` field. diff --git a/antarest/matrixstore/service.py b/antarest/matrixstore/service.py index 492acb046a..7f9b83cca2 100644 --- a/antarest/matrixstore/service.py +++ b/antarest/matrixstore/service.py @@ -225,6 +225,7 @@ def create_by_importation(self, file: UploadFile, is_json: bool = False) -> t.Li A list of `MatrixInfoDTO` objects containing the SHA256 hash of the imported matrices. """ with file.file as f: + assert file.filename is not None if file.content_type == "application/zip": with contextlib.closing(f): buffer = io.BytesIO(f.read()) diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 039c3d00b5..13395a439a 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -56,13 +56,13 @@ def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IServi services: Dict[Module, IService] = {} if Module.WATCHER in services_list: - watcher = create_watcher(config=config, application=None, study_service=study_service) + watcher = create_watcher(config=config, app_ctxt=None, study_service=study_service) services[Module.WATCHER] = watcher if Module.MATRIX_GC in services_list: matrix_gc = create_matrix_gc( config=config, - application=None, + app_ctxt=None, study_service=study_service, matrix_service=matrix_service, ) diff --git a/antarest/study/business/adequacy_patch_management.py b/antarest/study/business/adequacy_patch_management.py index d05b066ef1..ddc2214891 100644 --- a/antarest/study/business/adequacy_patch_management.py +++ b/antarest/study/business/adequacy_patch_management.py @@ -10,10 +10,11 @@ # # This file is part of the Antares project. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from pydantic.types import StrictBool, confloat, conint +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -29,18 +30,19 @@ class PriceTakingOrder(EnumIgnoreCase): ThresholdType = confloat(ge=0) +@all_optional_model class AdequacyPatchFormFields(FormFieldsBaseModel): # version 830 - enable_adequacy_patch: Optional[StrictBool] - ntc_from_physical_areas_out_to_physical_areas_in_adequacy_patch: Optional[StrictBool] - ntc_between_physical_areas_out_adequacy_patch: Optional[StrictBool] + enable_adequacy_patch: StrictBool + ntc_from_physical_areas_out_to_physical_areas_in_adequacy_patch: StrictBool + ntc_between_physical_areas_out_adequacy_patch: StrictBool # version 850 - price_taking_order: Optional[PriceTakingOrder] - include_hurdle_cost_csr: Optional[StrictBool] - check_csr_cost_function: Optional[StrictBool] - threshold_initiate_curtailment_sharing_rule: Optional[ThresholdType] # type: ignore - threshold_display_local_matching_rule_violations: Optional[ThresholdType] # type: ignore - threshold_csr_variable_bounds_relaxation: Optional[conint(ge=0, strict=True)] # type: ignore + price_taking_order: PriceTakingOrder + include_hurdle_cost_csr: StrictBool + check_csr_cost_function: StrictBool + threshold_initiate_curtailment_sharing_rule: ThresholdType # type: ignore + threshold_display_local_matching_rule_violations: ThresholdType # type: ignore + threshold_csr_variable_bounds_relaxation: conint(ge=0, strict=True) # type: ignore ADEQUACY_PATCH_PATH = f"{GENERAL_DATA_PATH}/adequacy patch" diff --git a/antarest/study/business/advanced_parameters_management.py b/antarest/study/business/advanced_parameters_management.py index eed0420d72..3e68ff0aa1 100644 --- a/antarest/study/business/advanced_parameters_management.py +++ b/antarest/study/business/advanced_parameters_management.py @@ -10,12 +10,13 @@ # # This file is part of the Antares project. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -from pydantic import validator +from pydantic import field_validator from pydantic.types import StrictInt, StrictStr from antarest.core.exceptions import InvalidFieldForVersionError +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -72,33 +73,34 @@ class RenewableGenerationModeling(EnumIgnoreCase): CLUSTERS = "clusters" +@all_optional_model class AdvancedParamsFormFields(FormFieldsBaseModel): # Advanced parameters - accuracy_on_correlation: Optional[StrictStr] + accuracy_on_correlation: StrictStr # Other preferences - initial_reservoir_levels: Optional[InitialReservoirLevel] - power_fluctuations: Optional[PowerFluctuation] - shedding_policy: Optional[SheddingPolicy] - hydro_pricing_mode: Optional[HydroPricingMode] - hydro_heuristic_policy: Optional[HydroHeuristicPolicy] - unit_commitment_mode: Optional[UnitCommitmentMode] - number_of_cores_mode: Optional[SimulationCore] - day_ahead_reserve_management: Optional[ReserveManagement] - renewable_generation_modelling: Optional[RenewableGenerationModeling] + initial_reservoir_levels: InitialReservoirLevel + power_fluctuations: PowerFluctuation + shedding_policy: SheddingPolicy + hydro_pricing_mode: HydroPricingMode + hydro_heuristic_policy: HydroHeuristicPolicy + unit_commitment_mode: UnitCommitmentMode + number_of_cores_mode: SimulationCore + day_ahead_reserve_management: ReserveManagement + renewable_generation_modelling: RenewableGenerationModeling # Seeds - seed_tsgen_wind: Optional[StrictInt] - seed_tsgen_load: Optional[StrictInt] - seed_tsgen_hydro: Optional[StrictInt] - seed_tsgen_thermal: Optional[StrictInt] - seed_tsgen_solar: Optional[StrictInt] - seed_tsnumbers: Optional[StrictInt] - seed_unsupplied_energy_costs: Optional[StrictInt] - seed_spilled_energy_costs: Optional[StrictInt] - seed_thermal_costs: Optional[StrictInt] - seed_hydro_costs: Optional[StrictInt] - seed_initial_reservoir_levels: Optional[StrictInt] - - @validator("accuracy_on_correlation") + seed_tsgen_wind: StrictInt + seed_tsgen_load: StrictInt + seed_tsgen_hydro: StrictInt + seed_tsgen_thermal: StrictInt + seed_tsgen_solar: StrictInt + seed_tsnumbers: StrictInt + seed_unsupplied_energy_costs: StrictInt + seed_spilled_energy_costs: StrictInt + seed_thermal_costs: StrictInt + seed_hydro_costs: StrictInt + seed_initial_reservoir_levels: StrictInt + + @field_validator("accuracy_on_correlation") def check_accuracy_on_correlation(cls, v: str) -> str: sanitized_v = v.strip().replace(" ", "") if not sanitized_v: diff --git a/antarest/study/business/all_optional_meta.py b/antarest/study/business/all_optional_meta.py index 6e23e82e85..7a8440226b 100644 --- a/antarest/study/business/all_optional_meta.py +++ b/antarest/study/business/all_optional_meta.py @@ -10,87 +10,37 @@ # # This file is part of the Antares project. +import copy import typing as t -import pydantic.fields -import pydantic.main -from pydantic import BaseModel +from pydantic import BaseModel, create_model from antarest.core.utils.string import to_camel_case +ModelClass = t.TypeVar("ModelClass", bound=BaseModel) -class AllOptionalMetaclass(pydantic.main.ModelMetaclass): - """ - Metaclass that makes all fields of a Pydantic model optional. - - Usage: - class MyModel(BaseModel, metaclass=AllOptionalMetaclass): - field1: str - field2: int - ... - Instances of the model can be created even if not all fields are provided during initialization. - Default values, when provided, are used unless `use_none` is set to `True`. +def all_optional_model(model: t.Type[ModelClass]) -> t.Type[ModelClass]: """ + This decorator can be used to make all fields of a pydantic model optionals. - def __new__( - cls: t.Type["AllOptionalMetaclass"], - name: str, - bases: t.Tuple[t.Type[t.Any], ...], - namespaces: t.Dict[str, t.Any], - use_none: bool = False, - **kwargs: t.Dict[str, t.Any], - ) -> t.Any: - """ - Create a new instance of the metaclass. - - Args: - name: Name of the class to create. - bases: Base classes of the class to create (a Pydantic model). - namespaces: namespace of the class to create that defines the fields of the model. - use_none: If `True`, the default value of the fields is set to `None`. - Note that this field is not part of the Pydantic model, but it is an extension. - **kwargs: Additional keyword arguments used by the metaclass. - """ - # Modify the annotations of the class (but not of the ancestor classes) - # in order to make all fields optional. - # If the current model inherits from another model, the annotations of the ancestor models - # are not modified, because the fields are already converted to `ModelField`. - annotations = namespaces.get("__annotations__", {}) - for field_name, field_type in annotations.items(): - if not field_name.startswith("__"): - # Making already optional fields optional is not a problem (nothing is changed). - annotations[field_name] = t.Optional[field_type] - namespaces["__annotations__"] = annotations - - if use_none: - # Modify the namespace fields to set their default value to `None`. - for field_name, field_info in namespaces.items(): - if isinstance(field_info, pydantic.fields.FieldInfo): - field_info.default = None - field_info.default_factory = None - - # Create the class: all annotations are converted into `ModelField`. - instance = super().__new__(cls, name, bases, namespaces, **kwargs) - - # Modify the inherited fields of the class to make them optional - # and set their default value to `None`. - model_field: pydantic.fields.ModelField - for field_name, model_field in instance.__fields__.items(): - model_field.required = False - model_field.allow_none = True - if use_none: - model_field.default = None - model_field.default_factory = None - model_field.field_info.default = None - - return instance + Args: + model: The pydantic model to modify. + Returns: + The modified model. + """ + kwargs = {} + for field_name, field_info in model.model_fields.items(): + new = copy.deepcopy(field_info) + new.default = None + new.annotation = t.Optional[field_info.annotation] # type: ignore + kwargs[field_name] = (new.annotation, new) -MODEL = t.TypeVar("MODEL", bound=t.Type[BaseModel]) + return create_model(f"Partial{model.__name__}", __base__=model, __module__=model.__module__, **kwargs) # type: ignore -def camel_case_model(model: MODEL) -> MODEL: +def camel_case_model(model: t.Type[BaseModel]) -> t.Type[BaseModel]: """ This decorator can be used to modify a model to use camel case aliases. @@ -100,7 +50,14 @@ def camel_case_model(model: MODEL) -> MODEL: Returns: The modified model. """ - model.__config__.alias_generator = to_camel_case - for field_name, field in model.__fields__.items(): - field.alias = to_camel_case(field_name) + model.model_config["alias_generator"] = to_camel_case + + # Manually overriding already defined alias names (in base classes), + # otherwise they have precedence over generated ones. + # TODO There is probably a better way to handle those cases + for field_name, field in model.model_fields.items(): + new_alias = to_camel_case(field_name) + field.alias = new_alias + field.validation_alias = new_alias + field.serialization_alias = new_alias return model diff --git a/antarest/study/business/allocation_management.py b/antarest/study/business/allocation_management.py index 7576a9ab63..79c22f26cd 100644 --- a/antarest/study/business/allocation_management.py +++ b/antarest/study/business/allocation_management.py @@ -10,11 +10,13 @@ # # This file is part of the Antares project. -from typing import Dict, List +from typing import Dict, List, Union import numpy import numpy as np -from pydantic import conlist, root_validator, validator +from annotated_types import Len +from pydantic import ValidationInfo, field_validator, model_validator +from typing_extensions import Annotated from antarest.core.exceptions import AllocationDataNotFound, AreaNotFound from antarest.study.business.area_management import AreaInfoDTO @@ -36,9 +38,9 @@ class AllocationFormFields(FormFieldsBaseModel): allocation: List[AllocationField] - @root_validator - def check_allocation(cls, values: Dict[str, List[AllocationField]]) -> Dict[str, List[AllocationField]]: - allocation = values.get("allocation", []) + @model_validator(mode="after") + def check_allocation(self) -> "AllocationFormFields": + allocation = self.allocation if not allocation: raise ValueError("allocation must not be empty") @@ -56,7 +58,7 @@ def check_allocation(cls, values: Dict[str, List[AllocationField]]) -> Dict[str, if sum(a.coefficient for a in allocation) <= 0: raise ValueError("sum of allocation coefficients must be positive") - return values + return self class AllocationMatrix(FormFieldsBaseModel): @@ -67,14 +69,14 @@ class AllocationMatrix(FormFieldsBaseModel): data: 2D-array matrix of consumption coefficients """ - index: conlist(str, min_items=1) # type: ignore - columns: conlist(str, min_items=1) # type: ignore + index: Annotated[List[str], Len(min_length=1)] + columns: Annotated[List[str], Len(min_length=1)] data: List[List[float]] # NonNegativeFloat not necessary # noinspection PyMethodParameters - @validator("data") + @field_validator("data", mode="before") def validate_hydro_allocation_matrix( - cls, data: List[List[float]], values: Dict[str, List[str]] + cls, data: List[List[float]], values: Union[Dict[str, List[str]], ValidationInfo] ) -> List[List[float]]: """ Validate the hydraulic allocation matrix. @@ -89,8 +91,9 @@ def validate_hydro_allocation_matrix( """ array = np.array(data) - rows = len(values.get("index", [])) - cols = len(values.get("columns", [])) + new_values = values if isinstance(values, dict) else values.data + rows = len(new_values.get("index", [])) + cols = len(new_values.get("columns", [])) if array.size == 0: raise ValueError("allocation matrix must not be empty") @@ -136,7 +139,7 @@ def get_allocation_data(self, study: Study, area_id: str) -> Dict[str, List[Allo if not allocation_data: raise AllocationDataNotFound(area_id) - return allocation_data.get("[allocation]", {}) + return allocation_data.get("[allocation]", {}) # type: ignore def get_allocation_form_fields( self, all_areas: List[AreaInfoDTO], study: Study, area_id: str @@ -160,13 +163,10 @@ def get_allocation_form_fields( allocations = self.get_allocation_data(study, area_id) filtered_allocations = {area: value for area, value in allocations.items() if area in areas_ids} - - return AllocationFormFields.construct( - allocation=[ - AllocationField.construct(area_id=area, coefficient=value) - for area, value in filtered_allocations.items() - ] - ) + final_allocations = [ + AllocationField.construct(area_id=area, coefficient=value) for area, value in filtered_allocations.items() + ] + return AllocationFormFields.model_validate({"allocation": final_allocations}) def set_allocation_form_fields( self, diff --git a/antarest/study/business/area_management.py b/antarest/study/business/area_management.py index 4e6297e349..e79f2df575 100644 --- a/antarest/study/business/area_management.py +++ b/antarest/study/business/area_management.py @@ -15,11 +15,11 @@ import re import typing as t -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Field from antarest.core.exceptions import ConfigFileNotFound, DuplicateAreaName, LayerNotAllowedToBeDeleted, LayerNotFound from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Patch, PatchArea, PatchCluster, RawStudy, Study from antarest.study.repository import StudyMetadataRepository @@ -50,8 +50,8 @@ class AreaType(enum.Enum): class AreaCreationDTO(BaseModel): name: str type: AreaType - metadata: t.Optional[PatchArea] - set: t.Optional[t.List[str]] + metadata: t.Optional[PatchArea] = None + set: t.Optional[t.List[str]] = None # review: is this class necessary? @@ -82,7 +82,7 @@ class LayerInfoDTO(BaseModel): areas: t.List[str] -class UpdateAreaUi(BaseModel, extra="forbid", allow_population_by_field_name=True): +class UpdateAreaUi(BaseModel, extra="forbid", populate_by_name=True): """ DTO for updating area UI @@ -107,7 +107,7 @@ class UpdateAreaUi(BaseModel, extra="forbid", allow_population_by_field_name=Tru ... } >>> model = UpdateAreaUi(**obj) - >>> pprint(model.dict(by_alias=True), width=80) + >>> pprint(model.model_dump(by_alias=True), width=80) {'colorRgb': [230, 108, 44], 'layerColor': {0: '230, 108, 44', 4: '230, 108, 44', @@ -179,9 +179,9 @@ class _BaseAreaDTO( OptimizationProperties.FilteringSection, OptimizationProperties.ModalOptimizationSection, AdequacyPathProperties.AdequacyPathSection, - extra=Extra.forbid, + extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Represents an area output. @@ -200,8 +200,9 @@ class _BaseAreaDTO( # noinspection SpellCheckingInspection +@all_optional_model @camel_case_model -class AreaOutput(_BaseAreaDTO, metaclass=AllOptionalMetaclass, use_none=True): +class AreaOutput(_BaseAreaDTO): """ DTO object use to get the area information using a flat structure. """ @@ -227,30 +228,32 @@ def from_model( obj = { "average_unsupplied_energy_cost": average_unsupplied_energy_cost, "average_spilled_energy_cost": average_spilled_energy_cost, - **area_folder.optimization.filtering.dict(by_alias=False), - **area_folder.optimization.nodal_optimization.dict(by_alias=False), + **area_folder.optimization.filtering.model_dump(by_alias=False), + **area_folder.optimization.nodal_optimization.model_dump(by_alias=False), # adequacy_patch is only available if study version >= 830. - **(area_folder.adequacy_patch.adequacy_patch.dict(by_alias=False) if area_folder.adequacy_patch else {}), + **( + area_folder.adequacy_patch.adequacy_patch.model_dump(by_alias=False) + if area_folder.adequacy_patch + else {} + ), } return cls(**obj) def _to_optimization(self) -> OptimizationProperties: - obj = {name: getattr(self, name) for name in OptimizationProperties.FilteringSection.__fields__} + obj = {name: getattr(self, name) for name in OptimizationProperties.FilteringSection.model_fields} filtering_section = OptimizationProperties.FilteringSection(**obj) - obj = {name: getattr(self, name) for name in OptimizationProperties.ModalOptimizationSection.__fields__} + obj = {name: getattr(self, name) for name in OptimizationProperties.ModalOptimizationSection.model_fields} nodal_optimization_section = OptimizationProperties.ModalOptimizationSection(**obj) - return OptimizationProperties( - filtering=filtering_section, - nodal_optimization=nodal_optimization_section, - ) + args = {"filtering": filtering_section, "nodal_optimization": nodal_optimization_section} + return OptimizationProperties.model_validate(args) def _to_adequacy_patch(self) -> t.Optional[AdequacyPathProperties]: - obj = {name: getattr(self, name) for name in AdequacyPathProperties.AdequacyPathSection.__fields__} + obj = {name: getattr(self, name) for name in AdequacyPathProperties.AdequacyPathSection.model_fields} # If all fields are `None`, the object is empty. if all(value is None for value in obj.values()): return None adequacy_path_section = AdequacyPathProperties.AdequacyPathSection(**obj) - return AdequacyPathProperties(adequacy_patch=adequacy_path_section) + return AdequacyPathProperties.model_validate({"adequacy_patch": adequacy_path_section}) @property def area_folder(self) -> AreaFolder: @@ -359,7 +362,7 @@ def update_areas_props( for area_id, update_area in update_areas_by_ids.items(): # Update the area properties. old_area = old_areas_by_ids[area_id] - new_area = old_area.copy(update=update_area.dict(by_alias=False, exclude_none=True)) + new_area = old_area.copy(update=update_area.model_dump(by_alias=False, exclude_none=True)) new_areas_by_ids[area_id] = new_area # Convert the DTO to a configuration object and update the configuration file. @@ -740,7 +743,7 @@ def update_thermal_cluster_metadata( id=area_id, name=file_study.config.areas[area_id].name, type=AreaType.AREA, - metadata=patch.areas.get(area_id, PatchArea()).dict(), + metadata=patch.areas.get(area_id, PatchArea()), thermals=self._get_clusters(file_study, area_id, patch), set=None, ) diff --git a/antarest/study/business/areas/hydro_management.py b/antarest/study/business/areas/hydro_management.py index b7376dae10..d464336d93 100644 --- a/antarest/study/business/areas/hydro_management.py +++ b/antarest/study/business/areas/hydro_management.py @@ -10,10 +10,11 @@ # # This file is part of the Antares project. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union -from pydantic import Field +from pydantic import Field, model_validator +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.utils import FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.storage_service import StudyStorageService @@ -36,22 +37,23 @@ class InflowStructure(FormFieldsBaseModel): ) +@all_optional_model class ManagementOptionsFormFields(FormFieldsBaseModel): - inter_daily_breakdown: Optional[float] = Field(ge=0) - intra_daily_modulation: Optional[float] = Field(ge=1) - inter_monthly_breakdown: Optional[float] = Field(ge=0) - reservoir: Optional[bool] - reservoir_capacity: Optional[float] = Field(ge=0) - follow_load: Optional[bool] - use_water: Optional[bool] - hard_bounds: Optional[bool] - initialize_reservoir_date: Optional[int] = Field(ge=0, le=11) - use_heuristic: Optional[bool] - power_to_level: Optional[bool] - use_leeway: Optional[bool] - leeway_low: Optional[float] = Field(ge=0) - leeway_up: Optional[float] = Field(ge=0) - pumping_efficiency: Optional[float] = Field(ge=0) + inter_daily_breakdown: float = Field(ge=0) + intra_daily_modulation: float = Field(ge=1) + inter_monthly_breakdown: float = Field(ge=0) + reservoir: bool + reservoir_capacity: float = Field(ge=0) + follow_load: bool + use_water: bool + hard_bounds: bool + initialize_reservoir_date: int = Field(ge=0, le=11) + use_heuristic: bool + power_to_level: bool + use_leeway: bool + leeway_low: float = Field(ge=0) + leeway_up: float = Field(ge=0) + pumping_efficiency: float = Field(ge=0) HYDRO_PATH = "input/hydro/hydro" diff --git a/antarest/study/business/areas/properties_management.py b/antarest/study/business/areas/properties_management.py index f4abd5b1bb..ac3dbb902a 100644 --- a/antarest/study/business/areas/properties_management.py +++ b/antarest/study/business/areas/properties_management.py @@ -14,9 +14,10 @@ import typing as t from builtins import sorted -from pydantic import root_validator +from pydantic import model_validator from antarest.core.exceptions import ChildNotFoundError +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.utils import FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.area import AdequacyPatchMode @@ -49,18 +50,19 @@ def decode_filter(encoded_value: t.Set[str], current_filter: t.Optional[str] = N return ", ".join(sort_filter_options(encoded_value)) +@all_optional_model class PropertiesFormFields(FormFieldsBaseModel): - energy_cost_unsupplied: t.Optional[float] - energy_cost_spilled: t.Optional[float] - non_dispatch_power: t.Optional[bool] - dispatch_hydro_power: t.Optional[bool] - other_dispatch_power: t.Optional[bool] - filter_synthesis: t.Optional[t.Set[str]] - filter_by_year: t.Optional[t.Set[str]] + energy_cost_unsupplied: float + energy_cost_spilled: float + non_dispatch_power: bool + dispatch_hydro_power: bool + other_dispatch_power: bool + filter_synthesis: t.Set[str] + filter_by_year: t.Set[str] # version 830 - adequacy_patch_mode: t.Optional[AdequacyPatchMode] + adequacy_patch_mode: AdequacyPatchMode - @root_validator + @model_validator(mode="before") def validation(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: filters = { "filter_synthesis": values.get("filter_synthesis"), @@ -143,7 +145,7 @@ def get_value(field_info: FieldInfo) -> t.Any: encode = field_info.get("encode") or (lambda x: x) return encode(val) - return PropertiesFormFields.construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()}) + return PropertiesFormFields.model_construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()}) def set_field_values( self, diff --git a/antarest/study/business/areas/renewable_management.py b/antarest/study/business/areas/renewable_management.py index 5a9c2d47bb..12af8abdb5 100644 --- a/antarest/study/business/areas/renewable_management.py +++ b/antarest/study/business/areas/renewable_management.py @@ -11,14 +11,13 @@ # This file is part of the Antares project. import collections -import json import typing as t -from pydantic import validator +from pydantic import field_validator from antarest.core.exceptions import DuplicateRenewableCluster, RenewableClusterConfigNotFound, RenewableClusterNotFound from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study @@ -46,23 +45,26 @@ class TimeSeriesInterpretation(EnumIgnoreCase): PRODUCTION_FACTOR = "production-factor" +@all_optional_model @camel_case_model -class RenewableClusterInput(RenewableProperties, metaclass=AllOptionalMetaclass, use_none=True): +class RenewableClusterInput(RenewableProperties): """ Model representing the data structure required to edit an existing renewable cluster. """ class Config: + populate_by_name = True + @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = RenewableClusterInput( group="Gas", name="Gas Cluster XY", enabled=False, - unitCount=100, - nominalCapacity=1000.0, - tsInterpretation="power-generation", - ) + unit_count=100, + nominal_capacity=1000.0, + ts_interpretation="power-generation", + ).model_dump() class RenewableClusterCreation(RenewableClusterInput): @@ -71,7 +73,7 @@ class RenewableClusterCreation(RenewableClusterInput): """ # noinspection Pydantic - @validator("name", pre=True) + @field_validator("name", mode="before") def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. @@ -81,28 +83,29 @@ def validate_name(cls, name: t.Optional[str]) -> str: return name def to_config(self, study_version: t.Union[str, int]) -> RenewableConfigType: - values = self.dict(by_alias=False, exclude_none=True) + values = self.model_dump(by_alias=False, exclude_none=True) return create_renewable_config(study_version=study_version, **values) +@all_optional_model @camel_case_model -class RenewableClusterOutput(RenewableConfig, metaclass=AllOptionalMetaclass, use_none=True): +class RenewableClusterOutput(RenewableConfig): """ Model representing the output data structure to display the details of a renewable cluster. """ class Config: @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = RenewableClusterOutput( id="Gas cluster YZ", group="Gas", name="Gas Cluster YZ", enabled=False, - unitCount=100, - nominalCapacity=1000.0, - tsInterpretation="power-generation", - ) + unit_count=100, + nominal_capacity=1000.0, + ts_interpretation="power-generation", + ).model_dump() def create_renewable_output( @@ -111,7 +114,7 @@ def create_renewable_output( config: t.Mapping[str, t.Any], ) -> "RenewableClusterOutput": obj = create_renewable_config(study_version=study_version, **config, id=cluster_id) - kwargs = obj.dict(by_alias=False) + kwargs = obj.model_dump(by_alias=False) return RenewableClusterOutput(**kwargs) @@ -218,7 +221,7 @@ def _make_create_cluster_cmd(self, area_id: str, cluster: RenewableConfigType) - command = CreateRenewablesCluster( area_id=area_id, cluster_name=cluster.id, - parameters=cluster.dict(by_alias=True, exclude={"id"}), + parameters=cluster.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) return command @@ -281,16 +284,16 @@ def update_cluster( old_config = create_renewable_config(study_version, **values) # use Python values to synchronize Config and Form values - new_values = cluster_data.dict(by_alias=False, exclude_none=True) + new_values = cluster_data.model_dump(by_alias=False, exclude_none=True) new_config = old_config.copy(exclude={"id"}, update=new_values) - new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) + new_data = new_config.model_dump(mode="json", by_alias=True, exclude={"id"}) # create the dict containing the new values using aliases - data: t.Dict[str, t.Any] = { - field.alias: new_data[field.alias] - for field_name, field in new_config.__fields__.items() - if field_name in new_values - } + data: t.Dict[str, t.Any] = {} + for field_name, field in new_config.model_fields.items(): + if field_name in new_values: + name = field.alias if field.alias else field_name + data[name] = new_data[name] # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context @@ -300,7 +303,7 @@ def update_cluster( ] execute_or_add_commands(study, file_study, commands, self.storage_service) - values = new_config.dict(by_alias=False) + values = new_config.model_dump(by_alias=False) return RenewableClusterOutput(**values, id=cluster_id) def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None: @@ -352,7 +355,7 @@ def duplicate_cluster( # Cluster duplication current_cluster = self.get_cluster(study, area_id, source_id) current_cluster.name = new_cluster_name - creation_form = RenewableClusterCreation(**current_cluster.dict(by_alias=False, exclude={"id"})) + creation_form = RenewableClusterCreation(**current_cluster.model_dump(by_alias=False, exclude={"id"})) new_config = creation_form.to_config(study.version) create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) @@ -370,7 +373,7 @@ def duplicate_cluster( execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) - return RenewableClusterOutput(**new_config.dict(by_alias=False)) + return RenewableClusterOutput(**new_config.model_dump(by_alias=False)) def update_renewables_props( self, @@ -387,17 +390,17 @@ def update_renewables_props( for renewable_id, update_cluster in update_renewables_by_ids.items(): # Update the renewable cluster properties. old_cluster = old_renewables_by_ids[renewable_id] - new_cluster = old_cluster.copy(update=update_cluster.dict(by_alias=False, exclude_none=True)) + new_cluster = old_cluster.copy(update=update_cluster.model_dump(by_alias=False, exclude_none=True)) new_renewables_by_areas[area_id][renewable_id] = new_cluster # Convert the DTO to a configuration object and update the configuration file. properties = create_renewable_config( - study.version, **new_cluster.dict(by_alias=False, exclude_none=True) + study.version, **new_cluster.model_dump(by_alias=False, exclude_none=True) ) path = _CLUSTER_PATH.format(area_id=area_id, cluster_id=renewable_id) cmd = UpdateConfig( target=path, - data=json.loads(properties.json(by_alias=True, exclude={"id"})), + data=properties.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) commands.append(cmd) diff --git a/antarest/study/business/areas/st_storage_management.py b/antarest/study/business/areas/st_storage_management.py index ee64362d8b..dcb5db5dc8 100644 --- a/antarest/study/business/areas/st_storage_management.py +++ b/antarest/study/business/areas/st_storage_management.py @@ -11,14 +11,11 @@ # This file is part of the Antares project. import collections -import functools -import json import operator import typing as t import numpy as np -from pydantic import BaseModel, Extra, root_validator, validator -from requests.structures import CaseInsensitiveDict +from pydantic import BaseModel, field_validator, model_validator from typing_extensions import Literal from antarest.core.exceptions import ( @@ -30,7 +27,8 @@ STStorageNotFound, ) from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.core.requests import CaseInsensitiveDict +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id @@ -49,15 +47,16 @@ from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig +@all_optional_model @camel_case_model -class STStorageInput(STStorage880Properties, metaclass=AllOptionalMetaclass, use_none=True): +class STStorageInput(STStorage880Properties): """ Model representing the form used to EDIT an existing short-term storage. """ class Config: @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = STStorageInput( name="Siemens Battery", group=STStorageGroup.BATTERY, @@ -67,7 +66,7 @@ def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: efficiency=0.94, initial_level=0.5, initial_level_optim=True, - ) + ).model_dump() class STStorageCreation(STStorageInput): @@ -76,7 +75,7 @@ class STStorageCreation(STStorageInput): """ # noinspection Pydantic - @validator("name", pre=True) + @field_validator("name", mode="before") def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. @@ -87,19 +86,20 @@ def validate_name(cls, name: t.Optional[str]) -> str: # noinspection PyUnusedLocal def to_config(self, study_version: t.Union[str, int]) -> STStorageConfigType: - values = self.dict(by_alias=False, exclude_none=True) + values = self.model_dump(by_alias=False, exclude_none=True) return create_st_storage_config(study_version=study_version, **values) +@all_optional_model @camel_case_model -class STStorageOutput(STStorage880Config, metaclass=AllOptionalMetaclass, use_none=True): +class STStorageOutput(STStorage880Config): """ Model representing the form used to display the details of a short-term storage entry. """ class Config: @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = STStorageOutput( id="siemens_battery", name="Siemens Battery", @@ -109,7 +109,7 @@ def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: reservoir_capacity=600, efficiency=0.94, initial_level_optim=True, - ) + ).model_dump() # ============= @@ -131,13 +131,13 @@ class STStorageMatrix(BaseModel): """ class Config: - extra = Extra.forbid + extra = "forbid" data: t.List[t.List[float]] index: t.List[int] columns: t.List[int] - @validator("data") + @field_validator("data") def validate_time_series(cls, data: t.List[t.List[float]]) -> t.List[t.List[float]]: """ Validator to check the integrity of the time series data. @@ -172,7 +172,7 @@ class STStorageMatrices(BaseModel): """ class Config: - extra = Extra.forbid + extra = "forbid" pmax_injection: STStorageMatrix pmax_withdrawal: STStorageMatrix @@ -180,7 +180,7 @@ class Config: upper_rule_curve: STStorageMatrix inflows: STStorageMatrix - @validator( + @field_validator( "pmax_injection", "pmax_withdrawal", "lower_rule_curve", @@ -195,23 +195,18 @@ def validate_time_series(cls, matrix: STStorageMatrix) -> STStorageMatrix: raise ValueError("Matrix values should be between 0 and 1") return matrix - @root_validator() - def validate_rule_curve( - cls, values: t.MutableMapping[str, STStorageMatrix] - ) -> t.MutableMapping[str, STStorageMatrix]: + @model_validator(mode="after") + def validate_rule_curve(self) -> "STStorageMatrices": """ Validator to ensure 'lower_rule_curve' values are less than or equal to 'upper_rule_curve' values. """ - if "lower_rule_curve" in values and "upper_rule_curve" in values: - lower_rule_curve = values["lower_rule_curve"] - upper_rule_curve = values["upper_rule_curve"] - lower_array = np.array(lower_rule_curve.data, dtype=np.float64) - upper_array = np.array(upper_rule_curve.data, dtype=np.float64) - # noinspection PyUnresolvedReferences - if (lower_array > upper_array).any(): - raise ValueError("Each 'lower_rule_curve' value must be lower or equal to each 'upper_rule_curve'") - return values + lower_array = np.array(self.lower_rule_curve.data, dtype=np.float64) + upper_array = np.array(self.upper_rule_curve.data, dtype=np.float64) + if (lower_array > upper_array).any(): + raise ValueError("Each 'lower_rule_curve' value must be lower or equal to each 'upper_rule_curve'") + + return self # noinspection SpellCheckingInspection @@ -249,7 +244,7 @@ def create_storage_output( config: t.Mapping[str, t.Any], ) -> "STStorageOutput": obj = create_st_storage_config(study_version=study_version, **config, id=cluster_id) - kwargs = obj.dict(by_alias=False) + kwargs = obj.model_dump(by_alias=False) return STStorageOutput(**kwargs) @@ -393,17 +388,17 @@ def update_storages_props( for storage_id, update_cluster in update_storages_by_ids.items(): # Update the storage cluster properties. old_cluster = old_storages_by_ids[storage_id] - new_cluster = old_cluster.copy(update=update_cluster.dict(by_alias=False, exclude_none=True)) + new_cluster = old_cluster.copy(update=update_cluster.model_dump(by_alias=False, exclude_none=True)) new_storages_by_areas[area_id][storage_id] = new_cluster # Convert the DTO to a configuration object and update the configuration file. properties = create_st_storage_config( - study.version, **new_cluster.dict(by_alias=False, exclude_none=True) + study.version, **new_cluster.model_dump(by_alias=False, exclude_none=True) ) path = _STORAGE_LIST_PATH.format(area_id=area_id, storage_id=storage_id) cmd = UpdateConfig( target=path, - data=json.loads(properties.json(by_alias=True, exclude={"id"})), + data=properties.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) commands.append(cmd) @@ -472,16 +467,16 @@ def update_storage( old_config = create_st_storage_config(study_version, **values) # use Python values to synchronize Config and Form values - new_values = form.dict(by_alias=False, exclude_none=True) + new_values = form.model_dump(by_alias=False, exclude_none=True) new_config = old_config.copy(exclude={"id"}, update=new_values) - new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) + new_data = new_config.model_dump(mode="json", by_alias=True, exclude={"id"}) # create the dict containing the new values using aliases - data: t.Dict[str, t.Any] = { - field.alias: new_data[field.alias] - for field_name, field in new_config.__fields__.items() - if field_name in new_values - } + data: t.Dict[str, t.Any] = {} + for field_name, field in new_config.model_fields.items(): + if field_name in new_values: + name = field.alias if field.alias else field_name + data[name] = new_data[name] # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context @@ -492,7 +487,7 @@ def update_storage( ] execute_or_add_commands(study, file_study, commands, self.storage_service) - values = new_config.dict(by_alias=False) + values = new_config.model_dump(by_alias=False) return STStorageOutput(**values, id=storage_id) def delete_storages( @@ -554,7 +549,7 @@ def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_clus # We should remove the field 'enabled' for studies before v8.8 as it didn't exist if int(study.version) < 880: fields_to_exclude.add("enabled") - creation_form = STStorageCreation(**current_cluster.dict(by_alias=False, exclude=fields_to_exclude)) + creation_form = STStorageCreation(**current_cluster.model_dump(by_alias=False, exclude=fields_to_exclude)) new_config = creation_form.to_config(study.version) create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) @@ -583,7 +578,7 @@ def duplicate_cluster(self, study: Study, area_id: str, source_id: str, new_clus execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) - return STStorageOutput(**new_config.dict(by_alias=False)) + return STStorageOutput(**new_config.model_dump(by_alias=False)) def get_matrix( self, @@ -684,17 +679,18 @@ def validate_matrices( Returns: bool: True if validation is successful. """ - # Create a partial function to retrieve matrix objects - get_matrix_obj = functools.partial(self._get_matrix_obj, study, area_id, storage_id) + + def validate_matrix(matrix_type: STStorageTimeSeries) -> STStorageMatrix: + return STStorageMatrix.model_validate(self._get_matrix_obj(study, area_id, storage_id, matrix_type)) # Validate matrices by constructing the `STStorageMatrices` object # noinspection SpellCheckingInspection STStorageMatrices( - pmax_injection=get_matrix_obj("pmax_injection"), - pmax_withdrawal=get_matrix_obj("pmax_withdrawal"), - lower_rule_curve=get_matrix_obj("lower_rule_curve"), - upper_rule_curve=get_matrix_obj("upper_rule_curve"), - inflows=get_matrix_obj("inflows"), + pmax_injection=validate_matrix("pmax_injection"), + pmax_withdrawal=validate_matrix("pmax_withdrawal"), + lower_rule_curve=validate_matrix("lower_rule_curve"), + upper_rule_curve=validate_matrix("upper_rule_curve"), + inflows=validate_matrix("inflows"), ) # Validation successful diff --git a/antarest/study/business/areas/thermal_management.py b/antarest/study/business/areas/thermal_management.py index f37bc757f1..c1a06068aa 100644 --- a/antarest/study/business/areas/thermal_management.py +++ b/antarest/study/business/areas/thermal_management.py @@ -11,11 +11,10 @@ # This file is part of the Antares project. import collections -import json import typing as t from pathlib import Path -from pydantic import validator +from pydantic import field_validator from antarest.core.exceptions import ( DuplicateThermalCluster, @@ -25,7 +24,7 @@ WrongMatrixHeightError, ) from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id @@ -54,33 +53,35 @@ _ALL_CLUSTERS_PATH = "input/thermal/clusters" +@all_optional_model @camel_case_model -class ThermalClusterInput(Thermal870Properties, metaclass=AllOptionalMetaclass, use_none=True): +class ThermalClusterInput(Thermal870Properties): """ Model representing the data structure required to edit an existing thermal cluster within a study. """ class Config: @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = ThermalClusterInput( group="Gas", name="Gas Cluster XY", enabled=False, - unitCount=100, - nominalCapacity=1000.0, - genTs="use global", + unit_count=100, + nominal_capacity=1000.0, + gen_ts="use global", co2=7.0, - ) + ).model_dump() +@camel_case_model class ThermalClusterCreation(ThermalClusterInput): """ Model representing the data structure required to create a new thermal cluster within a study. """ # noinspection Pydantic - @validator("name", pre=True) + @field_validator("name", mode="before") def validate_name(cls, name: t.Optional[str]) -> str: """ Validator to check if the name is not empty. @@ -90,29 +91,30 @@ def validate_name(cls, name: t.Optional[str]) -> str: return name def to_config(self, study_version: t.Union[str, int]) -> ThermalConfigType: - values = self.dict(by_alias=False, exclude_none=True) + values = self.model_dump(by_alias=False, exclude_none=True) return create_thermal_config(study_version=study_version, **values) +@all_optional_model @camel_case_model -class ThermalClusterOutput(Thermal870Config, metaclass=AllOptionalMetaclass, use_none=True): +class ThermalClusterOutput(Thermal870Config): """ Model representing the output data structure to display the details of a thermal cluster within a study. """ class Config: @staticmethod - def schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: + def json_schema_extra(schema: t.MutableMapping[str, t.Any]) -> None: schema["example"] = ThermalClusterOutput( id="Gas cluster YZ", group="Gas", name="Gas Cluster YZ", enabled=False, - unitCount=100, - nominalCapacity=1000.0, - genTs="use global", + unit_count=100, + nominal_capacity=1000.0, + gen_ts="use global", co2=7.0, - ) + ).model_dump() def create_thermal_output( @@ -121,7 +123,7 @@ def create_thermal_output( config: t.Mapping[str, t.Any], ) -> "ThermalClusterOutput": obj = create_thermal_config(study_version=study_version, **config, id=cluster_id) - kwargs = obj.dict(by_alias=False) + kwargs = obj.model_dump(by_alias=False) return ThermalClusterOutput(**kwargs) @@ -252,15 +254,17 @@ def update_thermals_props( for thermal_id, update_cluster in update_thermals_by_ids.items(): # Update the thermal cluster properties. old_cluster = old_thermals_by_ids[thermal_id] - new_cluster = old_cluster.copy(update=update_cluster.dict(by_alias=False, exclude_none=True)) + new_cluster = old_cluster.copy(update=update_cluster.model_dump(by_alias=False, exclude_none=True)) new_thermals_by_areas[area_id][thermal_id] = new_cluster # Convert the DTO to a configuration object and update the configuration file. - properties = create_thermal_config(study.version, **new_cluster.dict(by_alias=False, exclude_none=True)) + properties = create_thermal_config( + study.version, **new_cluster.model_dump(by_alias=False, exclude_none=True) + ) path = _CLUSTER_PATH.format(area_id=area_id, cluster_id=thermal_id) cmd = UpdateConfig( target=path, - data=json.loads(properties.json(by_alias=True, exclude={"id"})), + data=properties.model_dump(mode="json", by_alias=True, exclude={"id"}), command_context=self.storage_service.variant_study_service.command_factory.command_context, ) commands.append(cmd) @@ -302,12 +306,13 @@ def create_cluster(self, study: Study, area_id: str, cluster_data: ThermalCluste def _make_create_cluster_cmd(self, area_id: str, cluster: ThermalConfigType) -> CreateCluster: # NOTE: currently, in the `CreateCluster` class, there is a confusion # between the cluster name and the cluster ID (which is a section name). - command = CreateCluster( - area_id=area_id, - cluster_name=cluster.id, - parameters=cluster.dict(by_alias=True, exclude={"id"}), - command_context=self.storage_service.variant_study_service.command_factory.command_context, - ) + args = { + "area_id": area_id, + "cluster_name": cluster.id, + "parameters": cluster.model_dump(mode="json", by_alias=True, exclude={"id"}), + "command_context": self.storage_service.variant_study_service.command_factory.command_context, + } + command = CreateCluster.model_validate(args) return command def update_cluster( @@ -346,16 +351,16 @@ def update_cluster( old_config = create_thermal_config(study_version, **values) # Use Python values to synchronize Config and Form values - new_values = cluster_data.dict(by_alias=False, exclude_none=True) + new_values = cluster_data.model_dump(by_alias=False, exclude_none=True) new_config = old_config.copy(exclude={"id"}, update=new_values) - new_data = json.loads(new_config.json(by_alias=True, exclude={"id"})) + new_data = new_config.model_dump(mode="json", by_alias=True, exclude={"id"}) # create the dict containing the new values using aliases - data: t.Dict[str, t.Any] = { - field.alias: new_data[field.alias] - for field_name, field in new_config.__fields__.items() - if field_name in new_values - } + data: t.Dict[str, t.Any] = {} + for field_name, field in new_config.model_fields.items(): + if field_name in new_values: + name = field.alias if field.alias else field_name + data[name] = new_data[name] # create the update config commands with the modified data command_context = self.storage_service.variant_study_service.command_factory.command_context @@ -365,8 +370,8 @@ def update_cluster( ] execute_or_add_commands(study, file_study, commands, self.storage_service) - values = new_config.dict(by_alias=False) - return ThermalClusterOutput(**values, id=cluster_id) + values = {**new_config.model_dump(mode="json", by_alias=False), "id": cluster_id} + return ThermalClusterOutput.model_validate(values) def delete_clusters(self, study: Study, area_id: str, cluster_ids: t.Sequence[str]) -> None: """ @@ -418,7 +423,7 @@ def duplicate_cluster( # Cluster duplication source_cluster = self.get_cluster(study, area_id, source_id) source_cluster.name = new_cluster_name - creation_form = ThermalClusterCreation(**source_cluster.dict(by_alias=False, exclude={"id"})) + creation_form = ThermalClusterCreation(**source_cluster.model_dump(by_alias=False, exclude={"id"})) new_config = creation_form.to_config(study.version) create_cluster_cmd = self._make_create_cluster_cmd(area_id, new_config) @@ -451,7 +456,7 @@ def duplicate_cluster( execute_or_add_commands(study, self._get_file_study(study), commands, self.storage_service) - return ThermalClusterOutput(**new_config.dict(by_alias=False)) + return ThermalClusterOutput(**new_config.model_dump(by_alias=False)) def validate_series(self, study: Study, area_id: str, cluster_id: str) -> bool: lower_cluster_id = cluster_id.lower() diff --git a/antarest/study/business/binding_constraint_management.py b/antarest/study/business/binding_constraint_management.py index 90735ff6b2..4ef2a5fffc 100644 --- a/antarest/study/business/binding_constraint_management.py +++ b/antarest/study/business/binding_constraint_management.py @@ -11,13 +11,11 @@ # This file is part of the Antares project. import collections -import json import logging import typing as t import numpy as np -from pydantic import BaseModel, Field, root_validator, validator -from requests.utils import CaseInsensitiveDict +from pydantic import BaseModel, Field, field_validator, model_validator from antarest.core.exceptions import ( BindingConstraintNotFound, @@ -31,6 +29,7 @@ WrongMatrixHeightError, ) from antarest.core.model import JSON +from antarest.core.requests import CaseInsensitiveDict from antarest.core.utils.string import to_camel_case from antarest.study.business.all_optional_meta import camel_case_model from antarest.study.business.utils import execute_or_add_commands @@ -129,12 +128,12 @@ class ConstraintTerm(BaseModel): data: the constraint term data (link or cluster), if any. """ - id: t.Optional[str] - weight: t.Optional[float] - offset: t.Optional[int] - data: t.Optional[t.Union[LinkTerm, ClusterTerm]] + id: t.Optional[str] = None + weight: t.Optional[float] = None + offset: t.Optional[int] = None + data: t.Optional[t.Union[LinkTerm, ClusterTerm]] = None - @validator("id") + @field_validator("id") def id_to_lower(cls, v: t.Optional[str]) -> t.Optional[str]: """Ensure the ID is lower case.""" if v is None: @@ -265,7 +264,7 @@ class ConstraintInput(BindingConstraintMatrices, ConstraintInput870): class ConstraintCreation(ConstraintInput): name: str - @root_validator(pre=True) + @model_validator(mode="before") def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: for _key in ["time_step"] + [m.value for m in TermMatrices]: _camel = to_camel_case(_key) @@ -316,20 +315,19 @@ def check_matrices_dimensions(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t. raise ValueError(err_msg) -@camel_case_model class ConstraintOutputBase(BindingConstraintPropertiesBase): id: str name: str terms: t.MutableSequence[ConstraintTerm] = Field(default_factory=lambda: []) + # I have to redefine the time_step attribute to give him another alias. + time_step: t.Optional[BindingConstraintFrequency] = Field(DEFAULT_TIMESTEP, alias="timeStep") # type: ignore -@camel_case_model class ConstraintOutput830(ConstraintOutputBase): - filter_year_by_year: str = "" - filter_synthesis: str = "" + filter_year_by_year: str = Field(default="", alias="filterYearByYear") + filter_synthesis: str = Field(default="", alias="filterSynthesis") -@camel_case_model class ConstraintOutput870(ConstraintOutput830): group: str = DEFAULT_GROUP @@ -372,7 +370,7 @@ def _get_references_by_widths( continue matrix_height = matrix.shape[0] - expected_height = EXPECTED_MATRIX_SHAPES[bc.time_step][0] + expected_height = EXPECTED_MATRIX_SHAPES[bc.time_step][0] # type: ignore if matrix_height != expected_height: raise WrongMatrixHeightError( f"The binding constraint '{bc.name}' should have {expected_height} rows, currently: {matrix_height}" @@ -443,17 +441,22 @@ def parse_and_add_terms(key: str, value: t.Any, adapted_constraint: ConstraintOu id=key, weight=weight, offset=offset, - data={ - "area1": term_data[0], - "area2": term_data[1], - }, + data=LinkTerm.model_validate( + { + "area1": term_data[0], + "area2": term_data[1], + } + ), ) ) # Cluster term else: adapted_constraint.terms.append( ConstraintTerm( - id=key, weight=weight, offset=offset, data={"area": term_data[0], "cluster": term_data[1]} + id=key, + weight=weight, + offset=offset, + data=ClusterTerm.model_validate({"area": term_data[0], "cluster": term_data[1]}), ) ) @@ -601,7 +604,7 @@ def get_grouped_constraints(self, study: Study) -> t.Mapping[str, t.Sequence[Con storage_service = self.storage_service.get_storage(study) file_study = storage_service.get_raw(study) config = file_study.tree.get(["input", "bindingconstraints", "bindingconstraints"]) - grouped_constraints = CaseInsensitiveDict() # type: ignore + grouped_constraints = CaseInsensitiveDict() for constraint in config.values(): constraint_config = self.constraint_model_adapter(constraint, int(study.version)) @@ -710,7 +713,10 @@ def create_binding_constraint( check_attributes_coherence(data, version, data.operator or DEFAULT_OPERATOR) - new_constraint = {"name": data.name, **json.loads(data.json(exclude={"terms", "name"}, exclude_none=True))} + new_constraint = { + "name": data.name, + **data.model_dump(mode="json", exclude={"terms", "name"}, exclude_none=True), + } args = { **new_constraint, "command_context": self.storage_service.variant_study_service.command_factory.command_context, @@ -748,7 +754,7 @@ def update_binding_constraint( upd_constraint = { "id": binding_constraint_id, - **json.loads(data.json(exclude={"terms", "name"}, exclude_none=True)), + **data.model_dump(mode="json", exclude={"terms", "name"}, exclude_none=True), } args = { **upd_constraint, @@ -768,7 +774,7 @@ def update_binding_constraint( updated_matrices = [term for term in [m.value for m in TermMatrices] if getattr(data, term)] time_step = data.time_step or existing_constraint.time_step command.validates_and_fills_matrices( - time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False + time_step=time_step, specific_matrices=updated_matrices, version=study_version, create=False # type: ignore ) execute_or_add_commands(study, file_study, [command], self.storage_service) @@ -834,11 +840,9 @@ def _update_constraint_with_terms( coeffs = { term_id: [term.weight, term.offset] if term.offset else [term.weight] for term_id, term in terms.items() } - command = UpdateBindingConstraint( - id=bc.id, - coeffs=coeffs, - command_context=self.storage_service.variant_study_service.command_factory.command_context, - ) + command_context = self.storage_service.variant_study_service.command_factory.command_context + args = {"id": bc.id, "coeffs": coeffs, "command_context": command_context} + command = UpdateBindingConstraint.model_validate(args) file_study = self.storage_service.get_storage(study).get_raw(study) execute_or_add_commands(study, file_study, [command], self.storage_service) @@ -861,7 +865,7 @@ def update_constraint_terms( if update_mode == "add": for term in constraint_terms: if term.data is None: - raise InvalidConstraintTerm(binding_constraint_id, term.json()) + raise InvalidConstraintTerm(binding_constraint_id, term.model_dump_json()) constraint = self.get_binding_constraint(study, binding_constraint_id) existing_terms = collections.OrderedDict((term.generate_id(), term) for term in constraint.terms) diff --git a/antarest/study/business/correlation_management.py b/antarest/study/business/correlation_management.py index 56b9e33fba..c05b293b3e 100644 --- a/antarest/study/business/correlation_management.py +++ b/antarest/study/business/correlation_management.py @@ -15,11 +15,11 @@ The generators are of the same category and can be hydraulic, wind, load or solar. """ import collections -from typing import Dict, List, Sequence +from typing import Dict, List, Sequence, Union import numpy as np import numpy.typing as npt -from pydantic import conlist, validator +from pydantic import ValidationInfo, field_validator from antarest.core.exceptions import AreaNotFound from antarest.study.business.area_management import AreaInfoDTO @@ -40,7 +40,7 @@ class AreaCoefficientItem(FormFieldsBaseModel): """ class Config: - allow_population_by_field_name = True + populate_by_name = True area_id: str coefficient: float @@ -57,7 +57,7 @@ class CorrelationFormFields(FormFieldsBaseModel): correlation: List[AreaCoefficientItem] # noinspection PyMethodParameters - @validator("correlation") + @field_validator("correlation") def check_correlation(cls, correlation: List[AreaCoefficientItem]) -> List[AreaCoefficientItem]: if not correlation: raise ValueError("correlation must not be empty") @@ -84,13 +84,21 @@ class CorrelationMatrix(FormFieldsBaseModel): data: A 2D-array matrix of correlation coefficients. """ - index: conlist(str, min_items=1) # type: ignore - columns: conlist(str, min_items=1) # type: ignore + index: List[str] + columns: List[str] data: List[List[float]] # NonNegativeFloat not necessary + @field_validator("index", "columns", mode="before") + def validate_list_length(cls, values: List[str]) -> List[str]: + if len(values) == 0: + raise ValueError("correlation matrix cannot have 0 columns/index") + return values + # noinspection PyMethodParameters - @validator("data") - def validate_correlation_matrix(cls, data: List[List[float]], values: Dict[str, List[str]]) -> List[List[float]]: + @field_validator("data", mode="before") + def validate_correlation_matrix( + cls, data: List[List[float]], values: Union[Dict[str, List[str]], ValidationInfo] + ) -> List[List[float]]: """ Validates the correlation matrix by checking its shape and range of coefficients. @@ -112,8 +120,9 @@ def validate_correlation_matrix(cls, data: List[List[float]], values: Dict[str, """ array = np.array(data) - rows = len(values.get("index", [])) - cols = len(values.get("columns", [])) + new_values = values if isinstance(values, dict) else values.data + rows = len(new_values.get("index", [])) + cols = len(new_values.get("columns", [])) if array.size == 0: raise ValueError("correlation matrix must not be empty") @@ -129,7 +138,7 @@ def validate_correlation_matrix(cls, data: List[List[float]], values: Dict[str, return data class Config: - schema_extra = { + json_schema_extra = { "example": { "columns": ["north", "east", "south", "west"], "data": [ diff --git a/antarest/study/business/district_manager.py b/antarest/study/business/district_manager.py index 3563230790..c642d61531 100644 --- a/antarest/study/business/district_manager.py +++ b/antarest/study/business/district_manager.py @@ -69,16 +69,19 @@ def get_districts(self, study: Study) -> List[DistrictInfoDTO]: """ file_study = self.storage_service.get_storage(study).get_raw(study) all_areas = list(file_study.config.areas) - return [ - DistrictInfoDTO( - id=district_id, - name=district.name, - areas=district.get_areas(all_areas), - output=district.output, - comments=file_study.tree.get(["input", "areas", "sets", district_id]).get("comments", ""), + districts = [] + for district_id, district in file_study.config.sets.items(): + assert district.name is not None + districts.append( + DistrictInfoDTO( + id=district_id, + name=district.name, + areas=district.get_areas(all_areas), + output=district.output, + comments=file_study.tree.get(["input", "areas", "sets", district_id]).get("comments", ""), + ) ) - for district_id, district in file_study.config.sets.items() - ] + return districts def create_district( self, @@ -148,14 +151,14 @@ def update_district( file_study = self.storage_service.get_storage(study).get_raw(study) if district_id not in file_study.config.sets: raise DistrictNotFound(district_id) - areas = frozenset(dto.areas or []) - all_areas = frozenset(file_study.config.areas) + areas = set(dto.areas or []) + all_areas = set(file_study.config.areas) if invalid_areas := areas - all_areas: raise AreaNotFound(*invalid_areas) command = UpdateDistrict( id=district_id, base_filter=DistrictBaseFilter.remove_all, - filter_items=areas, + filter_items=dto.areas or [], output=dto.output, comments=dto.comments, command_context=self.storage_service.variant_study_service.command_factory.command_context, diff --git a/antarest/study/business/general_management.py b/antarest/study/business/general_management.py index c66a7b02fa..53e1b9a2ed 100644 --- a/antarest/study/business/general_management.py +++ b/antarest/study/business/general_management.py @@ -10,10 +10,11 @@ # # This file is part of the Antares project. -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Union, cast -from pydantic import PositiveInt, StrictBool, conint, root_validator +from pydantic import PositiveInt, StrictBool, ValidationInfo, conint, model_validator +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -63,39 +64,41 @@ class BuildingMode(EnumIgnoreCase): DayNumberType = conint(ge=1, le=366) +@all_optional_model class GeneralFormFields(FormFieldsBaseModel): - mode: Optional[Mode] - first_day: Optional[DayNumberType] # type: ignore - last_day: Optional[DayNumberType] # type: ignore - horizon: Optional[str] # Don't use `StrictStr` because it can be an int - first_month: Optional[Month] - first_week_day: Optional[WeekDay] - first_january: Optional[WeekDay] - leap_year: Optional[StrictBool] - nb_years: Optional[PositiveInt] - building_mode: Optional[BuildingMode] - selection_mode: Optional[StrictBool] - year_by_year: Optional[StrictBool] - simulation_synthesis: Optional[StrictBool] - mc_scenario: Optional[StrictBool] + mode: Mode + first_day: DayNumberType # type: ignore + last_day: DayNumberType # type: ignore + horizon: Union[str, int] + first_month: Month + first_week_day: WeekDay + first_january: WeekDay + leap_year: StrictBool + nb_years: PositiveInt + building_mode: BuildingMode + selection_mode: StrictBool + year_by_year: StrictBool + simulation_synthesis: StrictBool + mc_scenario: StrictBool # Geographic trimming + Thematic trimming. # For study versions < 710 - filtering: Optional[StrictBool] + filtering: StrictBool # For study versions >= 710 - geographic_trimming: Optional[StrictBool] - thematic_trimming: Optional[StrictBool] - - @root_validator - def day_fields_validation(cls, values: Dict[str, Any]) -> Dict[str, Any]: - first_day = values.get("first_day") - last_day = values.get("last_day") - leap_year = values.get("leap_year") + geographic_trimming: StrictBool + thematic_trimming: StrictBool + + @model_validator(mode="before") + def day_fields_validation(cls, values: Union[Dict[str, Any], ValidationInfo]) -> Dict[str, Any]: + new_values = values if isinstance(values, dict) else values.data + first_day = new_values.get("first_day") + last_day = new_values.get("last_day") + leap_year = new_values.get("leap_year") day_fields = [first_day, last_day, leap_year] if all(v is None for v in day_fields): # The user wishes to update another field than these three. # no need to validate anything: - return values + return new_values if any(v is None for v in day_fields): raise ValueError("First day, last day and leap year fields must be defined together") @@ -110,7 +113,7 @@ def day_fields_validation(cls, values: Dict[str, Any]) -> Dict[str, Any]: if last_day > num_days_in_year: raise ValueError(f"Last day cannot be greater than {num_days_in_year}") - return values + return new_values GENERAL = "general" diff --git a/antarest/study/business/link_management.py b/antarest/study/business/link_management.py index 8515b42fee..6f696d2a5e 100644 --- a/antarest/study/business/link_management.py +++ b/antarest/study/business/link_management.py @@ -16,7 +16,7 @@ from antarest.core.exceptions import ConfigFileNotFound from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass, camel_case_model +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.business.utils import execute_or_add_commands from antarest.study.model import RawStudy from antarest.study.storage.rawstudy.model.filesystem.config.links import LinkProperties @@ -40,8 +40,9 @@ class LinkInfoDTO(BaseModel): ui: t.Optional[LinkUIDTO] = None +@all_optional_model @camel_case_model -class LinkOutput(LinkProperties, metaclass=AllOptionalMetaclass, use_none=True): +class LinkOutput(LinkProperties): """ DTO object use to get the link information. """ @@ -121,7 +122,7 @@ def get_all_links_props(self, study: RawStudy) -> t.Mapping[t.Tuple[str, str], L for area2_id, properties_cfg in property_map.items(): area1_id, area2_id = sorted([area1_id, area2_id]) properties = LinkProperties(**properties_cfg) - links_by_ids[(area1_id, area2_id)] = LinkOutput(**properties.dict(by_alias=False)) + links_by_ids[(area1_id, area2_id)] = LinkOutput(**properties.model_dump(by_alias=False)) return links_by_ids @@ -137,11 +138,11 @@ def update_links_props( for (area1, area2), update_link_dto in update_links_by_ids.items(): # Update the link properties. old_link_dto = old_links_by_ids[(area1, area2)] - new_link_dto = old_link_dto.copy(update=update_link_dto.dict(by_alias=False, exclude_none=True)) + new_link_dto = old_link_dto.copy(update=update_link_dto.model_dump(by_alias=False, exclude_none=True)) new_links_by_ids[(area1, area2)] = new_link_dto # Convert the DTO to a configuration object and update the configuration file. - properties = LinkProperties(**new_link_dto.dict(by_alias=False)) + properties = LinkProperties(**new_link_dto.model_dump(by_alias=False)) path = f"{_ALL_LINKS_PATH}/{area1}/properties/{area2}" cmd = UpdateConfig( target=path, diff --git a/antarest/study/business/optimization_management.py b/antarest/study/business/optimization_management.py index 5ee764046e..cb599d12e2 100644 --- a/antarest/study/business/optimization_management.py +++ b/antarest/study/business/optimization_management.py @@ -10,10 +10,11 @@ # # This file is part of the Antares project. -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Union, cast from pydantic.types import StrictBool +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -45,24 +46,20 @@ class SimplexOptimizationRange(EnumIgnoreCase): WEEK = "week" +@all_optional_model class OptimizationFormFields(FormFieldsBaseModel): - binding_constraints: Optional[StrictBool] - hurdle_costs: Optional[StrictBool] - transmission_capacities: Optional[ - Union[ - StrictBool, - Union[LegacyTransmissionCapacities, TransmissionCapacities], - ] - ] - thermal_clusters_min_stable_power: Optional[StrictBool] - thermal_clusters_min_ud_time: Optional[StrictBool] - day_ahead_reserve: Optional[StrictBool] - primary_reserve: Optional[StrictBool] - strategic_reserve: Optional[StrictBool] - spinning_reserve: Optional[StrictBool] - export_mps: Optional[Union[bool, str]] - unfeasible_problem_behavior: Optional[UnfeasibleProblemBehavior] - simplex_optimization_range: Optional[SimplexOptimizationRange] + binding_constraints: StrictBool + hurdle_costs: StrictBool + transmission_capacities: Union[StrictBool, LegacyTransmissionCapacities, TransmissionCapacities] + thermal_clusters_min_stable_power: StrictBool + thermal_clusters_min_ud_time: StrictBool + day_ahead_reserve: StrictBool + primary_reserve: StrictBool + strategic_reserve: StrictBool + spinning_reserve: StrictBool + export_mps: Union[bool, str] + unfeasible_problem_behavior: UnfeasibleProblemBehavior + simplex_optimization_range: SimplexOptimizationRange OPTIMIZATION_PATH = f"{GENERAL_DATA_PATH}/optimization" diff --git a/antarest/study/business/table_mode_management.py b/antarest/study/business/table_mode_management.py index 848beb513b..342c1c5abb 100644 --- a/antarest/study/business/table_mode_management.py +++ b/antarest/study/business/table_mode_management.py @@ -95,36 +95,37 @@ def __init__( def _get_table_data_unsafe(self, study: RawStudy, table_type: TableModeType) -> TableDataDTO: if table_type == TableModeType.AREA: areas_map = self._area_manager.get_all_area_props(study) - data = {area_id: area.dict(by_alias=True) for area_id, area in areas_map.items()} + data = {area_id: area.model_dump(by_alias=True) for area_id, area in areas_map.items()} elif table_type == TableModeType.LINK: links_map = self._link_manager.get_all_links_props(study) data = { - f"{area1_id} / {area2_id}": link.dict(by_alias=True) for (area1_id, area2_id), link in links_map.items() + f"{area1_id} / {area2_id}": link.model_dump(by_alias=True) + for (area1_id, area2_id), link in links_map.items() } elif table_type == TableModeType.THERMAL: thermals_by_areas = self._thermal_manager.get_all_thermals_props(study) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, thermals_by_ids in thermals_by_areas.items() for cluster_id, cluster in thermals_by_ids.items() } elif table_type == TableModeType.RENEWABLE: renewables_by_areas = self._renewable_manager.get_all_renewables_props(study) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, renewables_by_ids in renewables_by_areas.items() for cluster_id, cluster in renewables_by_ids.items() } elif table_type == TableModeType.ST_STORAGE: storages_by_areas = self._st_storage_manager.get_all_storages_props(study) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, storages_by_ids in storages_by_areas.items() for cluster_id, cluster in storages_by_ids.items() } elif table_type == TableModeType.BINDING_CONSTRAINT: bc_seq = self._binding_constraint_manager.get_binding_constraints(study) - data = {bc.id: bc.dict(by_alias=True, exclude={"id", "name", "terms"}) for bc in bc_seq} + data = {bc.id: bc.model_dump(by_alias=True, exclude={"id", "name", "terms"}) for bc in bc_seq} else: # pragma: no cover raise NotImplementedError(f"Table type {table_type} not implemented") return data @@ -189,13 +190,13 @@ def update_table_data( # Use AreaOutput to update properties of areas, which may include `None` values area_props_by_ids = {key: AreaOutput(**values) for key, values in data.items()} areas_map = self._area_manager.update_areas_props(study, area_props_by_ids) - data = {area_id: area.dict(by_alias=True, exclude_none=True) for area_id, area in areas_map.items()} + data = {area_id: area.model_dump(by_alias=True, exclude_none=True) for area_id, area in areas_map.items()} return data elif table_type == TableModeType.LINK: links_map = {tuple(key.split(" / ")): LinkOutput(**values) for key, values in data.items()} updated_map = self._link_manager.update_links_props(study, links_map) # type: ignore data = { - f"{area1_id} / {area2_id}": link.dict(by_alias=True) + f"{area1_id} / {area2_id}": link.model_dump(by_alias=True) for (area1_id, area2_id), link in updated_map.items() } return data @@ -207,7 +208,7 @@ def update_table_data( thermals_by_areas[area_id][cluster_id] = ThermalClusterInput(**values) thermals_map = self._thermal_manager.update_thermals_props(study, thermals_by_areas) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, thermals_by_ids in thermals_map.items() for cluster_id, cluster in thermals_by_ids.items() } @@ -220,7 +221,7 @@ def update_table_data( renewables_by_areas[area_id][cluster_id] = RenewableClusterInput(**values) renewables_map = self._renewable_manager.update_renewables_props(study, renewables_by_areas) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, renewables_by_ids in renewables_map.items() for cluster_id, cluster in renewables_by_ids.items() } @@ -233,7 +234,7 @@ def update_table_data( storages_by_areas[area_id][cluster_id] = STStorageInput(**values) storages_map = self._st_storage_manager.update_storages_props(study, storages_by_areas) data = { - f"{area_id} / {cluster_id}": cluster.dict(by_alias=True, exclude={"id", "name"}) + f"{area_id} / {cluster_id}": cluster.model_dump(by_alias=True, exclude={"id", "name"}) for area_id, storages_by_ids in storages_map.items() for cluster_id, cluster in storages_by_ids.items() } @@ -241,7 +242,9 @@ def update_table_data( elif table_type == TableModeType.BINDING_CONSTRAINT: bcs_by_ids = {key: ConstraintInput(**values) for key, values in data.items()} bcs_map = self._binding_constraint_manager.update_binding_constraints(study, bcs_by_ids) - return {bc_id: bc.dict(by_alias=True, exclude={"id", "name", "terms"}) for bc_id, bc in bcs_map.items()} + return { + bc_id: bc.model_dump(by_alias=True, exclude={"id", "name", "terms"}) for bc_id, bc in bcs_map.items() + } else: # pragma: no cover raise NotImplementedError(f"Table type {table_type} not implemented") diff --git a/antarest/study/business/thematic_trimming_field_infos.py b/antarest/study/business/thematic_trimming_field_infos.py index 5a2f07c682..06086005fd 100644 --- a/antarest/study/business/thematic_trimming_field_infos.py +++ b/antarest/study/business/thematic_trimming_field_infos.py @@ -16,11 +16,12 @@ import typing as t -from antarest.study.business.all_optional_meta import AllOptionalMetaclass +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.utils import FormFieldsBaseModel -class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetaclass, use_none=True): +@all_optional_model +class ThematicTrimmingFormFields(FormFieldsBaseModel): """ This class manages the configuration of result filtering in a simulation. @@ -237,6 +238,5 @@ class ThematicTrimmingFormFields(FormFieldsBaseModel, metaclass=AllOptionalMetac } -def get_fields_info(study_version: t.Union[str, int]) -> t.Mapping[str, t.Mapping[str, t.Any]]: - study_version = int(study_version) +def get_fields_info(study_version: int) -> t.Mapping[str, t.Mapping[str, t.Any]]: return {key: info for key, info in FIELDS_INFO.items() if (info.get("start_version") or 0) <= study_version} diff --git a/antarest/study/business/thematic_trimming_management.py b/antarest/study/business/thematic_trimming_management.py index 17ca5a0290..96bd5732d6 100644 --- a/antarest/study/business/thematic_trimming_management.py +++ b/antarest/study/business/thematic_trimming_management.py @@ -49,7 +49,7 @@ def set_field_values(self, study: Study, field_values: ThematicTrimmingFormField Set Thematic Trimming config from the webapp form """ file_study = self.storage_service.get_storage(study).get_raw(study) - field_values_dict = field_values.dict() + field_values_dict = field_values.model_dump() keys_by_bool: t.Dict[bool, t.List[t.Any]] = {True: [], False: []} fields_info = get_fields_info(int(study.version)) diff --git a/antarest/study/business/timeseries_config_management.py b/antarest/study/business/timeseries_config_management.py index 8f02a1c512..7e77612fd8 100644 --- a/antarest/study/business/timeseries_config_management.py +++ b/antarest/study/business/timeseries_config_management.py @@ -12,9 +12,10 @@ import typing as t -from pydantic import StrictBool, StrictInt, root_validator, validator +from pydantic import StrictBool, StrictInt, field_validator, model_validator from antarest.core.model import JSON +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.business.utils import GENERAL_DATA_PATH, FormFieldsBaseModel, execute_or_add_commands from antarest.study.model import Study @@ -39,28 +40,30 @@ class SeasonCorrelation(EnumIgnoreCase): ANNUAL = "annual" +@all_optional_model class TSFormFieldsForType(FormFieldsBaseModel): - stochastic_ts_status: t.Optional[StrictBool] - number: t.Optional[StrictInt] - refresh: t.Optional[StrictBool] - refresh_interval: t.Optional[StrictInt] - season_correlation: t.Optional[SeasonCorrelation] - store_in_input: t.Optional[StrictBool] - store_in_output: t.Optional[StrictBool] - intra_modal: t.Optional[StrictBool] - inter_modal: t.Optional[StrictBool] - - + stochastic_ts_status: StrictBool + number: StrictInt + refresh: StrictBool + refresh_interval: StrictInt + season_correlation: SeasonCorrelation + store_in_input: StrictBool + store_in_output: StrictBool + intra_modal: StrictBool + inter_modal: StrictBool + + +@all_optional_model class TSFormFields(FormFieldsBaseModel): - load: t.Optional[TSFormFieldsForType] = None - hydro: t.Optional[TSFormFieldsForType] = None - thermal: t.Optional[TSFormFieldsForType] = None - wind: t.Optional[TSFormFieldsForType] = None - solar: t.Optional[TSFormFieldsForType] = None - renewables: t.Optional[TSFormFieldsForType] = None - ntc: t.Optional[TSFormFieldsForType] = None - - @root_validator(pre=True) + load: TSFormFieldsForType + hydro: TSFormFieldsForType + thermal: TSFormFieldsForType + wind: TSFormFieldsForType + solar: TSFormFieldsForType + renewables: TSFormFieldsForType + ntc: TSFormFieldsForType + + @model_validator(mode="before") def check_type_validity( cls, values: t.Dict[str, t.Optional[TSFormFieldsForType]] ) -> t.Dict[str, t.Optional[TSFormFieldsForType]]: @@ -73,7 +76,7 @@ def has_type(ts_type: TSType) -> bool: ) return values - @validator("thermal") + @field_validator("thermal") def thermal_validation(cls, v: TSFormFieldsForType) -> TSFormFieldsForType: if v.season_correlation is not None: raise ValueError("season_correlation is not allowed for 'thermal' type") @@ -130,7 +133,7 @@ def __set_field_values_for_type( field_values: TSFormFieldsForType, ) -> None: commands: t.List[UpdateConfig] = [] - values = field_values.dict() + values = field_values.model_dump() for field, path in PATH_BY_TS_STR_FIELD.items(): field_val = values[field] diff --git a/antarest/study/business/utils.py b/antarest/study/business/utils.py index ac1ea418a1..1afacecc04 100644 --- a/antarest/study/business/utils.py +++ b/antarest/study/business/utils.py @@ -17,7 +17,7 @@ from antarest.core.exceptions import CommandApplicationError from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters -from antarest.core.utils.string import to_camel_case +from antarest.study.business.all_optional_meta import camel_case_model from antarest.study.model import RawStudy, Study from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.storage_service import StudyStorageService @@ -69,12 +69,12 @@ def execute_or_add_commands( ) +@camel_case_model class FormFieldsBaseModel( BaseModel, - alias_generator=to_camel_case, extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Pydantic Model for webapp form diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py index 0951ce9dfb..6c02724afd 100644 --- a/antarest/study/business/xpansion_management.py +++ b/antarest/study/business/xpansion_management.py @@ -19,11 +19,11 @@ import zipfile from fastapi import HTTPException, UploadFile -from pydantic import BaseModel, Extra, Field, ValidationError, root_validator, validator +from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator from antarest.core.exceptions import BadZipBinary, ChildNotFoundError from antarest.core.model import JSON -from antarest.study.business.all_optional_meta import AllOptionalMetaclass +from antarest.study.business.all_optional_meta import all_optional_model from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.model import Study from antarest.study.storage.rawstudy.model.filesystem.bucket_node import BucketNode @@ -74,12 +74,12 @@ class XpansionSensitivitySettings(BaseModel): projection: t.List[str] = Field(default_factory=list, description="List of candidate names to project") capex: bool = Field(default=False, description="Whether to include capex in the sensitivity analysis") - @validator("projection", pre=True) + @field_validator("projection", mode="before") def projection_validation(cls, v: t.Optional[t.Sequence[str]]) -> t.Sequence[str]: return [] if v is None else v -class XpansionSettings(BaseModel, extra=Extra.ignore, validate_assignment=True, allow_population_by_field_name=True): +class XpansionSettings(BaseModel, extra="ignore", validate_assignment=True, populate_by_name=True): """ A data transfer object representing the general settings used for Xpansion. @@ -152,8 +152,8 @@ class XpansionSettings(BaseModel, extra=Extra.ignore, validate_assignment=True, # The sensitivity analysis is optional sensitivity_config: t.Optional[XpansionSensitivitySettings] = None - @root_validator(pre=True) - def normalize_values(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[str, t.Any]: + @model_validator(mode="before") + def validate_float_values(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[str, t.Any]: if "relaxed-optimality-gap" in values: values["relaxed_optimality_gap"] = values.pop("relaxed-optimality-gap") @@ -208,7 +208,8 @@ def from_config(cls, config_obj: JSON) -> "GetXpansionSettings": return cls.construct(**config_obj) -class UpdateXpansionSettings(XpansionSettings, metaclass=AllOptionalMetaclass, use_none=True): +@all_optional_model +class UpdateXpansionSettings(XpansionSettings): """ DTO object used to update the Xpansion settings. @@ -219,13 +220,6 @@ class UpdateXpansionSettings(XpansionSettings, metaclass=AllOptionalMetaclass, u # note: for some reason, the alias is not taken into account when using the metaclass, # so we have to redefine the fields with the alias. - - # On the other hand, we make these fields mandatory, because there is an anomaly on the front side: - # When the user does not select any file, the front sends a request without the "yearly-weights" - # or "additional-constraints" field, instead of sending the field with an empty value. - # This is not a problem as long as the front sends a request with all the fields (PUT case), - # but it is a problem for partial requests (PATCH case). - yearly_weights: str = Field( "", alias="yearly-weights", @@ -245,19 +239,21 @@ class XpansionCandidateDTO(BaseModel): name: str link: str annual_cost_per_mw: float = Field(alias="annual-cost-per-mw", ge=0) - unit_size: t.Optional[float] = Field(None, alias="unit-size", ge=0) - max_units: t.Optional[int] = Field(None, alias="max-units", ge=0) - max_investment: t.Optional[float] = Field(None, alias="max-investment", ge=0) - already_installed_capacity: t.Optional[int] = Field(None, alias="already-installed-capacity", ge=0) + unit_size: t.Optional[float] = Field(default=None, alias="unit-size", ge=0) + max_units: t.Optional[int] = Field(default=None, alias="max-units", ge=0) + max_investment: t.Optional[float] = Field(default=None, alias="max-investment", ge=0) + already_installed_capacity: t.Optional[int] = Field(default=None, alias="already-installed-capacity", ge=0) # this is obsolete (replaced by direct/indirect) - link_profile: t.Optional[str] = Field(None, alias="link-profile") + link_profile: t.Optional[str] = Field(default=None, alias="link-profile") # this is obsolete (replaced by direct/indirect) - already_installed_link_profile: t.Optional[str] = Field(None, alias="already-installed-link-profile") - direct_link_profile: t.Optional[str] = Field(None, alias="direct-link-profile") - indirect_link_profile: t.Optional[str] = Field(None, alias="indirect-link-profile") - already_installed_direct_link_profile: t.Optional[str] = Field(None, alias="already-installed-direct-link-profile") + already_installed_link_profile: t.Optional[str] = Field(default=None, alias="already-installed-link-profile") + direct_link_profile: t.Optional[str] = Field(default=None, alias="direct-link-profile") + indirect_link_profile: t.Optional[str] = Field(default=None, alias="indirect-link-profile") + already_installed_direct_link_profile: t.Optional[str] = Field( + default=None, alias="already-installed-direct-link-profile" + ) already_installed_indirect_link_profile: t.Optional[str] = Field( - None, alias="already-installed-indirect-link-profile" + default=None, alias="already-installed-indirect-link-profile" ) @@ -347,9 +343,11 @@ def create_xpansion_configuration(self, study: Study, zipped_config: t.Optional[ raise BadZipBinary("Only zip file are allowed.") xpansion_settings = XpansionSettings() - settings_obj = xpansion_settings.dict(by_alias=True, exclude_none=True, exclude={"sensitivity_config"}) + settings_obj = xpansion_settings.model_dump( + by_alias=True, exclude_none=True, exclude={"sensitivity_config"} + ) if xpansion_settings.sensitivity_config: - sensitivity_obj = xpansion_settings.sensitivity_config.dict(by_alias=True, exclude_none=True) + sensitivity_obj = xpansion_settings.sensitivity_config.model_dump(by_alias=True, exclude_none=True) else: sensitivity_obj = {} @@ -389,7 +387,9 @@ def update_xpansion_settings( logger.info(f"Updating xpansion settings for study '{study.id}'") actual_settings = self.get_xpansion_settings(study) - settings_fields = new_xpansion_settings.dict(by_alias=False, exclude_none=True, exclude={"sensitivity_config"}) + settings_fields = new_xpansion_settings.model_dump( + by_alias=False, exclude_none=True, exclude={"sensitivity_config"} + ) updated_settings = actual_settings.copy(deep=True, update=settings_fields) file_study = self.study_storage_service.get_storage(study).get_raw(study) @@ -409,11 +409,11 @@ def update_xpansion_settings( msg = f"Additional constraints file '{constraints_file}' does not exist" raise XpansionFileNotFoundError(msg) from None - config_obj = updated_settings.dict(by_alias=True, exclude={"sensitivity_config"}) + config_obj = updated_settings.model_dump(by_alias=True, exclude={"sensitivity_config"}) file_study.tree.save(config_obj, ["user", "expansion", "settings"]) if new_xpansion_settings.sensitivity_config: - sensitivity_obj = new_xpansion_settings.sensitivity_config.dict(by_alias=True) + sensitivity_obj = new_xpansion_settings.sensitivity_config.model_dump(by_alias=True) file_study.tree.save(sensitivity_obj, ["user", "expansion", "sensitivity", "sensitivity_in"]) return self.get_xpansion_settings(study) @@ -553,7 +553,7 @@ def add_candidate(self, study: Study, xpansion_candidate: XpansionCandidateDTO) ) # The primary key is actually the name, the id does not matter and is never checked. logger.info(f"Adding candidate '{xpansion_candidate.name}' to study '{study.id}'") - candidates_obj[next_id] = xpansion_candidate.dict(by_alias=True, exclude_none=True) + candidates_obj[next_id] = xpansion_candidate.model_dump(by_alias=True, exclude_none=True) candidates_data = {"user": {"expansion": {"candidates": candidates_obj}}} file_study.tree.save(candidates_data) # Should we add a field in the study config containing the xpansion candidates like the links or the areas ? @@ -594,7 +594,7 @@ def update_candidate( for candidate_id, candidate in candidates.items(): if candidate["name"] == candidate_name: logger.info(f"Updating candidate '{candidate_name}' of study '{study.id}'") - candidates[candidate_id] = xpansion_candidate_dto.dict(by_alias=True, exclude_none=True) + candidates[candidate_id] = xpansion_candidate_dto.model_dump(by_alias=True, exclude_none=True) file_study.tree.save(candidates, ["user", "expansion", "candidates"]) return raise CandidateNotFoundError(f"The candidate '{xpansion_candidate_dto.name}' does not exist") @@ -614,7 +614,8 @@ def update_xpansion_constraints_settings(self, study: Study, constraints_file_na # Make sure filename is not `None`, because `None` values are ignored by the update. constraints_file_name = constraints_file_name or "" # noinspection PyArgumentList - xpansion_settings = UpdateXpansionSettings(additional_constraints=constraints_file_name) + args = {"additional_constraints": constraints_file_name} + xpansion_settings = UpdateXpansionSettings.model_validate(args) return self.update_xpansion_settings(study, xpansion_settings) def _raw_file_dir(self, raw_file_type: XpansionResourceFileType) -> t.List[str]: @@ -655,6 +656,7 @@ def _add_raw_files( content = file.file.read() if isinstance(content, str): content = content.encode(encoding="utf-8") + assert file.filename is not None buffer[file.filename] = content file_study.tree.save(data) diff --git a/antarest/study/main.py b/antarest/study/main.py index e80a353493..1efa9cb00b 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache @@ -39,7 +40,7 @@ def build_study_service( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, user_service: LoginService, matrix_service: ISimpleMatrixService, @@ -124,16 +125,17 @@ def build_study_service( config=config, ) - if application: - application.include_router(create_study_routes(study_service, file_transfer_manager, config)) - application.include_router(create_raw_study_routes(study_service, config)) - application.include_router(create_study_data_routes(study_service, config)) - application.include_router( + if app_ctxt: + api_root = app_ctxt.api_root + api_root.include_router(create_study_routes(study_service, file_transfer_manager, config)) + api_root.include_router(create_raw_study_routes(study_service, config)) + api_root.include_router(create_study_data_routes(study_service, config)) + api_root.include_router( create_study_variant_routes( study_service=study_service, config=config, ) ) - application.include_router(create_xpansion_routes(study_service, config)) + api_root.include_router(create_xpansion_routes(study_service, config)) return study_service diff --git a/antarest/study/model.py b/antarest/study/model.py index 543490d3c2..081fd48b4a 100644 --- a/antarest/study/model.py +++ b/antarest/study/model.py @@ -18,7 +18,7 @@ from datetime import datetime, timedelta from pathlib import Path -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from sqlalchemy import ( # type: ignore Boolean, Column, @@ -351,13 +351,18 @@ class StudyMetadataDTO(BaseModel): workspace: str managed: bool archived: bool - horizon: t.Optional[str] - scenario: t.Optional[str] - status: t.Optional[str] - doc: t.Optional[str] + horizon: t.Optional[str] = None + scenario: t.Optional[str] = None + status: t.Optional[str] = None + doc: t.Optional[str] = None folder: t.Optional[str] = None tags: t.List[str] = [] + @field_validator("horizon", mode="before") + def transform_horizon_to_str(cls, val: t.Union[str, int, None]) -> t.Optional[str]: + # horizon can be an int. + return str(val) if val else val # type: ignore + class StudyMetadataPatchDTO(BaseModel): name: t.Optional[str] = None @@ -366,18 +371,20 @@ class StudyMetadataPatchDTO(BaseModel): scenario: t.Optional[str] = None status: t.Optional[str] = None doc: t.Optional[str] = None - tags: t.Sequence[str] = () + tags: t.List[str] = [] - @validator("tags", each_item=True) - def _normalize_tags(cls, v: str) -> str: + @field_validator("tags", mode="before") + def _normalize_tags(cls, v: t.List[str]) -> t.List[str]: """Remove leading and trailing whitespaces, and replace consecutive whitespaces by a single one.""" - tag = " ".join(v.split()) - if not tag: - raise ValueError("Tag cannot be empty") - elif len(tag) > 40: - raise ValueError(f"Tag is too long: {tag!r}") - else: - return tag + tags = [] + for tag in v: + tag = " ".join(tag.split()) + if not tag: + raise ValueError("Tag cannot be empty") + elif len(tag) > 40: + raise ValueError(f"Tag is too long: {tag!r}") + tags.append(tag) + return tags class StudySimSettingsDTO(BaseModel): diff --git a/antarest/study/service.py b/antarest/study/service.py index 2bff77b1ec..125c4e1531 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -1330,7 +1330,7 @@ def export_task(_notifier: TaskUpdateNotifier) -> TaskResult: else: json_response = json.dumps( - matrix.dict(), + matrix.model_dump(), ensure_ascii=False, allow_nan=True, indent=None, @@ -1485,6 +1485,7 @@ def _create_edit_study_command( context = self.storage_service.variant_study_service.command_factory.command_context if isinstance(tree_node, IniFileNode): + assert not isinstance(data, (bytes, list)) return UpdateConfig( target=url, data=data, @@ -1500,6 +1501,7 @@ def _create_edit_study_command( matrix=matrix.tolist(), command_context=context, ) + assert isinstance(data, (list, str)) return ReplaceMatrix( target=url, matrix=data, @@ -1507,6 +1509,9 @@ def _create_edit_study_command( ) elif isinstance(tree_node, RawFileNode): if url.split("/")[-1] == "comments": + if isinstance(data, bytes): + data = data.decode("utf-8") + assert isinstance(data, str) return UpdateComments( comments=data, command_context=context, @@ -2425,7 +2430,7 @@ def unarchive_output_task( src=str(src), dest=str(dest), remove_src=not keep_src_zip, - ).dict(), + ).model_dump(), name=task_name, ref_id=study.id, request_params=params, diff --git a/antarest/study/storage/abstract_storage_service.py b/antarest/study/storage/abstract_storage_service.py index 045c33c81c..5f2c033ca3 100644 --- a/antarest/study/storage/abstract_storage_service.py +++ b/antarest/study/storage/abstract_storage_service.py @@ -88,7 +88,7 @@ def get_study_information( try: patch_obj = json.loads(additional_data.patch or "{}") - patch = Patch.parse_obj(patch_obj) + patch = Patch.model_validate(patch_obj) except ValueError as e: # The conversion to JSON and the parsing can fail if the patch is not valid logger.warning(f"Failed to parse patch for study {study.id}", exc_info=e) @@ -328,7 +328,7 @@ def _read_additional_data_from_files(self, file_study: FileStudy) -> StudyAdditi horizon = file_study.tree.get(url=["settings", "generaldata", "general", "horizon"]) author = file_study.tree.get(url=["study", "antares", "author"]) patch = self.patch_service.get_from_filestudy(file_study) - study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.json()) + study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.model_dump_json()) return study_additional_data def archive_study_output(self, study: T, output_id: str) -> bool: diff --git a/antarest/study/storage/patch_service.py b/antarest/study/storage/patch_service.py index 7570401aab..0f3eb68c05 100644 --- a/antarest/study/storage/patch_service.py +++ b/antarest/study/storage/patch_service.py @@ -35,7 +35,7 @@ def get(self, study: t.Union[RawStudy, VariantStudy], get_from_file: bool = Fals # the `study.additional_data.patch` field is optional if study.additional_data.patch: patch_obj = json.loads(study.additional_data.patch or "{}") - return Patch.parse_obj(patch_obj) + return Patch.model_validate(patch_obj) patch = Patch() patch_path = Path(study.path) / PATCH_JSON @@ -67,9 +67,9 @@ def set_reference_output( def save(self, study: t.Union[RawStudy, VariantStudy], patch: Patch) -> None: if self.repository: study.additional_data = study.additional_data or StudyAdditionalData() - study.additional_data.patch = patch.json() + study.additional_data.patch = patch.model_dump_json() self.repository.save(study) patch_path = (Path(study.path)) / PATCH_JSON patch_path.parent.mkdir(parents=True, exist_ok=True) - patch_path.write_text(patch.json()) + patch_path.write_text(patch.model_dump_json()) diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/area.py b/antarest/study/storage/rawstudy/model/filesystem/config/area.py index 964e46810d..74d15e5b55 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/area.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/area.py @@ -16,7 +16,7 @@ import typing as t -from pydantic import Field, root_validator, validator +from pydantic import Field, field_validator, model_validator from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.storage.rawstudy.model.filesystem.config.field_validators import ( @@ -54,7 +54,7 @@ class OptimizationProperties(IniProperties): >>> opt = OptimizationProperties(**obj) - >>> pprint(opt.dict(by_alias=True), width=80) + >>> pprint(opt.model_dump(by_alias=True), width=80) {'filtering': {'filter-synthesis': 'hourly, daily, weekly, monthly, annual', 'filter-year-by-year': 'hourly, annual'}, 'nodal optimization': {'dispatchable-hydro-power': False, @@ -75,7 +75,7 @@ class OptimizationProperties(IniProperties): Convert the object to a dictionary for writing to a configuration file: - >>> pprint(opt.dict(by_alias=True, exclude_defaults=True), width=80) + >>> pprint(opt.model_dump(by_alias=True, exclude_defaults=True), width=80) {'filtering': {'filter-synthesis': 'hourly, weekly, monthly, annual', 'filter-year-by-year': 'hourly, monthly, annual'}, 'nodal optimization': {'dispatchable-hydro-power': False, @@ -89,7 +89,7 @@ class FilteringSection(IniProperties): filter_synthesis: str = Field("", alias="filter-synthesis") filter_year_by_year: str = Field("", alias="filter-year-by-year") - @validator("filter_synthesis", "filter_year_by_year", pre=True) + @field_validator("filter_synthesis", "filter_year_by_year", mode="before") def _validate_filtering(cls, v: t.Any) -> str: return validate_filtering(v) @@ -159,33 +159,33 @@ class AreaUI(IniProperties): ... "color_b": 255, ... } >>> ui = AreaUI(**obj) - >>> pprint(ui.dict(by_alias=True), width=80) + >>> pprint(ui.model_dump(by_alias=True), width=80) {'colorRgb': '#0080FF', 'x': 1148, 'y': 144} Update the color: >>> ui.color_rgb = (192, 168, 127) - >>> pprint(ui.dict(by_alias=True), width=80) + >>> pprint(ui.model_dump(by_alias=True), width=80) {'colorRgb': '#C0A87F', 'x': 1148, 'y': 144} """ - x: int = Field(0, description="x coordinate of the area in the map") - y: int = Field(0, description="y coordinate of the area in the map") + x: int = Field(default=0, description="x coordinate of the area in the map") + y: int = Field(default=0, description="y coordinate of the area in the map") color_rgb: str = Field( - "#E66C2C", + default="#E66C2C", alias="colorRgb", description="color of the area in the map", ) - @validator("color_rgb", pre=True) + @field_validator("color_rgb", mode="before") def _validate_color_rgb(cls, v: t.Any) -> str: return validate_color_rgb(v) - @root_validator(pre=True) + @model_validator(mode="before") def _validate_colors(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: return validate_colors(values) - def to_config(self) -> t.Mapping[str, t.Any]: + def to_config(self) -> t.Dict[str, t.Any]: """ Convert the object to a dictionary for writing to a configuration file: @@ -198,6 +198,7 @@ def to_config(self) -> t.Mapping[str, t.Any]: >>> pprint(ui.to_config(), width=80) {'color_b': 255, 'color_g': 128, 'color_r': 0, 'x': 1148, 'y': 144} """ + assert self.color_rgb is not None r = int(self.color_rgb[1:3], 16) g = int(self.color_rgb[3:5], 16) b = int(self.color_rgb[5:7], 16) @@ -216,7 +217,7 @@ class UIProperties(IniProperties): UIProperties has default values for `style` and `layers`: >>> ui = UIProperties() - >>> pprint(ui.dict(), width=80) + >>> pprint(ui.model_dump(), width=80) {'layer_styles': {0: {'color_rgb': '#E66C2C', 'x': 0, 'y': 0}}, 'layers': {0}, 'style': {'color_rgb': '#E66C2C', 'x': 0, 'y': 0}} @@ -244,7 +245,7 @@ class UIProperties(IniProperties): ... } >>> ui = UIProperties(**obj) - >>> pprint(ui.dict(), width=80) + >>> pprint(ui.model_dump(), width=80) {'layer_styles': {0: {'color_rgb': '#0080FF', 'x': 1148, 'y': 144}, 4: {'color_rgb': '#0080FF', 'x': 1148, 'y': 144}, 6: {'color_rgb': '#C0A863', 'x': 1148, 'y': 144}, @@ -260,7 +261,7 @@ class UIProperties(IniProperties): description="style of the area in the map: coordinates and color", ) layers: t.Set[int] = Field( - default_factory=set, + default_factory=lambda: {0}, description="layers where the area is visible", ) layer_styles: t.Dict[int, AreaUI] = Field( @@ -269,28 +270,20 @@ class UIProperties(IniProperties): alias="layerStyles", ) - @root_validator(pre=True) - def _set_default_style(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: + @staticmethod + def _set_default_style(values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: """Defined the default style if missing.""" - style = values.get("style") + style = values.get("style", None) if style is None: values["style"] = AreaUI() elif isinstance(style, dict): values["style"] = AreaUI(**style) else: - values["style"] = AreaUI(**style.dict()) + values["style"] = AreaUI(**style.model_dump()) return values - @root_validator(pre=True) - def _set_default_layers(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: - """Define the default layers if missing.""" - _layers = values.get("layers") - if _layers is None: - values["layers"] = {0} - return values - - @root_validator(pre=True) - def _set_default_layer_styles(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: + @staticmethod + def _set_default_layer_styles(values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: """Define the default layer styles if missing.""" layer_styles = values.get("layer_styles") if layer_styles is None: @@ -302,13 +295,15 @@ def _set_default_layer_styles(cls, values: t.MutableMapping[str, t.Any]) -> t.Ma if isinstance(style, dict): values["layer_styles"][key] = AreaUI(**style) else: - values["layer_styles"][key] = AreaUI(**style.dict()) + values["layer_styles"][key] = AreaUI(**style.model_dump()) else: raise TypeError(f"Invalid type for layer_styles: {type(layer_styles)}") return values - @root_validator(pre=True) + @model_validator(mode="before") def _validate_layers(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: + cls._set_default_style(values) + cls._set_default_layer_styles(values) # Parse the `[ui]` section (if any) ui_section = values.pop("ui", {}) if ui_section: @@ -347,7 +342,7 @@ def _validate_layers(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str return values - def to_config(self) -> t.Mapping[str, t.Mapping[str, t.Any]]: + def to_config(self) -> t.Dict[str, t.Dict[str, t.Any]]: """ Convert the object to a dictionary for writing to a configuration file: @@ -375,7 +370,7 @@ def to_config(self) -> t.Mapping[str, t.Mapping[str, t.Any]]: 'x': 1148, 'y': 144}} """ - obj: t.MutableMapping[str, t.MutableMapping[str, t.Any]] = { + obj: t.Dict[str, t.Dict[str, t.Any]] = { "ui": {}, "layerX": {}, "layerY": {}, @@ -386,6 +381,7 @@ def to_config(self) -> t.Mapping[str, t.Mapping[str, t.Any]]: for layer, style in self.layer_styles.items(): obj["layerX"][str(layer)] = style.x obj["layerY"][str(layer)] = style.y + assert style.color_rgb is not None r = int(style.color_rgb[1:3], 16) g = int(style.color_rgb[3:5], 16) b = int(style.color_rgb[5:7], 16) @@ -405,7 +401,7 @@ class AreaFolder(IniProperties): Create and validate a new AreaProperties object from a dictionary read from a configuration file. >>> obj = AreaFolder() - >>> pprint(obj.dict(), width=80) + >>> pprint(obj.model_dump(), width=80) {'adequacy_patch': None, 'optimization': {'filtering': {'filter_synthesis': '', 'filter_year_by_year': ''}, @@ -450,7 +446,7 @@ class AreaFolder(IniProperties): ... } >>> obj = AreaFolder.construct(**data) - >>> pprint(obj.dict(), width=80) + >>> pprint(obj.model_dump(), width=80) {'adequacy_patch': None, 'optimization': {'filtering': {'filter-synthesis': 'annual, centennial'}, 'nodal optimization': {'spread-spilled-energy-cost': '15.5', @@ -512,7 +508,7 @@ class ThermalAreasProperties(IniProperties): ... }, ... } >>> area = ThermalAreasProperties(**obj) - >>> pprint(area.dict(), width=80) + >>> pprint(area.model_dump(), width=80) {'spilled_energy_cost': {'cz': 100.0}, 'unserverd_energy_cost': {'at': 4000.8, 'be': 3500.0, @@ -523,7 +519,7 @@ class ThermalAreasProperties(IniProperties): >>> area.unserverd_energy_cost["at"] = 6500.0 >>> area.unserverd_energy_cost["fr"] = 0.0 - >>> pprint(area.dict(), width=80) + >>> pprint(area.model_dump(), width=80) {'spilled_energy_cost': {'cz': 100.0}, 'unserverd_energy_cost': {'at': 6500.0, 'be': 3500.0, 'de': 1250.0, 'fr': 0.0}} @@ -546,7 +542,7 @@ class ThermalAreasProperties(IniProperties): description="spilled energy cost (€/MWh) of each area", ) - @validator("unserverd_energy_cost", "spilled_energy_cost", pre=True) + @field_validator("unserverd_energy_cost", "spilled_energy_cost", mode="before") def _validate_energy_cost(cls, v: t.Any) -> t.MutableMapping[str, float]: if isinstance(v, dict): return {str(k): float(v) for k, v in v.items()} diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py index 0ead589109..62393794e2 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/cluster.py @@ -19,15 +19,15 @@ import functools import typing as t -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Field @functools.total_ordering class ItemProperties( BaseModel, - extra=Extra.forbid, + extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Common properties related to thermal and renewable clusters, and short-term storage. @@ -47,7 +47,7 @@ class ItemProperties( group: str = Field(default="", description="Cluster group") - name: str = Field(description="Cluster name", regex=r"[a-zA-Z0-9_(),& -]+") + name: str = Field(description="Cluster name", pattern=r"[a-zA-Z0-9_(),& -]+") def __lt__(self, other: t.Any) -> bool: """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py b/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py index 3a544d6bb8..f6a371769b 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/field_validators.py @@ -15,7 +15,7 @@ _ALL_FILTERING = ["hourly", "daily", "weekly", "monthly", "annual"] -def extract_filtering(v: t.Any) -> t.Sequence[str]: +def extract_filtering(v: t.Any) -> t.List[str]: """ Extract filtering values from a comma-separated list of values. """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/files.py b/antarest/study/storage/rawstudy/model/filesystem/config/files.py index 296ee6cf19..15856dbf60 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/files.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/files.py @@ -26,7 +26,6 @@ DEFAULT_GROUP, DEFAULT_OPERATOR, DEFAULT_TIMESTEP, - BindingConstraintFrequency, ) from antarest.study.storage.rawstudy.model.filesystem.config.exceptions import ( SimulationParsingError, diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py b/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py index d30cc34fae..2cb5d9ec64 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/identifier.py @@ -12,22 +12,22 @@ import typing as t -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import BaseModel, Field, model_validator __all__ = ("IgnoreCaseIdentifier", "LowerCaseIdentifier") class IgnoreCaseIdentifier( BaseModel, - extra=Extra.forbid, + extra="forbid", validate_assignment=True, - allow_population_by_field_name=True, + populate_by_name=True, ): """ Base class for all configuration sections with an ID. """ - id: str = Field(description="ID (section name)", regex=r"[a-zA-Z0-9_(),& -]+") + id: str = Field(description="ID (section name)", pattern=r"[a-zA-Z0-9_(),& -]+") @classmethod def generate_id(cls, name: str) -> str: @@ -45,7 +45,7 @@ def generate_id(cls, name: str) -> str: return transform_name_to_id(name, lower=False) - @root_validator(pre=True) + @model_validator(mode="before") def validate_id(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: """ Calculate an ID based on the name, if not provided. @@ -56,18 +56,21 @@ def validate_id(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.A Returns: The updated values. """ - if storage_id := values.get("id"): - # If the ID is provided, it comes from a INI section name. - # In some legacy case, the ID was in lower case, so we need to convert it. - values["id"] = cls.generate_id(storage_id) - return values - if not values.get("name"): - return values - name = values["name"] - if storage_id := cls.generate_id(name): - values["id"] = storage_id - else: - raise ValueError(f"Invalid name '{name}'.") + + # For some reason I can't explain, values can be an object. If so, no validation is needed. + if isinstance(values, dict): + if storage_id := values.get("id"): + # If the ID is provided, it comes from a INI section name. + # In some legacy case, the ID was in lower case, so we need to convert it. + values["id"] = cls.generate_id(storage_id) + return values + if not values.get("name"): + return values + name = values["name"] + if storage_id := cls.generate_id(name): + values["id"] = storage_id + else: + raise ValueError(f"Invalid name '{name}'.") return values @@ -76,7 +79,7 @@ class LowerCaseIdentifier(IgnoreCaseIdentifier): Base class for all configuration sections with a lower case ID. """ - id: str = Field(description="ID (section name)", regex=r"[a-z0-9_(),& -]+") + id: str = Field(description="ID (section name)", pattern=r"[a-z0-9_(),& -]+") @classmethod def generate_id(cls, name: str) -> str: diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py b/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py index 84410b9e37..5464975bbf 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/ini_properties.py @@ -13,7 +13,7 @@ import json import typing as t -from pydantic import BaseModel, Extra +from pydantic import BaseModel class IniProperties( @@ -21,17 +21,17 @@ class IniProperties( # On reading, if the configuration contains an extra field, it is better # to forbid it, because it allows errors to be detected early. # Ignoring extra attributes can hide errors. - extra=Extra.forbid, + extra="forbid", # If a field is updated on assignment, it is also validated. validate_assignment=True, # On testing, we can use snake_case for field names. - allow_population_by_field_name=True, + populate_by_name=True, ): """ Base class for configuration sections. """ - def to_config(self) -> t.Mapping[str, t.Any]: + def to_config(self) -> t.Dict[str, t.Any]: """ Convert the object to a dictionary for writing to a configuration file (`*.ini`). @@ -40,14 +40,16 @@ def to_config(self) -> t.Mapping[str, t.Any]: """ config = {} - for field_name, field in self.__fields__.items(): + for field_name, field in self.model_fields.items(): value = getattr(self, field_name) if value is None: continue + alias = field.alias + assert alias is not None if isinstance(value, IniProperties): - config[field.alias] = value.to_config() + config[alias] = value.to_config() else: - config[field.alias] = json.loads(json.dumps(value)) + config[alias] = json.loads(json.dumps(value)) return config @classmethod @@ -56,7 +58,7 @@ def construct(cls, _fields_set: t.Optional[t.Set[str]] = None, **values: t.Any) Construct a new model instance from a dict of values, replacing aliases with real field names. """ # The pydantic construct() function does not allow aliases to be handled. - aliases = {(field.alias or name): name for name, field in cls.__fields__.items()} + aliases = {(field.alias or name): name for name, field in cls.model_fields.items()} renamed_values = {aliases.get(k, k): v for k, v in values.items()} if _fields_set is not None: _fields_set = {aliases.get(f, f) for f in _fields_set} diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/links.py b/antarest/study/storage/rawstudy/model/filesystem/config/links.py index 12e1be47e5..88059ef50e 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/links.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/links.py @@ -16,7 +16,7 @@ import typing as t -from pydantic import Field, root_validator, validator +from pydantic import Field, field_validator, model_validator from antarest.study.business.enum_ignore_case import EnumIgnoreCase from antarest.study.storage.rawstudy.model.filesystem.config.field_validators import ( @@ -96,7 +96,7 @@ class LinkProperties(IniProperties): >>> opt = LinkProperties(**obj) - >>> pprint(opt.dict(by_alias=True), width=80) + >>> pprint(opt.model_dump(by_alias=True), width=80) {'asset-type': , 'colorRgb': '#50C0FF', 'comments': 'This is a link', @@ -146,20 +146,20 @@ class LinkProperties(IniProperties): description="color of the area in the map", ) - @validator("filter_synthesis", "filter_year_by_year", pre=True) + @field_validator("filter_synthesis", "filter_year_by_year", mode="before") def _validate_filtering(cls, v: t.Any) -> str: return validate_filtering(v) - @validator("color_rgb", pre=True) + @field_validator("color_rgb", mode="before") def _validate_color_rgb(cls, v: t.Any) -> str: return validate_color_rgb(v) - @root_validator(pre=True) + @model_validator(mode="before") def _validate_colors(cls, values: t.MutableMapping[str, t.Any]) -> t.Mapping[str, t.Any]: return validate_colors(values) # noinspection SpellCheckingInspection - def to_config(self) -> t.Mapping[str, t.Any]: + def to_config(self) -> t.Dict[str, t.Any]: """ Convert the object to a dictionary for writing to a configuration file. """ diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/model.py b/antarest/study/storage/rawstudy/model/filesystem/config/model.py index 5e916b5308..d0abd57710 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/model.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/model.py @@ -14,7 +14,7 @@ import typing as t from pathlib import Path -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from antarest.core.utils.utils import DTO from antarest.study.business.enum_ignore_case import EnumIgnoreCase @@ -64,7 +64,7 @@ class Link(BaseModel, extra="ignore"): filters_synthesis: t.List[str] = Field(default_factory=list) filters_year: t.List[str] = Field(default_factory=list) - @root_validator(pre=True) + @model_validator(mode="before") def validation(cls, values: t.MutableMapping[str, t.Any]) -> t.MutableMapping[str, t.Any]: # note: field names are in kebab-case in the INI file filters_synthesis = values.pop("filter-synthesis", values.pop("filters_synthesis", "")) @@ -94,7 +94,7 @@ class DistrictSet(BaseModel): Object linked to /inputs/sets.ini information """ - ALL = ["hourly", "daily", "weekly", "monthly", "annual"] + ALL: t.List[str] = ["hourly", "daily", "weekly", "monthly", "annual"] name: t.Optional[str] = None inverted_set: bool = False areas: t.Optional[t.List[str]] = None diff --git a/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py b/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py index 6c247a3600..f3839566fd 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py +++ b/antarest/study/storage/rawstudy/model/filesystem/config/thermal.py @@ -439,4 +439,4 @@ def create_thermal_config(study_version: t.Union[str, int], **kwargs: t.Any) -> ValueError: If the study version is not supported. """ cls = get_thermal_config_cls(study_version) - return cls(**kwargs) + return cls.model_validate(kwargs) diff --git a/antarest/study/storage/rawstudy/model/filesystem/factory.py b/antarest/study/storage/rawstudy/model/filesystem/factory.py index 798aef8ac2..40b28e33d8 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/factory.py +++ b/antarest/study/storage/rawstudy/model/filesystem/factory.py @@ -104,7 +104,7 @@ def _create_from_fs_unsafe( from_cache = self.cache.get(cache_id) if from_cache is not None: logger.info(f"Study {study_id} read from cache") - config = FileStudyTreeConfigDTO.parse_obj(from_cache).to_build_config() + config = FileStudyTreeConfigDTO.model_validate(from_cache).to_build_config() if output_path: config.output_path = output_path config.outputs = parse_outputs(output_path) @@ -118,7 +118,7 @@ def _create_from_fs_unsafe( logger.info(f"Cache new entry from StudyFactory (studyID: {study_id})") self.cache.put( cache_id, - FileStudyTreeConfigDTO.from_build_config(config).dict(), + FileStudyTreeConfigDTO.from_build_config(config).model_dump(), ) return result diff --git a/antarest/study/storage/rawstudy/raw_study_service.py b/antarest/study/storage/rawstudy/raw_study_service.py index df36d92c7f..5e9295107a 100644 --- a/antarest/study/storage/rawstudy/raw_study_service.py +++ b/antarest/study/storage/rawstudy/raw_study_service.py @@ -104,7 +104,7 @@ def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: t.Option metadata.updated_at = metadata.updated_at or datetime.utcnow() if metadata.additional_data is None: metadata.additional_data = StudyAdditionalData() - metadata.additional_data.patch = metadata.additional_data.patch or Patch().json() + metadata.additional_data.patch = metadata.additional_data.patch or Patch().model_dump_json() metadata.additional_data.author = metadata.additional_data.author or "Unknown" else: diff --git a/antarest/study/storage/study_download_utils.py b/antarest/study/storage/study_download_utils.py index c870ecad60..91b2b4f9af 100644 --- a/antarest/study/storage/study_download_utils.py +++ b/antarest/study/storage/study_download_utils.py @@ -345,7 +345,7 @@ def export( if filetype == ExportFormat.JSON: with open(target_file, "w") as fh: json.dump( - matrix.dict(), + matrix.model_dump(), fh, ensure_ascii=False, allow_nan=True, diff --git a/antarest/study/storage/variantstudy/business/command_extractor.py b/antarest/study/storage/variantstudy/business/command_extractor.py index 5c5e697204..adbcc811d8 100644 --- a/antarest/study/storage/variantstudy/business/command_extractor.py +++ b/antarest/study/storage/variantstudy/business/command_extractor.py @@ -223,7 +223,7 @@ def _extract_cluster(self, study: FileStudy, area_id: str, cluster_id: str, rene create_cluster_command( area_id=area_id, cluster_name=cluster.id, - parameters=cluster.dict(by_alias=True, exclude_defaults=True, exclude={"id"}), + parameters=cluster.model_dump(by_alias=True, exclude_defaults=True, exclude={"id"}), command_context=self.command_context, ), self.generate_replace_matrix( @@ -323,6 +323,7 @@ def extract_district(self, study: FileStudy, district_id: str) -> t.List[IComman district_config = study_config.sets[district_id] base_filter = DistrictBaseFilter.add_all if district_config.inverted_set else DistrictBaseFilter.remove_all district_fetched_config = study_tree.get(["input", "areas", "sets", district_id]) + assert district_config.name is not None study_commands.append( CreateDistrict( name=district_config.name, @@ -382,7 +383,8 @@ def extract_binding_constraint( matrices[name] = matrix["data"] # Create the command to create the binding constraint - create_cmd = CreateBindingConstraint(**binding, **matrices, coeffs=terms, command_context=self.command_context) + kwargs = {**binding, **matrices, "coeffs": terms, "command_context": self.command_context} + create_cmd = CreateBindingConstraint.model_validate(kwargs) return [create_cmd] @@ -416,8 +418,8 @@ def generate_update_playlist( config = study_tree.get(["settings", "generaldata"]) playlist = get_playlist(config) return UpdatePlaylist( - items=playlist.keys() if playlist else None, - weights=({year for year, weight in playlist.items() if weight != 1} if playlist else None), + items=list(playlist.keys()) if playlist else None, + weights=({year: weight for year, weight in playlist.items() if weight != 1} if playlist else None), active=bool(playlist and len(playlist) > 0), reverse=False, command_context=self.command_context, @@ -453,6 +455,7 @@ def generate_update_district( study_tree = study.tree district_config = study_config.sets[district_id] district_fetched_config = study_tree.get(["input", "areas", "sets", district_id]) + assert district_config.name is not None return UpdateDistrict( id=district_config.name, base_filter=DistrictBaseFilter.add_all if district_config.inverted_set else DistrictBaseFilter.remove_all, diff --git a/antarest/study/storage/variantstudy/business/command_reverter.py b/antarest/study/storage/variantstudy/business/command_reverter.py index bf034b4167..5134a86bea 100644 --- a/antarest/study/storage/variantstudy/business/command_reverter.py +++ b/antarest/study/storage/variantstudy/business/command_reverter.py @@ -135,7 +135,7 @@ def _revert_update_binding_constraint( if matrix is not None: args[matrix_name] = matrix_service.get_matrix_id(matrix) - return [UpdateBindingConstraint(**args)] + return [UpdateBindingConstraint.model_validate(args)] return base_command.get_command_extractor().extract_binding_constraint(base, base_command.id) diff --git a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py index bf577294eb..328e43f0f6 100644 --- a/antarest/study/storage/variantstudy/business/utils_binding_constraint.py +++ b/antarest/study/storage/variantstudy/business/utils_binding_constraint.py @@ -13,7 +13,6 @@ import typing as t from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( - DEFAULT_TIMESTEP, BindingConstraintFrequency, BindingConstraintOperator, ) diff --git a/antarest/study/storage/variantstudy/command_factory.py b/antarest/study/storage/variantstudy/command_factory.py index b5c63d258a..d178cfa79f 100644 --- a/antarest/study/storage/variantstudy/command_factory.py +++ b/antarest/study/storage/variantstudy/command_factory.py @@ -11,6 +11,7 @@ # This file is part of the Antares project. import typing as t +import copy from antarest.core.model import JSON from antarest.matrixstore.service import ISimpleMatrixService @@ -117,10 +118,15 @@ def to_command(self, command_dto: CommandDTO) -> t.List[ICommand]: """ args = command_dto.args if isinstance(args, dict): - return [self._to_single_command(command_dto.action, args, command_dto.version, command_dto.id)] + # In some cases, pydantic can modify inplace the given args. + # We don't want that so before doing so we copy the dictionnary. + new_args = copy.deepcopy(args) + return [self._to_single_command(command_dto.action, new_args, command_dto.version, command_dto.id)] elif isinstance(args, list): return [ - self._to_single_command(command_dto.action, argument, command_dto.version, command_dto.id) + self._to_single_command( + command_dto.action, copy.deepcopy(argument), command_dto.version, command_dto.id + ) for argument in args ] raise NotImplementedError() diff --git a/antarest/study/storage/variantstudy/model/command/create_area.py b/antarest/study/storage/variantstudy/model/command/create_area.py index 65854d772a..bfc8711b8e 100644 --- a/antarest/study/storage/variantstudy/model/command/create_area.py +++ b/antarest/study/storage/variantstudy/model/command/create_area.py @@ -61,8 +61,8 @@ class CreateArea(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_AREA - version = 1 + command_name: CommandName = CommandName.CREATE_AREA + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py index c3814cc340..50ce659fd9 100644 --- a/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/create_binding_constraint.py @@ -10,16 +10,15 @@ # # This file is part of the Antares project. -import json import typing as t from abc import ABCMeta from enum import Enum import numpy as np -from pydantic import BaseModel, Extra, Field, root_validator, validator +from pydantic import BaseModel, Field, field_validator, model_validator from antarest.matrixstore.model import MatrixData -from antarest.study.business.all_optional_meta import AllOptionalMetaclass +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import ( DEFAULT_GROUP, DEFAULT_OPERATOR, @@ -91,26 +90,24 @@ def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixTyp # ================================================================================= -class BindingConstraintPropertiesBase(BaseModel, extra=Extra.forbid, allow_population_by_field_name=True): +class BindingConstraintPropertiesBase(BaseModel, extra="forbid", populate_by_name=True): enabled: bool = True time_step: BindingConstraintFrequency = Field(DEFAULT_TIMESTEP, alias="type") operator: BindingConstraintOperator = DEFAULT_OPERATOR comments: str = "" - @classmethod - def from_dict(cls, **attrs: t.Any) -> "BindingConstraintPropertiesBase": - """ - Instantiate a class from a dictionary excluding unknown or `None` fields. - """ - attrs = {k: v for k, v in attrs.items() if k in cls.__fields__ and v is not None} - return cls(**attrs) + @model_validator(mode="before") + def replace_with_alias(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + if "type" in values: + values["time_step"] = values.pop("type") + return values class BindingConstraintProperties830(BindingConstraintPropertiesBase): filter_year_by_year: str = Field("", alias="filter-year-by-year") filter_synthesis: str = Field("", alias="filter-synthesis") - @validator("filter_synthesis", "filter_year_by_year", pre=True) + @field_validator("filter_synthesis", "filter_year_by_year", mode="before") def _validate_filtering(cls, v: t.Any) -> str: return validate_filtering(v) @@ -151,10 +148,12 @@ def create_binding_constraint_config(study_version: t.Union[str, int], **kwargs: The binding_constraint configuration model. """ cls = get_binding_constraint_config_cls(study_version) - return cls.from_dict(**kwargs) + attrs = {k: v for k, v in kwargs.items() if k in cls.model_fields and v is not None} + return cls(**attrs) -class OptionalProperties(BindingConstraintProperties870, metaclass=AllOptionalMetaclass, use_none=True): +@all_optional_model +class OptionalProperties(BindingConstraintProperties870): pass @@ -163,32 +162,30 @@ class OptionalProperties(BindingConstraintProperties870, metaclass=AllOptionalMe # ================================================================================= -class BindingConstraintMatrices(BaseModel, extra=Extra.forbid, allow_population_by_field_name=True): +@camel_case_model +class BindingConstraintMatrices(BaseModel, extra="forbid", populate_by_name=True): """ Class used to store the matrices of a binding constraint. """ values: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="2nd member matrix for studies before v8.7", ) less_term_matrix: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="less term matrix for v8.7+ studies", - alias="lessTermMatrix", ) greater_term_matrix: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="greater term matrix for v8.7+ studies", - alias="greaterTermMatrix", ) equal_term_matrix: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="equal term matrix for v8.7+ studies", - alias="equalTermMatrix", ) - @root_validator(pre=True) + @model_validator(mode="before") def check_matrices( cls, values: t.Dict[str, t.Optional[t.Union[MatrixType, str]]] ) -> t.Dict[str, t.Optional[t.Union[MatrixType, str]]]: @@ -215,10 +212,10 @@ class AbstractBindingConstraintCommand(OptionalProperties, BindingConstraintMatr Abstract class for binding constraint commands. """ - coeffs: t.Optional[t.Dict[str, t.List[float]]] + coeffs: t.Optional[t.Dict[str, t.List[float]]] = None def to_dto(self) -> CommandDTO: - json_command = json.loads(self.json(exclude={"command_context"})) + json_command = self.model_dump(mode="json", exclude={"command_context"}) args = {} for field in ["enabled", "coeffs", "comments", "time_step", "operator"]: if json_command[field]: @@ -400,7 +397,7 @@ class CreateBindingConstraint(AbstractBindingConstraintCommand): Command used to create a binding constraint. """ - command_name = CommandName.CREATE_BINDING_CONSTRAINT + command_name: CommandName = CommandName.CREATE_BINDING_CONSTRAINT version: int = 1 # Properties of the `CREATE_BINDING_CONSTRAINT` command: @@ -427,8 +424,8 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: bd_id = transform_name_to_id(self.name) study_version = study_data.config.version - props = create_binding_constraint_config(study_version, **self.dict()) - obj = json.loads(props.json(by_alias=True)) + props = create_binding_constraint_config(study_version, **self.model_dump()) + obj = props.model_dump(mode="json", by_alias=True) new_binding = {"id": bd_id, "name": self.name, **obj} @@ -454,9 +451,9 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: bd_id = transform_name_to_id(self.name) args = {"id": bd_id, "command_context": other.command_context} - excluded_fields = frozenset(ICommand.__fields__) - self_command = json.loads(self.json(exclude=excluded_fields)) - other_command = json.loads(other.json(exclude=excluded_fields)) + excluded_fields = set(ICommand.model_fields) + self_command = self.model_dump(mode="json", exclude=excluded_fields) + other_command = other.model_dump(mode="json", exclude=excluded_fields) properties = [ "enabled", "coeffs", @@ -480,7 +477,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: if self_matrix_id != other_matrix_id: args[matrix_name] = other_matrix_id - return [UpdateBindingConstraint(**args)] + return [UpdateBindingConstraint.model_validate(args)] def match(self, other: "ICommand", equal: bool = False) -> bool: if not isinstance(other, self.__class__): diff --git a/antarest/study/storage/variantstudy/model/command/create_cluster.py b/antarest/study/storage/variantstudy/model/command/create_cluster.py index 9b43d9de5b..a1c6ef17aa 100644 --- a/antarest/study/storage/variantstudy/model/command/create_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/create_cluster.py @@ -12,7 +12,7 @@ import typing as t -from pydantic import validator +from pydantic import Field, ValidationInfo, field_validator from antarest.core.model import JSON from antarest.core.utils.utils import assert_this @@ -38,46 +38,51 @@ class CreateCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_THERMAL_CLUSTER - version = 1 + command_name: CommandName = CommandName.CREATE_THERMAL_CLUSTER + version: int = 1 # Command parameters # ================== area_id: str cluster_name: str - parameters: t.Dict[str, str] - prepro: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = None - modulation: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = None + parameters: t.Dict[str, t.Any] + prepro: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = Field(None, validate_default=True) + modulation: t.Optional[t.Union[t.List[t.List[MatrixData]], str]] = Field(None, validate_default=True) - @validator("cluster_name") + @field_validator("cluster_name", mode="before") def validate_cluster_name(cls, val: str) -> str: valid_name = transform_name_to_id(val, lower=False) if valid_name != val: raise ValueError("Cluster name must only contains [a-zA-Z0-9],&,-,_,(,) characters") return val - @validator("prepro", always=True) + @field_validator("prepro", mode="before") def validate_prepro( - cls, v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], values: t.Any + cls, + v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], + values: t.Union[t.Dict[str, t.Any], ValidationInfo], ) -> t.Optional[t.Union[t.List[t.List[MatrixData]], str]]: + new_values = values if isinstance(values, dict) else values.data if v is None: - v = values["command_context"].generator_matrix_constants.get_thermal_prepro_data() + v = new_values["command_context"].generator_matrix_constants.get_thermal_prepro_data() return v - else: - return validate_matrix(v, values) + return validate_matrix(v, new_values) - @validator("modulation", always=True) + @field_validator("modulation", mode="before") def validate_modulation( - cls, v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], values: t.Any + cls, + v: t.Optional[t.Union[t.List[t.List[MatrixData]], str]], + values: t.Union[t.Dict[str, t.Any], ValidationInfo], ) -> t.Optional[t.Union[t.List[t.List[MatrixData]], str]]: + new_values = values if isinstance(values, dict) else values.data if v is None: - v = values["command_context"].generator_matrix_constants.get_thermal_prepro_modulation() + v = new_values["command_context"].generator_matrix_constants.get_thermal_prepro_modulation() return v else: - return validate_matrix(v, values) + return validate_matrix(v, new_values) def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: # Search the Area in the configuration diff --git a/antarest/study/storage/variantstudy/model/command/create_district.py b/antarest/study/storage/variantstudy/model/command/create_district.py index ac7086edac..afb9736806 100644 --- a/antarest/study/storage/variantstudy/model/command/create_district.py +++ b/antarest/study/storage/variantstudy/model/command/create_district.py @@ -13,7 +13,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, cast -from pydantic import validator +from pydantic import field_validator from antarest.study.storage.rawstudy.model.filesystem.config.model import ( DistrictSet, @@ -39,8 +39,8 @@ class CreateDistrict(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_DISTRICT - version = 1 + command_name: CommandName = CommandName.CREATE_DISTRICT + version: int = 1 # Command parameters # ================== @@ -51,7 +51,7 @@ class CreateDistrict(ICommand): output: bool = True comments: str = "" - @validator("name") + @field_validator("name") def validate_district_name(cls, val: str) -> str: valid_name = transform_name_to_id(val, lower=False) if valid_name != val: diff --git a/antarest/study/storage/variantstudy/model/command/create_link.py b/antarest/study/storage/variantstudy/model/command/create_link.py index 21b1f1dc67..ef2d21f0a9 100644 --- a/antarest/study/storage/variantstudy/model/command/create_link.py +++ b/antarest/study/storage/variantstudy/model/command/create_link.py @@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast -from pydantic import root_validator, validator +from pydantic import ValidationInfo, field_validator, model_validator from antarest.core.model import JSON from antarest.core.utils.utils import assert_this @@ -51,30 +51,31 @@ class CreateLink(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_LINK - version = 1 + command_name: CommandName = CommandName.CREATE_LINK + version: int = 1 # Command parameters # ================== area1: str area2: str - parameters: Optional[Dict[str, str]] = None + parameters: Optional[Dict[str, Any]] = None series: Optional[Union[List[List[MatrixData]], str]] = None direct: Optional[Union[List[List[MatrixData]], str]] = None indirect: Optional[Union[List[List[MatrixData]], str]] = None - @validator("series", "direct", "indirect", always=True) + @field_validator("series", "direct", "indirect", mode="before") def validate_series( - cls, v: Optional[Union[List[List[MatrixData]], str]], values: Any + cls, v: Optional[Union[List[List[MatrixData]], str]], values: Union[Dict[str, Any], ValidationInfo] ) -> Optional[Union[List[List[MatrixData]], str]]: - return validate_matrix(v, values) if v is not None else v + new_values = values if isinstance(values, dict) else values.data + return validate_matrix(v, new_values) if v is not None else v - @root_validator - def validate_areas(cls, values: Dict[str, Any]) -> Any: - if values.get("area1") == values.get("area2"): + @model_validator(mode="after") + def validate_areas(self) -> "CreateLink": + if self.area1 == self.area2: raise ValueError("Cannot create link on same node") - return values + return self def _create_link_in_config(self, area_from: str, area_to: str, study_data: FileStudyTreeConfig) -> None: self.parameters = self.parameters or {} diff --git a/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py b/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py index 9d27742118..1a932dd30d 100644 --- a/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/create_renewables_cluster.py @@ -12,7 +12,7 @@ import typing as t -from pydantic import validator +from pydantic import field_validator from antarest.core.model import JSON from antarest.study.storage.rawstudy.model.filesystem.config.model import ( @@ -36,17 +36,17 @@ class CreateRenewablesCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_RENEWABLES_CLUSTER - version = 1 + command_name: CommandName = CommandName.CREATE_RENEWABLES_CLUSTER + version: int = 1 # Command parameters # ================== area_id: str cluster_name: str - parameters: t.Dict[str, str] + parameters: t.Dict[str, t.Any] - @validator("cluster_name") + @field_validator("cluster_name") def validate_cluster_name(cls, val: str) -> str: valid_name = transform_name_to_id(val, lower=False) if valid_name != val: diff --git a/antarest/study/storage/variantstudy/model/command/create_st_storage.py b/antarest/study/storage/variantstudy/model/command/create_st_storage.py index fd8f445d4b..8244957da5 100644 --- a/antarest/study/storage/variantstudy/model/command/create_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/create_st_storage.py @@ -10,12 +10,10 @@ # # This file is part of the Antares project. -import json import typing as t import numpy as np -from pydantic import Field, validator -from pydantic.fields import ModelField +from pydantic import Field, ValidationInfo, model_validator from antarest.core.model import JSON from antarest.matrixstore.model import MatrixData @@ -52,32 +50,32 @@ class CreateSTStorage(ICommand): # Overloaded metadata # =================== - command_name = CommandName.CREATE_ST_STORAGE - version = 1 + command_name: CommandName = CommandName.CREATE_ST_STORAGE + version: int = 1 # Command parameters # ================== - area_id: str = Field(description="Area ID", regex=r"[a-z0-9_(),& -]+") + area_id: str = Field(description="Area ID", pattern=r"[a-z0-9_(),& -]+") parameters: STStorageConfigType pmax_injection: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="Charge capacity (modulation)", ) pmax_withdrawal: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="Discharge capacity (modulation)", ) lower_rule_curve: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="Lower rule curve (coefficient)", ) upper_rule_curve: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="Upper rule curve (coefficient)", ) inflows: t.Optional[t.Union[MatrixType, str]] = Field( - None, + default=None, description="Inflows (MW)", ) @@ -91,12 +89,9 @@ def storage_name(self) -> str: """The label representing the name of the storage for the user.""" return self.parameters.name - @validator(*_MATRIX_NAMES, always=True) - def register_matrix( - cls, - v: t.Optional[t.Union[MatrixType, str]], - values: t.Dict[str, t.Any], - field: ModelField, + @staticmethod + def validate_field( + v: t.Optional[t.Union[MatrixType, str]], values: t.Dict[str, t.Any], field: str ) -> t.Optional[t.Union[MatrixType, str]]: """ Validates a matrix array or link, and store the matrix array in the matrix repository. @@ -112,7 +107,7 @@ def register_matrix( Args: v: The matrix array or link to be validated and registered. values: A dictionary containing additional values used for validation. - field: The field being validated. + field: The name of the validated parameter Returns: The ID of the validated and stored matrix prefixed by "matrix://". @@ -134,7 +129,7 @@ def register_matrix( "upper_rule_curve": constants.get_st_storage_upper_rule_curve, "inflows": constants.get_st_storage_inflows, } - method = methods[field.name] + method = methods[field] return method() if isinstance(v, str): # Check the matrix link @@ -148,7 +143,7 @@ def register_matrix( raise ValueError("Matrix values cannot contain NaN") # All matrices except "inflows" are constrained between 0 and 1 constrained = set(_MATRIX_NAMES) - {"inflows"} - if field.name in constrained and (np.any(array < 0) or np.any(array > 1)): + if field in constrained and (np.any(array < 0) or np.any(array > 1)): raise ValueError("Matrix values should be between 0 and 1") v = t.cast(MatrixType, array.tolist()) return validate_matrix(v, values) @@ -156,6 +151,13 @@ def register_matrix( # pragma: no cover raise TypeError(repr(v)) + @model_validator(mode="before") + def validate_matrices(cls, values: t.Union[t.Dict[str, t.Any], ValidationInfo]) -> t.Dict[str, t.Any]: + new_values = values if isinstance(values, dict) else values.data + for field in _MATRIX_NAMES: + new_values[field] = cls.validate_field(new_values.get(field, None), new_values, field) + return new_values + def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ Applies configuration changes to the study data: add the short-term storage in the storages list. @@ -223,14 +225,14 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: Returns: The output of the command execution. """ - output, data = self._apply_config(study_data.config) + output, _ = self._apply_config(study_data.config) if not output.status: return output # Fill-in the "list.ini" file with the parameters. # On creation, it's better to write all the parameters in the file. config = study_data.tree.get(["input", "st-storage", "clusters", self.area_id, "list"]) - config[self.storage_id] = json.loads(self.parameters.json(by_alias=True, exclude={"id"})) + config[self.storage_id] = self.parameters.model_dump(mode="json", by_alias=True, exclude={"id"}) new_data: JSON = { "input": { @@ -252,7 +254,7 @@ def to_dto(self) -> CommandDTO: Returns: The DTO object representing the current command. """ - parameters = json.loads(self.parameters.json(by_alias=True, exclude={"id"})) + parameters = self.parameters.model_dump(mode="json", by_alias=True, exclude={"id"}) return CommandDTO( action=self.command_name.value, args={ @@ -317,7 +319,7 @@ def _create_diff(self, other: "ICommand") -> t.List["ICommand"]: if getattr(self, attr) != getattr(other, attr) ] if self.parameters != other.parameters: - data: t.Dict[str, t.Any] = json.loads(other.parameters.json(by_alias=True, exclude={"id"})) + data: t.Dict[str, t.Any] = other.parameters.model_dump(mode="json", by_alias=True, exclude={"id"}) commands.append( UpdateConfig( target=f"input/st-storage/clusters/{self.area_id}/list/{self.storage_id}", diff --git a/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py b/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py index 4cd62588f5..ad7e5fa863 100644 --- a/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py +++ b/antarest/study/storage/variantstudy/model/command/generate_thermal_cluster_timeseries.py @@ -43,8 +43,8 @@ class GenerateThermalClusterTimeSeries(ICommand): Command used to generate thermal cluster timeseries for an entire study """ - command_name = CommandName.GENERATE_THERMAL_CLUSTER_TIMESERIES - version = 1 + command_name: CommandName = CommandName.GENERATE_THERMAL_CLUSTER_TIMESERIES + version: int = 1 def _apply_config(self, study_data: FileStudyTreeConfig) -> OutputTuple: return CommandOutput(status=True, message="Nothing to do"), {} diff --git a/antarest/study/storage/variantstudy/model/command/icommand.py b/antarest/study/storage/variantstudy/model/command/icommand.py index 85a4222ba4..eb9a1f1285 100644 --- a/antarest/study/storage/variantstudy/model/command/icommand.py +++ b/antarest/study/storage/variantstudy/model/command/icommand.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod import typing_extensions as te -from pydantic import BaseModel, Extra +from pydantic import BaseModel from antarest.core.utils.utils import assert_this from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig @@ -35,7 +35,7 @@ OutputTuple: te.TypeAlias = t.Tuple[CommandOutput, t.Dict[str, t.Any]] -class ICommand(ABC, BaseModel, extra=Extra.forbid, arbitrary_types_allowed=True, copy_on_model_validation="deep"): +class ICommand(ABC, BaseModel, extra="forbid", arbitrary_types_allowed=True): """ Interface for all commands that can be applied to a study. @@ -138,9 +138,9 @@ def match(self, other: "ICommand", equal: bool = False) -> bool: """ if not isinstance(other, self.__class__): return False - excluded_fields = set(ICommand.__fields__) - this_values = self.dict(exclude=excluded_fields) - that_values = other.dict(exclude=excluded_fields) + excluded_fields = set(ICommand.model_fields) + this_values = self.model_dump(exclude=excluded_fields) + that_values = other.model_dump(exclude=excluded_fields) return this_values == that_values @abstractmethod diff --git a/antarest/study/storage/variantstudy/model/command/remove_area.py b/antarest/study/storage/variantstudy/model/command/remove_area.py index 4323830432..3bf3e1856c 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_area.py +++ b/antarest/study/storage/variantstudy/model/command/remove_area.py @@ -33,7 +33,7 @@ class RemoveArea(ICommand): Command used to remove an area. """ - command_name = CommandName.REMOVE_AREA + command_name: CommandName = CommandName.REMOVE_AREA version: int = 1 # Properties of the `REMOVE_AREA` command: diff --git a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py index ecc448001c..ee51c7e641 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/remove_binding_constraint.py @@ -27,7 +27,7 @@ class RemoveBindingConstraint(ICommand): Command used to remove a binding constraint. """ - command_name = CommandName.REMOVE_BINDING_CONSTRAINT + command_name: CommandName = CommandName.REMOVE_BINDING_CONSTRAINT version: int = 1 # Properties of the `REMOVE_BINDING_CONSTRAINT` command: diff --git a/antarest/study/storage/variantstudy/model/command/remove_cluster.py b/antarest/study/storage/variantstudy/model/command/remove_cluster.py index 210fcc61f8..3895b423e9 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/remove_cluster.py @@ -30,8 +30,8 @@ class RemoveCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_THERMAL_CLUSTER - version = 1 + command_name: CommandName = CommandName.REMOVE_THERMAL_CLUSTER + version: int = 1 # Command parameters # ================== @@ -187,7 +187,7 @@ def _remove_cluster_from_binding_constraints(self, study_data: FileStudy) -> Non # Collect the binding constraints that are related to the area to remove # by searching the terms that contain the ID of the area. - bc_to_remove = {} + bc_to_remove = [] lower_area_id = self.area_id.lower() lower_cluster_id = self.cluster_id.lower() for bc_index, bc in list(binding_constraints.items()): @@ -200,16 +200,15 @@ def _remove_cluster_from_binding_constraints(self, study_data: FileStudy) -> Non # noinspection PyTypeChecker related_area_id, related_cluster_id = map(str.lower, key.split(".")) if (lower_area_id, lower_cluster_id) == (related_area_id, related_cluster_id): - bc_to_remove[bc_index] = binding_constraints.pop(bc_index) + bc_to_remove.append(binding_constraints.pop(bc_index)["id"]) break matrix_suffixes = ["_lt", "_gt", "_eq"] if study_data.config.version >= 870 else [""] existing_files = study_data.tree.get(["input", "bindingconstraints"], depth=1) - for bc_index, bc in bc_to_remove.items(): - for suffix in matrix_suffixes: - matrix_id = f"{bc['id']}{suffix}" - if matrix_id in existing_files: - study_data.tree.delete(["input", "bindingconstraints", matrix_id]) + for bc_id, suffix in zip(bc_to_remove, matrix_suffixes): + matrix_id = f"{bc_id}{suffix}" + if matrix_id in existing_files: + study_data.tree.delete(["input", "bindingconstraints", matrix_id]) study_data.tree.save(binding_constraints, url) diff --git a/antarest/study/storage/variantstudy/model/command/remove_district.py b/antarest/study/storage/variantstudy/model/command/remove_district.py index 9d20a53a50..586a827943 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_district.py +++ b/antarest/study/storage/variantstudy/model/command/remove_district.py @@ -27,8 +27,8 @@ class RemoveDistrict(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_DISTRICT - version = 1 + command_name: CommandName = CommandName.REMOVE_DISTRICT + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/remove_link.py b/antarest/study/storage/variantstudy/model/command/remove_link.py index 68289a8b0d..a384ccff58 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_link.py +++ b/antarest/study/storage/variantstudy/model/command/remove_link.py @@ -12,7 +12,7 @@ import typing as t -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig, transform_name_to_id from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy @@ -29,7 +29,7 @@ class RemoveLink(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_LINK + command_name: CommandName = CommandName.REMOVE_LINK version: int = 1 # Command parameters @@ -40,7 +40,7 @@ class RemoveLink(ICommand): area2: str # noinspection PyMethodParameters - @validator("area1", "area2", pre=True) + @field_validator("area1", "area2", mode="before") def _validate_id(cls, area: str) -> str: if isinstance(area, str): # Area IDs must be in lowercase and not empty. @@ -54,16 +54,12 @@ def _validate_id(cls, area: str) -> str: return area # noinspection PyMethodParameters - @root_validator(pre=False) - def _validate_link(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - area1 = values.get("area1") - area2 = values.get("area2") - - if area1 and area2: - # By convention, the source area is always the smallest one (in lexicographic order). - values["area1"], values["area2"] = sorted([area1, area2]) - - return values + @model_validator(mode="after") + def _validate_link(self) -> "RemoveLink": + # By convention, the source area is always the smallest one (in lexicographic order). + if self.area1 > self.area2: + self.area1, self.area2 = self.area2, self.area1 + return self def _check_link_exists(self, study_cfg: FileStudyTreeConfig) -> OutputTuple: """ diff --git a/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py b/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py index 620590c172..834dc1043b 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py +++ b/antarest/study/storage/variantstudy/model/command/remove_renewables_cluster.py @@ -27,8 +27,8 @@ class RemoveRenewablesCluster(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_RENEWABLES_CLUSTER - version = 1 + command_name: CommandName = CommandName.REMOVE_RENEWABLES_CLUSTER + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/remove_st_storage.py b/antarest/study/storage/variantstudy/model/command/remove_st_storage.py index 7a38afe588..550587535d 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_st_storage.py +++ b/antarest/study/storage/variantstudy/model/command/remove_st_storage.py @@ -32,14 +32,14 @@ class RemoveSTStorage(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REMOVE_ST_STORAGE - version = 1 + command_name: CommandName = CommandName.REMOVE_ST_STORAGE + version: int = 1 # Command parameters # ================== - area_id: str = Field(description="Area ID", regex=r"[a-z0-9_(),& -]+") - storage_id: str = Field(description="Short term storage ID", regex=r"[a-z0-9_(),& -]+") + area_id: str = Field(description="Area ID", pattern=r"[a-z0-9_(),& -]+") + storage_id: str = Field(description="Short term storage ID", pattern=r"[a-z0-9_(),& -]+") def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: """ diff --git a/antarest/study/storage/variantstudy/model/command/replace_matrix.py b/antarest/study/storage/variantstudy/model/command/replace_matrix.py index cdc5a74ec0..8a26f12f28 100644 --- a/antarest/study/storage/variantstudy/model/command/replace_matrix.py +++ b/antarest/study/storage/variantstudy/model/command/replace_matrix.py @@ -12,7 +12,7 @@ import typing as t -from pydantic import validator +from pydantic import Field, ValidationInfo, field_validator from antarest.core.exceptions import ChildNotFoundError from antarest.core.model import JSON @@ -35,16 +35,18 @@ class ReplaceMatrix(ICommand): # Overloaded metadata # =================== - command_name = CommandName.REPLACE_MATRIX - version = 1 + command_name: CommandName = CommandName.REPLACE_MATRIX + version: int = 1 # Command parameters # ================== target: str - matrix: t.Union[t.List[t.List[MatrixData]], str] + matrix: t.Union[t.List[t.List[MatrixData]], str] = Field(validate_default=True) - _validate_matrix = validator("matrix", each_item=True, always=True, allow_reuse=True)(validate_matrix) + @field_validator("matrix", mode="before") + def matrix_validator(cls, matrix: t.Union[t.List[t.List[MatrixData]], str], values: ValidationInfo) -> str: + return validate_matrix(matrix, values.data) def _apply_config(self, study_data: FileStudyTreeConfig) -> t.Tuple[CommandOutput, t.Dict[str, t.Any]]: return ( diff --git a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py index 79165fd787..530cc92dfd 100644 --- a/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py +++ b/antarest/study/storage/variantstudy/model/command/update_binding_constraint.py @@ -10,7 +10,6 @@ # # This file is part of the Antares project. -import json import typing as t from antarest.core.model import JSON @@ -102,7 +101,7 @@ class UpdateBindingConstraint(AbstractBindingConstraintCommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_BINDING_CONSTRAINT + command_name: CommandName = CommandName.UPDATE_BINDING_CONSTRAINT version: int = 1 # Command parameters @@ -183,14 +182,14 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: ) study_version = study_data.config.version - props = create_binding_constraint_config(study_version, **self.dict()) - obj = json.loads(props.json(by_alias=True, exclude_unset=True)) + props = create_binding_constraint_config(study_version, **self.model_dump()) + obj = props.model_dump(mode="json", by_alias=True, exclude_unset=True) updated_cfg = binding_constraints[index] updated_cfg.update(obj) - excluded_fields = set(ICommand.__fields__) | {"id"} - updated_properties = self.dict(exclude=excluded_fields, exclude_none=True) + excluded_fields = set(ICommand.model_fields) | {"id"} + updated_properties = self.model_dump(exclude=excluded_fields, exclude_none=True) # This 2nd check is here to remove the last term. if self.coeffs or updated_properties == {"coeffs": {}}: # Remove terms which IDs contain a "%" or a "." in their name @@ -203,8 +202,8 @@ def to_dto(self) -> CommandDTO: matrices = ["values"] + [m.value for m in TermMatrices] matrix_service = self.command_context.matrix_service - excluded_fields = frozenset(ICommand.__fields__) - json_command = json.loads(self.json(exclude=excluded_fields, exclude_none=True)) + excluded_fields = set(ICommand.model_fields) + json_command = self.model_dump(mode="json", exclude=excluded_fields, exclude_none=True) for key in json_command: if key in matrices: json_command[key] = matrix_service.get_matrix_id(json_command[key]) diff --git a/antarest/study/storage/variantstudy/model/command/update_comments.py b/antarest/study/storage/variantstudy/model/command/update_comments.py index a74260e7bb..5a3d57a670 100644 --- a/antarest/study/storage/variantstudy/model/command/update_comments.py +++ b/antarest/study/storage/variantstudy/model/command/update_comments.py @@ -28,8 +28,8 @@ class UpdateComments(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_COMMENTS - version = 1 + command_name: CommandName = CommandName.UPDATE_COMMENTS + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_config.py b/antarest/study/storage/variantstudy/model/command/update_config.py index fe859a0ea7..067b0ecba1 100644 --- a/antarest/study/storage/variantstudy/model/command/update_config.py +++ b/antarest/study/storage/variantstudy/model/command/update_config.py @@ -44,8 +44,8 @@ class UpdateConfig(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_CONFIG - version = 1 + command_name: CommandName = CommandName.UPDATE_CONFIG + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_district.py b/antarest/study/storage/variantstudy/model/command/update_district.py index 4b6f7f2f97..e0d63dfafd 100644 --- a/antarest/study/storage/variantstudy/model/command/update_district.py +++ b/antarest/study/storage/variantstudy/model/command/update_district.py @@ -28,17 +28,17 @@ class UpdateDistrict(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_DISTRICT - version = 1 + command_name: CommandName = CommandName.UPDATE_DISTRICT + version: int = 1 # Command parameters # ================== id: str - base_filter: Optional[DistrictBaseFilter] - filter_items: Optional[List[str]] - output: Optional[bool] - comments: Optional[str] + base_filter: Optional[DistrictBaseFilter] = None + filter_items: Optional[List[str]] = None + output: Optional[bool] = None + comments: Optional[str] = None def _apply_config(self, study_data: FileStudyTreeConfig) -> Tuple[CommandOutput, Dict[str, Any]]: base_set = study_data.sets[self.id] diff --git a/antarest/study/storage/variantstudy/model/command/update_playlist.py b/antarest/study/storage/variantstudy/model/command/update_playlist.py index 8841f6d3b5..52f6f70c5c 100644 --- a/antarest/study/storage/variantstudy/model/command/update_playlist.py +++ b/antarest/study/storage/variantstudy/model/command/update_playlist.py @@ -28,8 +28,8 @@ class UpdatePlaylist(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_PLAYLIST - version = 1 + command_name: CommandName = CommandName.UPDATE_PLAYLIST + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_raw_file.py b/antarest/study/storage/variantstudy/model/command/update_raw_file.py index 91178ebb1f..1a3414f90b 100644 --- a/antarest/study/storage/variantstudy/model/command/update_raw_file.py +++ b/antarest/study/storage/variantstudy/model/command/update_raw_file.py @@ -29,8 +29,8 @@ class UpdateRawFile(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_FILE - version = 1 + command_name: CommandName = CommandName.UPDATE_FILE + version: int = 1 # Command parameters # ================== diff --git a/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py b/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py index e3729cd220..86a127f776 100644 --- a/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py +++ b/antarest/study/storage/variantstudy/model/command/update_scenario_builder.py @@ -13,8 +13,8 @@ import typing as t import numpy as np -from requests.structures import CaseInsensitiveDict +from antarest.core.requests import CaseInsensitiveDict from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput @@ -45,13 +45,13 @@ class UpdateScenarioBuilder(ICommand): # Overloaded metadata # =================== - command_name = CommandName.UPDATE_SCENARIO_BUILDER - version = 1 + command_name: CommandName = CommandName.UPDATE_SCENARIO_BUILDER + version: int = 1 # Command parameters # ================== - data: t.Dict[str, t.Any] + data: t.Union[t.Dict[str, t.Any], t.Mapping[str, t.Any], t.MutableMapping[str, t.Any]] def _apply(self, study_data: FileStudy) -> CommandOutput: """ diff --git a/antarest/study/storage/variantstudy/model/command_context.py b/antarest/study/storage/variantstudy/model/command_context.py index a3e913355e..5996e63528 100644 --- a/antarest/study/storage/variantstudy/model/command_context.py +++ b/antarest/study/storage/variantstudy/model/command_context.py @@ -24,4 +24,3 @@ class CommandContext(BaseModel): class Config: arbitrary_types_allowed = True - copy_on_model_validation = False diff --git a/antarest/study/storage/variantstudy/model/model.py b/antarest/study/storage/variantstudy/model/model.py index 66add039af..0be3c75353 100644 --- a/antarest/study/storage/variantstudy/model/model.py +++ b/antarest/study/storage/variantstudy/model/model.py @@ -69,7 +69,7 @@ class CommandDTO(BaseModel): version: The version of the command. """ - id: t.Optional[str] + id: t.Optional[str] = None action: str args: t.Union[t.MutableSequence[JSON], JSON] version: int = 1 diff --git a/antarest/study/storage/variantstudy/snapshot_generator.py b/antarest/study/storage/variantstudy/snapshot_generator.py index 6353f723d1..086c6d3952 100644 --- a/antarest/study/storage/variantstudy/snapshot_generator.py +++ b/antarest/study/storage/variantstudy/snapshot_generator.py @@ -133,7 +133,7 @@ def generate_snapshot( else: try: - notifier(results.json()) + notifier(results.model_dump_json()) except Exception as exc: # This exception is ignored, because it is not critical. logger.warning(f"Error while sending notification: {exc}", exc_info=True) @@ -203,7 +203,7 @@ def _read_additional_data(self, file_study: FileStudy) -> StudyAdditionalData: horizon = file_study.tree.get(url=["settings", "generaldata", "general", "horizon"]) author = file_study.tree.get(url=["study", "antares", "author"]) patch = self.patch_service.get_from_filestudy(file_study) - study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.json()) + study_additional_data = StudyAdditionalData(horizon=horizon, author=author, patch=patch.model_dump_json()) return study_additional_data def _update_cache(self, file_study: FileStudy) -> None: @@ -211,7 +211,7 @@ def _update_cache(self, file_study: FileStudy) -> None: self.cache.invalidate(f"{CacheConstants.RAW_STUDY}/{file_study.config.study_id}") self.cache.put( f"{CacheConstants.STUDY_FACTORY}/{file_study.config.study_id}", - FileStudyTreeConfigDTO.from_build_config(file_study.config).dict(), + FileStudyTreeConfigDTO.from_build_config(file_study.config).model_dump(), ) diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index ac718bb518..66273f54a0 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -627,7 +627,7 @@ def callback(notifier: TaskUpdateNotifier) -> TaskResult: message=( f"{study_id} generated successfully" if generate_result.success else f"{study_id} not generated" ), - return_value=generate_result.json(), + return_value=generate_result.model_dump_json(), ) metadata.generation_task = self.task_service.add_task( @@ -718,7 +718,7 @@ def notify(command_index: int, command_result: bool, command_message: str) -> No success=command_result, message=command_message, ) - notifier(command_result_obj.json()) + notifier(command_result_obj.model_dump_json()) self.event_bus.push( Event( type=EventType.STUDY_VARIANT_GENERATION_COMMAND_RESULT, diff --git a/antarest/study/web/studies_blueprint.py b/antarest/study/web/studies_blueprint.py index 761e7d6b05..f532ac3051 100644 --- a/antarest/study/web/studies_blueprint.py +++ b/antarest/study/web/studies_blueprint.py @@ -727,7 +727,7 @@ def output_download( extra={"user": current_user.id}, ) params = RequestParameters(user=current_user) - accept = request.headers.get("Accept") + accept = request.headers["Accept"] filetype = ExportFormat.from_dto(accept) content = study_service.download_outputs( diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index 549cfdf221..9b1cca0fc5 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -29,7 +29,7 @@ from antarest.matrixstore.matrix_editor import MatrixEditInstruction from antarest.study.business.adequacy_patch_management import AdequacyPatchFormFields from antarest.study.business.advanced_parameters_management import AdvancedParamsFormFields -from antarest.study.business.allocation_management import AllocationFormFields, AllocationMatrix +from antarest.study.business.allocation_management import AllocationField, AllocationFormFields, AllocationMatrix from antarest.study.business.area_management import AreaCreationDTO, AreaInfoDTO, AreaType, LayerInfoDTO, UpdateAreaUi from antarest.study.business.areas.hydro_management import InflowStructure, ManagementOptionsFormFields from antarest.study.business.areas.properties_management import PropertiesFormFields @@ -60,7 +60,12 @@ ConstraintOutput, ConstraintTerm, ) -from antarest.study.business.correlation_management import CorrelationFormFields, CorrelationManager, CorrelationMatrix +from antarest.study.business.correlation_management import ( + AreaCoefficientItem, + CorrelationFormFields, + CorrelationManager, + CorrelationMatrix, +) from antarest.study.business.district_manager import DistrictCreationDTO, DistrictInfoDTO, DistrictUpdateDTO from antarest.study.business.general_management import GeneralFormFields from antarest.study.business.link_management import LinkInfoDTO @@ -122,7 +127,7 @@ def create_study_data_routes(study_service: StudyService, config: Config) -> API "/studies/{uuid}/areas", tags=[APITag.study_data], summary="Get all areas basic info", - response_model=t.Union[t.List[AreaInfoDTO], t.Dict[str, t.Any]], # type: ignore + response_model=t.Union[t.List[AreaInfoDTO], t.Dict[str, t.Any]], ) def get_areas( uuid: str, @@ -503,7 +508,6 @@ def get_inflow_structure( "/studies/{uuid}/areas/{area_id}/hydro/inflow-structure", tags=[APITag.study_data], summary="Update inflow structure values", - response_model=InflowStructure, ) def update_inflow_structure( uuid: str, @@ -518,7 +522,7 @@ def update_inflow_structure( ) params = RequestParameters(user=current_user) study = study_service.check_study_access(uuid, StudyPermissionType.WRITE, params) - return study_service.hydro_manager.update_inflow_structure(study, area_id, values) + study_service.hydro_manager.update_inflow_structure(study, area_id, values) @bp.put( "/studies/{uuid}/matrix", @@ -1144,7 +1148,7 @@ def get_binding_constraint_list( "/studies/{uuid}/bindingconstraints/{binding_constraint_id}", tags=[APITag.study_data], summary="Get binding constraint", - response_model=ConstraintOutput, # type: ignore + response_model=ConstraintOutput, # TODO: redundant ? ) def get_binding_constraint( uuid: str, @@ -1532,8 +1536,8 @@ def set_allocation_form_fields( ..., example=AllocationFormFields( allocation=[ - {"areaId": "EAST", "coefficient": 1}, - {"areaId": "NORTH", "coefficient": 0.20}, + AllocationField.model_validate({"areaId": "EAST", "coefficient": 1}), + AllocationField.model_validate({"areaId": "NORTH", "coefficient": 0.20}), ] ), ), @@ -1565,8 +1569,8 @@ def set_allocation_form_fields( def get_correlation_matrix( uuid: str, columns: t.Optional[str] = Query( - None, - examples={ + default=None, + openapi_examples={ "all areas": { "description": "get the correlation matrix for all areas (by default)", "value": "", @@ -1698,8 +1702,8 @@ def set_correlation_form_fields( ..., example=CorrelationFormFields( correlation=[ - {"areaId": "east", "coefficient": 80}, - {"areaId": "north", "coefficient": 20}, + AreaCoefficientItem.model_validate({"areaId": "east", "coefficient": 80}), + AreaCoefficientItem.model_validate({"areaId": "north", "coefficient": 20}), ] ), ), diff --git a/antarest/study/web/variant_blueprint.py b/antarest/study/web/variant_blueprint.py index 98df4ffc2d..6acfed4250 100644 --- a/antarest/study/web/variant_blueprint.py +++ b/antarest/study/web/variant_blueprint.py @@ -107,12 +107,7 @@ def create_variant( "/studies/{uuid}/variants", tags=[APITag.study_variant_management], summary="Get children variants", - responses={ - 200: { - "description": "The list of children study variant", - "model": List[StudyMetadataDTO], - } - }, + response_model=None, # To cope with recursive models issues ) def get_variants( uuid: str, diff --git a/antarest/study/web/xpansion_studies_blueprint.py b/antarest/study/web/xpansion_studies_blueprint.py index 4044a4b9d4..93b60f01ff 100644 --- a/antarest/study/web/xpansion_studies_blueprint.py +++ b/antarest/study/web/xpansion_studies_blueprint.py @@ -139,7 +139,7 @@ def add_candidate( current_user: JWTUser = Depends(auth.get_current_user), ) -> XpansionCandidateDTO: logger.info( - f"Adding new candidate {xpansion_candidate_dto.dict(by_alias=True)} to study {uuid}", + f"Adding new candidate {xpansion_candidate_dto.model_dump(by_alias=True)} to study {uuid}", extra={"user": current_user.id}, ) params = RequestParameters(user=current_user) diff --git a/antarest/tools/cli.py b/antarest/tools/cli.py index d1b91d522b..85eef22128 100644 --- a/antarest/tools/cli.py +++ b/antarest/tools/cli.py @@ -15,10 +15,11 @@ from typing import Optional import click +from httpx import Client from antarest.study.model import NEW_DEFAULT_STUDY_VERSION from antarest.study.storage.study_upgrader import StudyUpgrader -from antarest.tools.lib import extract_commands, generate_diff, generate_study +from antarest.tools.lib import create_http_client, extract_commands, generate_diff, generate_study @click.group(context_settings={"max_content_width": 120}) @@ -43,6 +44,12 @@ def commands() -> None: type=str, help="Authentication token if server needs one", ) +@click.option( + "--no-verify", + is_flag=True, + default=False, + help="Disables SSL certificate verification", +) @click.option( "--output", "-o", @@ -82,6 +89,7 @@ def cli_apply_script( output: Optional[str], host: Optional[str], auth_token: Optional[str], + no_verify: bool, version: str, ) -> None: """Apply a variant script onto an AntaresWeb study variant""" @@ -95,7 +103,10 @@ def cli_apply_script( print("--study_id must be set") exit(1) - res = generate_study(Path(input), study_id, output, host, auth_token, version) + client = None + if host: + client = create_http_client(verify=not no_verify, auth_token=auth_token) + res = generate_study(Path(input), study_id, output, host, client, version) print(res) diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index b930fd87f4..6bfd3df997 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -20,17 +20,7 @@ from zipfile import ZipFile import numpy as np - -try: - # The HTTPX equivalent of `requests.Session` is `httpx.Client`. - import httpx as requests - from httpx import Client as Session -except ImportError: - # noinspection PyUnresolvedReferences, PyPackageRequirements - import requests - - # noinspection PyUnresolvedReferences,PyPackageRequirements - from requests import Session +from httpx import Client from antarest.core.cache.business.local_chache import LocalCache from antarest.core.config import CacheConfig @@ -61,34 +51,28 @@ def apply_commands(self, commands: List[CommandDTO], matrices_dir: Path) -> Gene raise NotImplementedError() +def set_auth_token(client: Client, auth_token: Optional[str] = None) -> Client: + if auth_token is not None: + client.headers.update({"Authorization": f"Bearer {auth_token}"}) + return client + + +def create_http_client(verify: bool, auth_token: Optional[str] = None) -> Client: + client = Client(verify=verify) + set_auth_token(client, auth_token) + return client + + class RemoteVariantGenerator(IVariantGenerator): def __init__( self, study_id: str, - host: Optional[str] = None, - token: Optional[str] = None, - session: Optional[Session] = None, + host: str, + session: Client, ): self.study_id = study_id - - # todo: find the correct way to handle certificates. - # By default, Requests/Httpx verifies SSL certificates for HTTPS requests. - # When verify is set to `False`, requests will accept any TLS certificate presented - # by the server,and will ignore hostname mismatches and/or expired certificates, - # which will make your application vulnerable to man-in-the-middle (MitM) attacks. - # Setting verify to False may be useful during local development or testing. - if Session.__name__ == "Client": - # noinspection PyArgumentList - self.session = session or Session(verify=False) - else: - self.session = session or Session() - self.session.verify = False - + self.session = session self.host = host - if session is None and host is None: - raise ValueError("Missing either session or host") - if token is not None: - self.session.headers.update({"Authorization": f"Bearer {token}"}) def apply_commands( self, @@ -114,7 +98,7 @@ def apply_commands( res = self.session.post( self.build_url(f"/v1/studies/{self.study_id}/commands"), - json=[command.dict() for command in commands], + json=[command.model_dump() for command in commands], ) res.raise_for_status() stopwatch.log_elapsed(lambda x: logger.info(f"Command upload done in {x}s")) @@ -224,7 +208,7 @@ def extract_commands(study_path: Path, commands_output_dir: Path) -> None: (commands_output_dir / COMMAND_FILE).write_text( json.dumps( - [command.dict(exclude={"id"}) for command in command_list], + [command.model_dump(exclude={"id"}) for command in command_list], indent=2, ) ) @@ -316,7 +300,7 @@ def generate_diff( (output_dir / COMMAND_FILE).write_text( json.dumps( - [command.to_dto().dict(exclude={"id"}) for command in diff_commands], + [command.to_dto().model_dump(exclude={"id"}) for command in diff_commands], indent=2, ) ) @@ -348,7 +332,7 @@ def generate_study( study_id: Optional[str], output: Optional[str] = None, host: Optional[str] = None, - token: Optional[str] = None, + session: Optional[Client] = None, study_version: str = NEW_DEFAULT_STUDY_VERSION, ) -> GenerationResultInfoDTO: """ @@ -362,7 +346,7 @@ def generate_study( If `study_id` and `host` are not provided, this must be specified. host: The URL of the Antares server to use for generating the new study. If `study_id` is not provided, this is ignored. - token: The authentication token to use when connecting to the Antares server. + session: The session to use when connecting to the Antares server. If `host` is not provided, this is ignored. study_version: The target version of the generated study. @@ -370,8 +354,9 @@ def generate_study( GenerationResultInfoDTO: A data transfer object containing information about the generation result. """ generator: Union[RemoteVariantGenerator, LocalVariantGenerator] - if study_id is not None and host is not None: - generator = RemoteVariantGenerator(study_id, host, token) + + if study_id is not None and host is not None and session is not None: + generator = RemoteVariantGenerator(study_id, host, session) elif output is None: raise TypeError("'output' must be set") else: diff --git a/antarest/utils.py b/antarest/utils.py index eb2e5dc479..1f7f4b7594 100644 --- a/antarest/utils.py +++ b/antarest/utils.py @@ -16,17 +16,15 @@ from typing import Any, Dict, Mapping, Optional, Tuple import redis -import sqlalchemy.ext.baked # type: ignore -import uvicorn # type: ignore -from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT # type: ignore +from fastapi import APIRouter, FastAPI from ratelimit import RateLimitMiddleware # type: ignore from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore -from sqlalchemy import create_engine +from sqlalchemy import create_engine # type: ignore from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.pool import NullPool # type: ignore +from antarest.core.application import AppBuildContext from antarest.core.cache.main import build_cache from antarest.core.config import Config from antarest.core.filetransfer.main import build_filetransfer_service @@ -112,24 +110,24 @@ def init_db_engine( return engine -def create_event_bus(application: Optional[FastAPI], config: Config) -> Tuple[IEventBus, Optional[redis.Redis]]: # type: ignore +def create_event_bus(app_ctxt: Optional[AppBuildContext], config: Config) -> Tuple[IEventBus, Optional[redis.Redis]]: # type: ignore redis_client = new_redis_instance(config.redis) if config.redis is not None else None return ( - build_eventbus(application, config, True, redis_client), + build_eventbus(app_ctxt, config, True, redis_client), redis_client, ) def create_core_services( - application: Optional[FastAPI], config: Config + app_ctxt: Optional[AppBuildContext], config: Config ) -> Tuple[ICache, IEventBus, ITaskService, FileTransferManager, LoginService, MatrixService, StudyService,]: - event_bus, redis_client = create_event_bus(application, config) + event_bus, redis_client = create_event_bus(app_ctxt, config) cache = build_cache(config=config, redis_client=redis_client) - filetransfer_service = build_filetransfer_service(application, event_bus, config) - task_service = build_taskjob_manager(application, config, event_bus) - login_service = build_login(application, config, event_bus=event_bus) + filetransfer_service = build_filetransfer_service(app_ctxt, event_bus, config) + task_service = build_taskjob_manager(app_ctxt, config, event_bus) + login_service = build_login(app_ctxt, config, event_bus=event_bus) matrix_service = build_matrix_service( - application, + app_ctxt, config=config, file_transfer_manager=filetransfer_service, task_service=task_service, @@ -137,7 +135,7 @@ def create_core_services( service=None, ) study_service = build_study_service( - application, + app_ctxt, config, matrix_service=matrix_service, cache=cache, @@ -159,7 +157,7 @@ def create_core_services( def create_watcher( config: Config, - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], study_service: Optional[StudyService] = None, ) -> Watcher: if study_service: @@ -169,22 +167,22 @@ def create_watcher( task_service=study_service.task_service, ) else: - _, _, task_service, _, _, _, study_service = create_core_services(application, config) + _, _, task_service, _, _, _, study_service = create_core_services(app_ctxt, config) watcher = Watcher( config=config, study_service=study_service, task_service=task_service, ) - if application: - application.include_router(create_watcher_routes(watcher=watcher, config=config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_watcher_routes(watcher=watcher, config=config)) return watcher def create_matrix_gc( config: Config, - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], study_service: Optional[StudyService] = None, matrix_service: Optional[MatrixService] = None, ) -> MatrixGarbageCollector: @@ -195,7 +193,7 @@ def create_matrix_gc( matrix_service=matrix_service, ) else: - _, _, _, _, _, matrix_service, study_service = create_core_services(application, config) + _, _, _, _, _, matrix_service, study_service = create_core_services(app_ctxt, config) return MatrixGarbageCollector( config=config, study_service=study_service, @@ -224,7 +222,7 @@ def create_simulator_worker( return SimulatorWorker(event_bus, matrix_service, config) -def create_services(config: Config, application: Optional[FastAPI], create_all: bool = False) -> Dict[str, Any]: +def create_services(config: Config, app_ctxt: Optional[AppBuildContext], create_all: bool = False) -> Dict[str, Any]: services: Dict[str, Any] = {} ( @@ -235,12 +233,12 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: user_service, matrix_service, study_service, - ) = create_core_services(application, config) + ) = create_core_services(app_ctxt, config) - maintenance_service = build_maintenance_manager(application, config=config, cache=cache, event_bus=event_bus) + maintenance_service = build_maintenance_manager(app_ctxt, config=config, cache=cache, event_bus=event_bus) launcher = build_launcher( - application, + app_ctxt, config, study_service=study_service, event_bus=event_bus, @@ -249,13 +247,13 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: cache=cache, ) - watcher = create_watcher(config=config, application=application, study_service=study_service) + watcher = create_watcher(config=config, app_ctxt=app_ctxt, study_service=study_service) services["watcher"] = watcher if config.server.services and Module.MATRIX_GC.value in config.server.services or create_all: matrix_garbage_collector = create_matrix_gc( config=config, - application=application, + app_ctxt=app_ctxt, study_service=study_service, matrix_service=matrix_service, ) diff --git a/antarest/worker/archive_worker.py b/antarest/worker/archive_worker.py index f0adbf481c..4fbc6a0631 100644 --- a/antarest/worker/archive_worker.py +++ b/antarest/worker/archive_worker.py @@ -58,10 +58,10 @@ def __init__( ) def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: - logger.info(f"Executing task {task_info.json()}") + logger.info(f"Executing task {task_info.model_dump_json()}") try: # sourcery skip: extract-method - archive_args = ArchiveTaskArgs.parse_obj(task_info.task_args) + archive_args = ArchiveTaskArgs.model_validate(task_info.task_args) dest = self.translate_path(Path(archive_args.dest)) src = self.translate_path(Path(archive_args.src)) stopwatch = StopWatch() @@ -75,7 +75,7 @@ def _execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return TaskResult(success=True, message="") except Exception as e: logger.warning( - f"Task {task_info.json()} failed", + f"Task {task_info.model_dump_json()} failed", exc_info=e, ) return TaskResult(success=False, message=str(e)) diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py index 7397e0ae4f..583f340e7a 100644 --- a/antarest/worker/simulator_worker.py +++ b/antarest/worker/simulator_worker.py @@ -81,7 +81,7 @@ def execute_kirshoff_constraint_generation_task(self, task_info: WorkerTaskComma def execute_timeseries_generation_task(self, task_info: WorkerTaskCommand) -> TaskResult: result = TaskResult(success=True, message="", return_value="") - task = GenerateTimeseriesTaskArgs.parse_obj(task_info.task_args) + task = GenerateTimeseriesTaskArgs.model_validate(task_info.task_args) binary = ( self.binaries[task.study_version] if task.study_version in self.binaries diff --git a/antarest/worker/worker.py b/antarest/worker/worker.py index 922f10e0b2..7dc2534764 100644 --- a/antarest/worker/worker.py +++ b/antarest/worker/worker.py @@ -112,8 +112,8 @@ def _loop(self) -> None: time.sleep(1) async def _listen_for_tasks(self, event: Event) -> None: - logger.info(f"Accepting new task {event.json()}") - task_info = WorkerTaskCommand.parse_obj(event.payload) + logger.info(f"Accepting new task {event.model_dump_json()}") + task_info = WorkerTaskCommand.model_validate(event.payload) self.event_bus.push( Event( type=EventType.WORKER_TASK_STARTED, @@ -131,7 +131,7 @@ def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: return self._execute_task(task_info) except Exception as e: logger.error( - f"Unexpected error occurred when executing task {task_info.json()}", + f"Unexpected error occurred when executing task {task_info.model_dump_json()}", exc_info=e, ) return TaskResult(success=False, message=repr(e)) diff --git a/pyproject.toml b/pyproject.toml index 1d01065859..d8210f8981 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,14 @@ where = ["."] include = ["antarest*"] [tool.mypy] +exclude = "antarest/fastapi_jwt_auth/*" strict = true -files = "antarest/**/*.py" +files = "antarest" +plugins = "pydantic.mypy" + +[[tool.mypy.overrides]] +module = ["antarest/fastapi_jwt_auth.*"] +follow_imports = "skip" [[tool.mypy.overrides]] module = [ @@ -71,7 +77,7 @@ line-length = 120 exclude = "(antares-?launcher/*|alembic/*)" [tool.coverage.run] -omit = ["antarest/tools/cli.py", "antarest/tools/admin.py"] +omit = ["antarest/tools/cli.py", "antarest/tools/admin.py", "antarest/fastapi_jwt_auth/*.py"] relative_files = true # avoids absolute path issues in CI [tool.isort] diff --git a/requirements-dev.txt b/requirements-dev.txt index 5ea6126c6c..973973dc76 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # Version of Black should match the versions set in `.github/workflows/main.yml` black~=23.7.0 isort~=5.12.0 -mypy~=1.4.1 +mypy~=1.11.1 pyinstaller==5.6.2 pyinstaller-hooks-contrib==2024.6 diff --git a/requirements-test.txt b/requirements-test.txt index 10a360d7f4..e72650560c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,4 +7,5 @@ pytest-mock~=3.14.0 # In this version DataFrame conversion to Excel is done using 'xlsxwriter' library. # But Excel files reading is done using 'openpyxl' library, during testing only. -openpyxl~=3.1.2 \ No newline at end of file +openpyxl~=3.1.2 +jinja2~=3.1.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eb42c65793..4659b879ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,17 +2,34 @@ Antares-Launcher~=1.3.2 antares-study-version~=1.0.3 antares-timeseries-generation~=0.1.5 +# When you install `fastapi[all]`, you get FastAPI along with additional dependencies: +# - `uvicorn`: A fast ASGI server commonly used to run FastAPI applications. +# - `pydantic`: A data validation library integrated into FastAPI for defining data models and validating input and output of endpoints. +# - `httpx`: A modern HTTP client library for making HTTP requests from your FastAPI application. +# - `starlette`: The underlying ASGI framework used by FastAPI for handling HTTP requests. +# - `uvloop` (on certain systems): A fast event loop based on uvicorn that improves FastAPI application performance. +# - `python-multipart`: A library for handling multipart data in HTTP requests, commonly used for processing file uploads. +# - `watchdog`: A file watching library used by FastAPI's automatic reloading tool to update the application in real-time when files are modified. +# - `email-validator`: A library for email address validation, used for validating email fields in Pydantic models. +# - `python-dotenv`: A library for managing environment variables, commonly used to load application configurations from `.env` files. + +# We prefer to add only the specific libraries we need for our project +# and **manage their versions** for better control and to avoid unnecessary dependencies. +fastapi~=0.110.3 +uvicorn[standard]~=0.30.6 +pydantic~=2.8.2 +httpx~=0.27.0 +python-multipart~=0.0.9 + alembic~=1.7.5 asgi-ratelimit[redis]==0.7.0 bcrypt~=3.2.0 click~=8.0.3 contextvars~=2.4 -fastapi-jwt-auth~=0.5.0 -fastapi[all]~=0.73.0 filelock~=3.4.2 gunicorn~=20.1.0 -Jinja2~=3.0.3 jsonref~=0.2 +PyJWT~=2.9.0 MarkupSafe~=2.0.1 numpy~=1.22.1 pandas~=1.4.0 @@ -20,18 +37,13 @@ paramiko~=2.12.0 plyer~=2.0.0 psycopg2-binary==2.9.4 py7zr~=0.20.6 -pydantic~=1.9.0 PyQt5~=5.15.6 python-json-logger~=2.0.7 -python-multipart~=0.0.5 PyYAML~=5.4.1; python_version <= '3.9' PyYAML~=5.3.1; python_version > '3.9' redis~=4.1.2 -requests~=2.27.1 SQLAlchemy~=1.4.46 -starlette~=0.17.1 tables==3.6.1; python_version <= '3.8' tables==3.9.2; python_version > '3.8' -typing_extensions~=4.7.1 -uvicorn[standard]~=0.15.0 -xlsxwriter~=3.2.0 +typing_extensions~=4.12.2 +xlsxwriter~=3.2.0 \ No newline at end of file diff --git a/resources/application.yaml b/resources/application.yaml index 6fbdb31f9f..c962d09015 100644 --- a/resources/application.yaml +++ b/resources/application.yaml @@ -26,7 +26,19 @@ launcher: 700: path/to/700 enable_nb_cores_detection: true -root_path: "api" +# See https://fastapi.tiangolo.com/advanced/behind-a-proxy/ +# root path is used when the API is served behind a proxy which +# adds a prefix for clients. +# It does NOT add any prefix to the URLs which fastapi serve. + +# root_path: "api" + + +# Uncomment to serve the API under /api prefix +# (used in standalone mode to emulate the effect of proxy servers +# used in production deployments). + +# api_prefix: "/api" server: worker_threadpool_size: 12 @@ -36,4 +48,7 @@ server: logging: level: INFO - logfile: ./tmp/antarest.log \ No newline at end of file + logfile: ./tmp/antarest.log + +# True to get sqlalchemy logs +debug: False diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml index cf9087a2af..9d8d78d56e 100644 --- a/resources/deploy/config.prod.yaml +++ b/resources/deploy/config.prod.yaml @@ -66,6 +66,10 @@ launcher: debug: false +# See https://fastapi.tiangolo.com/advanced/behind-a-proxy/ +# root path is used when the API is served behind a proxy which +# adds a prefix for clients. +# It does NOT add any prefix to the URLs which fastapi serve. root_path: "api" #tasks: diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml index 6f3fdf595f..08831198ce 100644 --- a/resources/deploy/config.yaml +++ b/resources/deploy/config.yaml @@ -46,7 +46,8 @@ launcher: debug: false -root_path: "api" +# Serve the API at /api +api_prefix: "/api" server: worker_threadpool_size: 12 diff --git a/resources/templates/.placeholder b/resources/templates/.placeholder deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/scripts/build-front.sh b/scripts/build-front.sh index 87c2db6a50..15c2e0e60f 100755 --- a/scripts/build-front.sh +++ b/scripts/build-front.sh @@ -11,4 +11,3 @@ npm run build -- --mode=desktop cd .. rm -fr resources/webapp cp -r ./webapp/dist/ resources/webapp -cp ./webapp/dist/index.html resources/templates/ diff --git a/sonar-project.properties b/sonar-project.properties index 07d21a096d..a42bdb7b79 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -7,4 +7,4 @@ sonar.python.coverage.reportPaths=coverage.xml sonar.python.version=3.8 sonar.javascript.lcov.reportPaths=webapp/coverage/lcov.info sonar.projectVersion=2.17.5 -sonar.coverage.exclusions=antarest/gui.py,antarest/main.py,antarest/singleton_services.py,antarest/worker/archive_worker_service.py,webapp/**/* \ No newline at end of file +sonar.coverage.exclusions=antarest/gui.py,antarest/main.py,antarest/singleton_services.py,antarest/worker/archive_worker_service.py,webapp/**/*,,antarest/fastapi_jwt_auth/** \ No newline at end of file diff --git a/tests/cache/test_local_cache.py b/tests/cache/test_local_cache.py index feb066eef4..8f8643d2f4 100644 --- a/tests/cache/test_local_cache.py +++ b/tests/cache/test_local_cache.py @@ -41,11 +41,11 @@ def test_lifecycle(): id = "some_id" duration = 3600 timeout = int(time.time()) + duration - cache_element = LocalCacheElement(duration=duration, data=config.dict(), timeout=timeout) + cache_element = LocalCacheElement(duration=duration, data=config.model_dump(), timeout=timeout) # PUT - cache.put(id=id, data=config.dict(), duration=duration) + cache.put(id=id, data=config.model_dump(), duration=duration) assert cache.cache[id] == cache_element # GET - assert cache.get(id=id) == config.dict() + assert cache.get(id=id) == config.model_dump() diff --git a/tests/cache/test_redis_cache.py b/tests/cache/test_redis_cache.py index c5e4e0b547..211de6c689 100644 --- a/tests/cache/test_redis_cache.py +++ b/tests/cache/test_redis_cache.py @@ -40,7 +40,7 @@ def test_lifecycle(): id = "some_id" redis_key = f"cache:{id}" duration = 3600 - cache_element = RedisCacheElement(duration=duration, data=config.dict()).json() + cache_element = RedisCacheElement(duration=duration, data=config.model_dump()).model_dump_json() # GET redis_client.get.return_value = cache_element @@ -51,7 +51,7 @@ def test_lifecycle(): # PUT duration = 7200 - cache_element = RedisCacheElement(duration=duration, data=config.dict()).json() - cache.put(id=id, data=config.dict(), duration=duration) + cache_element = RedisCacheElement(duration=duration, data=config.model_dump()).model_dump_json() + cache.put(id=id, data=config.model_dump(), duration=duration) redis_client.set.assert_called_once_with(redis_key, cache_element) redis_client.expire.assert_called_with(redis_key, duration) diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 958a3946f6..89bc2ddfc2 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -119,7 +119,7 @@ def test_service(core_config: Config, event_bus: IEventBus) -> None: "status": TaskStatus.FAILED, "type": None, } - assert res.dict() == expected + assert res.model_dump() == expected # Test Case: add a task that fails and wait for it # ================================================ diff --git a/tests/eventbus/test_redis_event_bus.py b/tests/eventbus/test_redis_event_bus.py index ac130f1d81..37f69e1fad 100644 --- a/tests/eventbus/test_redis_event_bus.py +++ b/tests/eventbus/test_redis_event_bus.py @@ -29,7 +29,7 @@ def test_lifecycle(): payload="foo", permissions=PermissionInfo(public_mode=PublicMode.READ), ) - serialized = event.json() + serialized = event.model_dump_json() pubsub_mock.get_message.return_value = {"data": serialized} eventbus.push_event(event) redis_client.publish.assert_called_once_with("events", serialized) diff --git a/tests/eventbus/test_websocket_manager.py b/tests/eventbus/test_websocket_manager.py index defff4e774..94500de322 100644 --- a/tests/eventbus/test_websocket_manager.py +++ b/tests/eventbus/test_websocket_manager.py @@ -37,7 +37,7 @@ async def test_subscriptions(self): await ws_manager.connect(mock_connection, user) assert len(ws_manager.active_connections) == 1 - ws_manager.process_message(subscribe_message.json(), mock_connection) + ws_manager.process_message(subscribe_message.model_dump_json(), mock_connection) connections = ws_manager.active_connections[0] assert len(connections.channel_subscriptions) == 1 assert connections.channel_subscriptions[0] == "foo" @@ -51,7 +51,7 @@ async def test_subscriptions(self): mock_connection.send_text.assert_has_calls([call("msg1"), call("msg2")]) - ws_manager.process_message(unsubscribe_message.json(), mock_connection) + ws_manager.process_message(unsubscribe_message.model_dump_json(), mock_connection) assert len(connections.channel_subscriptions) == 0 ws_manager.disconnect(mock_connection) diff --git a/tests/integration/filesystem_blueprint/test_model.py b/tests/integration/filesystem_blueprint/test_model.py index 48742b499b..d26bdb02cb 100644 --- a/tests/integration/filesystem_blueprint/test_model.py +++ b/tests/integration/filesystem_blueprint/test_model.py @@ -30,7 +30,7 @@ def test_init(self) -> None: "common": "/path/to/workspaces/common_studies", }, } - dto = FilesystemDTO.parse_obj(example) + dto = FilesystemDTO.model_validate(example) assert dto.name == example["name"] assert dto.mount_dirs["default"] == Path(example["mount_dirs"]["default"]) assert dto.mount_dirs["common"] == Path(example["mount_dirs"]["common"]) @@ -46,7 +46,7 @@ def test_init(self) -> None: "free_bytes": 1e9 - 0.6e9, "message": f"{0.6e9 / 1e9:%} used", } - dto = MountPointDTO.parse_obj(example) + dto = MountPointDTO.model_validate(example) assert dto.name == example["name"] assert dto.path == Path(example["path"]) assert dto.total_bytes == example["total_bytes"] @@ -90,7 +90,7 @@ def test_init(self) -> None: "accessed": "2024-01-11T17:54:09", "message": "OK", } - dto = FileInfoDTO.parse_obj(example) + dto = FileInfoDTO.model_validate(example) assert dto.path == Path(example["path"]) assert dto.file_type == example["file_type"] assert dto.file_count == example["file_count"] diff --git a/tests/integration/launcher_blueprint/test_launcher_local.py b/tests/integration/launcher_blueprint/test_launcher_local.py index c29b1ac7e6..67fc421324 100644 --- a/tests/integration/launcher_blueprint/test_launcher_local.py +++ b/tests/integration/launcher_blueprint/test_launcher_local.py @@ -76,10 +76,8 @@ def test_get_launcher_nb_cores( ) assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() actual = res.json() - assert actual == { - "description": "Unknown solver configuration: 'unknown'", - "exception": "UnknownSolverConfig", - } + assert actual["description"] == "Input should be 'slurm', 'local' or 'default'" + assert actual["exception"] == "RequestValidationError" def test_get_launcher_time_limit( self, @@ -130,10 +128,8 @@ def test_get_launcher_time_limit( ) assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY, res.json() actual = res.json() - assert actual == { - "description": "Unknown solver configuration: 'unknown'", - "exception": "UnknownSolverConfig", - } + assert actual["description"] == "Input should be 'slurm', 'local' or 'default'" + assert actual["exception"] == "RequestValidationError" def test_jobs_permissions( self, diff --git a/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py b/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py index c64e68d7bd..9d9f47e120 100644 --- a/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py +++ b/tests/integration/raw_studies_blueprint/test_aggregate_raw_data.py @@ -19,14 +19,7 @@ import pytest from starlette.testclient import TestClient -from antarest.study.business.aggregator_management import ( - MCAllAreasQueryFile, - MCAllLinksQueryFile, - MCIndAreasQueryFile, - MCIndLinksQueryFile, -) from antarest.study.storage.df_download import TableExportFormat -from antarest.study.storage.rawstudy.model.filesystem.matrix.matrix import MatrixFrequency from tests.integration.raw_studies_blueprint.assets import ASSETS_DIR # define the requests parameters for the `economy/mc-ind` outputs aggregation @@ -34,8 +27,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "", "areas_ids": "", "columns_names": "", @@ -45,8 +38,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.HOURLY, + "query_file": "details", + "frequency": "hourly", "mc_years": "1", "areas_ids": "de,fr,it", "columns_names": "", @@ -56,8 +49,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.WEEKLY, + "query_file": "values", + "frequency": "weekly", "mc_years": "1,2", "areas_ids": "", "columns_names": "OP. COST,MRG. PRICE", @@ -67,8 +60,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "2", "areas_ids": "es,fr,de", "columns_names": "", @@ -78,8 +71,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.ANNUAL, + "query_file": "values", + "frequency": "annual", "mc_years": "", "areas_ids": "", "columns_names": "", @@ -89,8 +82,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "COSt,NODu", }, "test-06.result.tsv", @@ -98,8 +91,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.HOURLY, + "query_file": "details", + "frequency": "hourly", "columns_names": "COSt,NODu", }, "test-07.result.tsv", @@ -110,8 +103,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "", "columns_names": "", }, @@ -120,8 +113,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "1", "columns_names": "", }, @@ -130,8 +123,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "1,2", "columns_names": "UCAP LIn.,FLOw qUAD.", }, @@ -140,8 +133,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "1", "links_ids": "de - fr", }, @@ -150,8 +143,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "MArG. COsT,CONG. PRoB +", }, "test-05.result.tsv", @@ -162,8 +155,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "csv", }, "test-01.result.tsv", @@ -171,8 +164,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "tsv", }, "test-01.result.tsv", @@ -180,8 +173,8 @@ ( { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "xlsx", }, "test-01.result.tsv", @@ -192,20 +185,20 @@ INCOHERENT_REQUESTS_BODIES__IND = [ { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "mc_years": "123456789", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "fake_col", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "links_ids": "fake_id", }, ] @@ -214,17 +207,17 @@ { "output_id": "20201014-1425eco-goodbye", "query_file": "fake_query_file", - "frequency": MatrixFrequency.HOURLY, + "frequency": "hourly", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, + "query_file": "values", "frequency": "fake_frequency", }, { "output_id": "20201014-1425eco-goodbye", - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "format": "fake_format", }, ] @@ -234,8 +227,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "areas_ids": "", "columns_names": "", }, @@ -244,8 +237,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "details", + "frequency": "monthly", "areas_ids": "de,fr,it", "columns_names": "", }, @@ -254,8 +247,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "areas_ids": "", "columns_names": "OP. CoST,MRG. PrICE", }, @@ -264,8 +257,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "areas_ids": "es,fr,de", "columns_names": "", }, @@ -274,8 +267,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "areas_ids": "", "columns_names": "", }, @@ -284,8 +277,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.ID, - "frequency": MatrixFrequency.DAILY, + "query_file": "id", + "frequency": "daily", "areas_ids": "", "columns_names": "", }, @@ -294,8 +287,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "COsT,NoDU", }, "test-07-all.result.tsv", @@ -303,8 +296,8 @@ ( { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "details", + "frequency": "monthly", "columns_names": "COsT,NoDU", }, "test-08-all.result.tsv", @@ -315,8 +308,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "", }, "test-01-all.result.tsv", @@ -324,8 +317,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "columns_names": "", }, "test-02-all.result.tsv", @@ -333,8 +326,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "", }, "test-03-all.result.tsv", @@ -342,8 +335,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "links_ids": "de - fr", }, "test-04-all.result.tsv", @@ -351,8 +344,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.ID, - "frequency": MatrixFrequency.DAILY, + "query_file": "id", + "frequency": "daily", "links_ids": "", }, "test-05-all.result.tsv", @@ -360,8 +353,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "MARG. COsT,CONG. ProB +", }, "test-06-all.result.tsv", @@ -372,8 +365,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "csv", }, "test-01-all.result.tsv", @@ -381,8 +374,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "tsv", }, "test-01-all.result.tsv", @@ -390,8 +383,8 @@ ( { "output_id": "20241807-1540eco-extra-outputs", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "xlsx", }, "test-01-all.result.tsv", @@ -402,19 +395,19 @@ INCOHERENT_REQUESTS_BODIES__ALL = [ { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", }, { "output_id": "20201014-1427eco", - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "fake_col", }, { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "values", + "frequency": "monthly", "links_ids": "fake_id", }, ] @@ -423,17 +416,17 @@ { "output_id": "20201014-1427eco", "query_file": "fake_query_file", - "frequency": MatrixFrequency.MONTHLY, + "frequency": "monthly", }, { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, + "query_file": "values", "frequency": "fake_frequency", }, { "output_id": "20201014-1427eco", - "query_file": MCAllAreasQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "format": "fake_format", }, ] @@ -574,8 +567,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-ind/unknown_id", params={ - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -586,8 +579,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-ind/unknown_id", params={ - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -604,8 +597,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-ind/20201014-1425eco-goodbye", params={ - "query_file": MCIndAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.HOURLY, + "query_file": "details", + "frequency": "hourly", "columns_names": "fake_col", }, ) @@ -617,8 +610,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-ind/20201014-1425eco-goodbye", params={ - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", "columns_names": "fake_col", }, ) @@ -640,7 +633,7 @@ def test_non_existing_folder( shutil.rmtree(mc_ind_folder) res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-ind/20201014-1425eco-goodbye", - params={"query_file": MCIndAreasQueryFile.VALUES, "frequency": MatrixFrequency.HOURLY}, + params={"query_file": "values", "frequency": "hourly"}, ) assert res.status_code == 404, res.json() assert "economy/mc-ind" in res.json()["description"] @@ -782,8 +775,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-all/unknown_id", params={ - "query_file": MCIndAreasQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -794,8 +787,8 @@ def test_aggregation_with_wrong_output(self, client: TestClient, user_access_tok res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-all/unknown_id", params={ - "query_file": MCIndLinksQueryFile.VALUES, - "frequency": MatrixFrequency.HOURLY, + "query_file": "values", + "frequency": "hourly", }, ) assert res.status_code == 404, res.json() @@ -813,8 +806,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/areas/aggregate/mc-all/20201014-1427eco", params={ - "query_file": MCAllAreasQueryFile.DETAILS, - "frequency": MatrixFrequency.MONTHLY, + "query_file": "details", + "frequency": "monthly", "columns_names": "fake_col", }, ) @@ -826,8 +819,8 @@ def test_empty_columns(self, client: TestClient, user_access_token: str, interna res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-all/20241807-1540eco-extra-outputs", params={ - "query_file": MCAllLinksQueryFile.VALUES, - "frequency": MatrixFrequency.DAILY, + "query_file": "values", + "frequency": "daily", "columns_names": "fake_col", }, ) @@ -847,7 +840,7 @@ def test_non_existing_folder( shutil.rmtree(mc_all_path) res = client.get( f"/v1/studies/{internal_study_id}/links/aggregate/mc-all/20241807-1540eco-extra-outputs", - params={"query_file": MCAllLinksQueryFile.VALUES, "frequency": MatrixFrequency.DAILY}, + params={"query_file": "values", "frequency": "daily"}, ) assert res.status_code == 404, res.json() assert "economy/mc-all" in res.json()["description"] diff --git a/tests/integration/raw_studies_blueprint/test_download_matrices.py b/tests/integration/raw_studies_blueprint/test_download_matrices.py index 0e24c4d164..2228a71388 100644 --- a/tests/integration/raw_studies_blueprint/test_download_matrices.py +++ b/tests/integration/raw_studies_blueprint/test_download_matrices.py @@ -359,7 +359,7 @@ def test_download_matrices(self, client: TestClient, user_access_token: str, int for export_format in ["tsv", "xlsx"]: res = client.get( f"/v1/studies/{study_860_id}/raw/download", - params={"path": "input/hydro/series/de/mingen", "format": {export_format}}, + params={"path": "input/hydro/series/de/mingen", "format": export_format}, headers=user_headers, ) assert res.status_code == 200 diff --git a/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json b/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json index 80ffa29e32..1e0f3ada52 100644 --- a/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json +++ b/tests/integration/studies_blueprint/assets/test_synthesis/raw_study.synthesis.json @@ -38,22 +38,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -77,22 +62,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -116,22 +86,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -155,22 +110,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -194,22 +134,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -233,22 +158,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -272,22 +182,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -311,22 +206,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -350,22 +230,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -413,22 +278,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -452,22 +302,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -491,22 +326,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -530,22 +350,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -569,22 +374,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -608,22 +398,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -647,22 +422,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -686,22 +446,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -725,22 +470,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -788,22 +518,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -827,22 +542,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -866,22 +566,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -905,22 +590,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -944,22 +614,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -983,22 +638,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1022,22 +662,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1061,22 +686,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1100,22 +710,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -1151,22 +746,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -1190,22 +770,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -1229,22 +794,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -1268,22 +818,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -1307,22 +842,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -1346,22 +866,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1385,22 +890,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1424,22 +914,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1463,22 +938,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], diff --git a/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json b/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json index a77fa18a58..7e449747e4 100644 --- a/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json +++ b/tests/integration/studies_blueprint/assets/test_synthesis/variant_study.synthesis.json @@ -38,22 +38,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -77,22 +62,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -116,22 +86,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -155,22 +110,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -194,22 +134,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -233,22 +158,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -272,22 +182,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -311,22 +206,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -350,22 +230,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -413,22 +278,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -452,22 +302,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -491,22 +326,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -530,22 +350,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -569,22 +374,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -608,22 +398,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -647,22 +422,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -686,22 +446,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -725,22 +470,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -788,22 +518,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -827,22 +542,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -866,22 +566,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -905,22 +590,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -944,22 +614,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -983,22 +638,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1022,22 +662,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1061,22 +686,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1100,22 +710,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], @@ -1151,22 +746,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 10.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "02_wind_on", @@ -1190,22 +770,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 20.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "03_wind_off", @@ -1229,22 +794,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 30.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "04_res", @@ -1268,22 +818,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 40.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "05_nuclear", @@ -1307,22 +842,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 50.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "06_coal", @@ -1346,22 +866,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 60.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "07_gas", @@ -1385,22 +890,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 70.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "08_non-res", @@ -1424,22 +914,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 80.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 }, { "id": "09_hydro_pump", @@ -1463,22 +938,7 @@ "fixed-cost": 0.0, "startup-cost": 0.0, "market-bid-cost": 90.0, - "co2": 0.0, - "nh3": 0.0, - "so2": 0.0, - "nox": 0.0, - "pm2_5": 0.0, - "pm5": 0.0, - "pm10": 0.0, - "nmvoc": 0.0, - "op1": 0.0, - "op2": 0.0, - "op3": 0.0, - "op4": 0.0, - "op5": 0.0, - "costgeneration": "SetManually", - "efficiency": 100.0, - "variableomcost": 0.0 + "co2": 0.0 } ], "renewables": [], diff --git a/tests/integration/studies_blueprint/test_comments.py b/tests/integration/studies_blueprint/test_comments.py index 68615eb813..b83cc134b7 100644 --- a/tests/integration/studies_blueprint/test_comments.py +++ b/tests/integration/studies_blueprint/test_comments.py @@ -38,12 +38,9 @@ def test_raw_study( This test verifies that we can retrieve and modify the comments of a study. It also performs performance measurements and analyzes. """ - + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/comments") assert res.status_code == 200, res.json() actual = res.json() actual_xml = ElementTree.parse(io.StringIO(actual)).getroot() @@ -52,10 +49,7 @@ def test_raw_study( # Ensure the duration is relatively short start = time.time() - res = client.get( - f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/comments") assert res.status_code == 200, res.json() duration = time.time() - start assert 0 <= duration <= 0.1, f"Duration is {duration} seconds" @@ -63,16 +57,12 @@ def test_raw_study( # Update the comments of the study res = client.put( f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"comments": "Ceci est un commentaire en français."}, ) assert res.status_code == 204, res.json() # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{internal_study_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{internal_study_id}/comments") assert res.status_code == 200, res.json() assert res.json() == "Ceci est un commentaire en français." @@ -86,10 +76,10 @@ def test_variant_study( This test verifies that we can retrieve and modify the comments of a VARIANT study. It also performs performance measurements and analyzes. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # First, we create a copy of the study, and we convert it to a managed study. res = client.post( f"/v1/studies/{internal_study_id}/copy", - headers={"Authorization": f"Bearer {user_access_token}"}, params={"dest": "default", "with_outputs": False, "use_task": False}, # type: ignore ) assert res.status_code == 201, res.json() @@ -97,20 +87,13 @@ def test_variant_study( assert base_study_id is not None # Then, we create a new variant of the base study - res = client.post( - f"/v1/studies/{base_study_id}/variants", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"name": "Variant XYZ"}, - ) + res = client.post(f"/v1/studies/{base_study_id}/variants", params={"name": "Variant XYZ"}) assert res.status_code == 200, res.json() # should be CREATED variant_id = res.json() assert variant_id is not None # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/comments") assert res.status_code == 200, res.json() actual = res.json() actual_xml = ElementTree.parse(io.StringIO(actual)).getroot() @@ -119,10 +102,7 @@ def test_variant_study( # Ensure the duration is relatively short start = time.time() - res = client.get( - f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/comments") assert res.status_code == 200, res.json() duration = time.time() - start assert 0 <= duration <= 0.3, f"Duration is {duration} seconds" @@ -130,15 +110,11 @@ def test_variant_study( # Update the comments of the study res = client.put( f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, json={"comments": "Ceci est un commentaire en français."}, ) assert res.status_code == 204, res.json() # Get the comments of the study and compare with the expected file - res = client.get( - f"/v1/studies/{variant_id}/comments", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/comments") assert res.status_code == 200, res.json() assert res.json() == "Ceci est un commentaire en français." diff --git a/tests/integration/studies_blueprint/test_disk_usage.py b/tests/integration/studies_blueprint/test_disk_usage.py index 9fd167e0b5..6958667863 100644 --- a/tests/integration/studies_blueprint/test_disk_usage.py +++ b/tests/integration/studies_blueprint/test_disk_usage.py @@ -81,7 +81,7 @@ def test_disk_usage_endpoint( # Wait for task completion res = client.get(f"/v1/tasks/{task_id}", headers=user_headers, params={"wait_for_completion": True}) assert res.status_code == 200 - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result is not None assert task_result.result.success diff --git a/tests/integration/studies_blueprint/test_get_studies.py b/tests/integration/studies_blueprint/test_get_studies.py index a1a29d0d5f..8e94e15d92 100644 --- a/tests/integration/studies_blueprint/test_get_studies.py +++ b/tests/integration/studies_blueprint/test_get_studies.py @@ -1483,7 +1483,7 @@ def test_get_studies__invalid_parameters( res = client.get(STUDIES_URL, headers=headers, params={"sortBy": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"not a valid enumeration member", description), f"{description=}" + assert re.search("Input should be", description), f"{description=}" # Invalid `pageNb` parameter (negative integer) res = client.get(STUDIES_URL, headers=headers, params={"pageNb": -1}) @@ -1495,7 +1495,7 @@ def test_get_studies__invalid_parameters( res = client.get(STUDIES_URL, headers=headers, params={"pageNb": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"not a valid integer", description), f"{description=}" + assert re.search(r"should be a valid integer", description), f"{description=}" # Invalid `pageSize` parameter (negative integer) res = client.get(STUDIES_URL, headers=headers, params={"pageSize": -1}) @@ -1507,43 +1507,43 @@ def test_get_studies__invalid_parameters( res = client.get(STUDIES_URL, headers=headers, params={"pageSize": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"not a valid integer", description), f"{description=}" + assert re.search(r"should be a valid integer", description), f"{description=}" # Invalid `managed` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"managed": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" # Invalid `archived` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"archived": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" # Invalid `variant` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"variant": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" # Invalid `versions` parameter (not a list of integers) res = client.get(STUDIES_URL, headers=headers, params={"versions": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"string does not match regex", description), f"{description=}" + assert re.search(r"String should match pattern", description), f"{description=}" # Invalid `users` parameter (not a list of integers) res = client.get(STUDIES_URL, headers=headers, params={"users": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"string does not match regex", description), f"{description=}" + assert re.search(r"String should match pattern", description), f"{description=}" # Invalid `exists` parameter (not a boolean) res = client.get(STUDIES_URL, headers=headers, params={"exists": "invalid"}) assert res.status_code == INVALID_PARAMS_STATUS_CODE, res.json() description = res.json()["description"] - assert re.search(r"could not be parsed to a boolean", description), f"{description=}" + assert re.search(r"should be a valid boolean", description), f"{description=}" def test_studies_counting(client: TestClient, admin_access_token: str, user_access_token: str) -> None: diff --git a/tests/integration/studies_blueprint/test_update_tags.py b/tests/integration/studies_blueprint/test_update_tags.py index 0c6b5f60fb..fd9005a98a 100644 --- a/tests/integration/studies_blueprint/test_update_tags.py +++ b/tests/integration/studies_blueprint/test_update_tags.py @@ -28,14 +28,10 @@ def test_update_tags( This test verifies that we can update the tags of a study. It also tests the tags normalization. """ - + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Classic usage: set some tags to a study study_tags = ["Tag1", "Tag2"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) == set(study_tags) @@ -44,11 +40,7 @@ def test_update_tags( # - "Tag1" is preserved, but with the same case as the existing one. # - "Tag2" is replaced by "Tag3". study_tags = ["tag1", "Tag3"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) != set(study_tags) # not the same case @@ -57,22 +49,14 @@ def test_update_tags( # String normalization: whitespaces are stripped and # consecutive whitespaces are replaced by a single one. study_tags = [" \xa0Foo \t Bar \n ", " \t Baz\xa0\xa0"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) == {"Foo Bar", "Baz"} # We can have symbols in the tags study_tags = ["Foo-Bar", ":Baz%"] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 200, res.json() actual = res.json() assert set(actual["tags"]) == {"Foo-Bar", ":Baz%"} @@ -83,13 +67,10 @@ def test_update_tags__invalid_tags( user_access_token: str, internal_study_id: str, ) -> None: + client.headers = {"Authorization": f"Bearer {user_access_token}"} # We cannot have empty tags study_tags = [""] - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 422, res.json() description = res.json()["description"] assert "Tag cannot be empty" in description @@ -97,11 +78,7 @@ def test_update_tags__invalid_tags( # We cannot have tags longer than 40 characters study_tags = ["very long tags, very long tags, very long tags"] assert len(study_tags[0]) > 40 - res = client.put( - f"/v1/studies/{internal_study_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"tags": study_tags}, - ) + res = client.put(f"/v1/studies/{internal_study_id}", json={"tags": study_tags}) assert res.status_code == 422, res.json() description = res.json()["description"] assert "Tag is too long" in description diff --git a/tests/integration/study_data_blueprint/test_advanced_parameters.py b/tests/integration/study_data_blueprint/test_advanced_parameters.py index f20b597c66..7b76ee0c87 100644 --- a/tests/integration/study_data_blueprint/test_advanced_parameters.py +++ b/tests/integration/study_data_blueprint/test_advanced_parameters.py @@ -108,7 +108,7 @@ def test_set_advanced_parameters_values( ) assert res.status_code == 422 assert res.json()["exception"] == "RequestValidationError" - assert res.json()["description"] == "Invalid value: fake_correlation" + assert res.json()["description"] == "Value error, Invalid value: fake_correlation" obj = {"unitCommitmentMode": "milp"} res = client.put( diff --git a/tests/integration/study_data_blueprint/test_binding_constraints.py b/tests/integration/study_data_blueprint/test_binding_constraints.py index e149e8a49a..42b7a9cc57 100644 --- a/tests/integration/study_data_blueprint/test_binding_constraints.py +++ b/tests/integration/study_data_blueprint/test_binding_constraints.py @@ -15,7 +15,7 @@ import numpy as np import pandas as pd import pytest -from requests.exceptions import HTTPError +from httpx._exceptions import HTTPError from starlette.testclient import TestClient from antarest.study.business.binding_constraint_management import ClusterTerm, ConstraintTerm, LinkTerm @@ -319,7 +319,7 @@ def test_lifecycle__nominal(self, client: TestClient, user_access_token: str, st assert res.status_code == 422, res.json() assert res.json() == { "body": {"data": {}, "id": f"{area1_id}.{cluster_id}"}, - "description": "field required", + "description": "Field required", "exception": "RequestValidationError", } diff --git a/tests/integration/study_data_blueprint/test_config_general.py b/tests/integration/study_data_blueprint/test_config_general.py index 27b69fe557..7353814e1b 100644 --- a/tests/integration/study_data_blueprint/test_config_general.py +++ b/tests/integration/study_data_blueprint/test_config_general.py @@ -45,7 +45,7 @@ def test_get_general_form_values( "firstJanuary": "Monday", "firstMonth": "january", "firstWeekDay": "Monday", - "horizon": "2030", + "horizon": 2030, "lastDay": 7, "leapYear": False, "mcScenario": True, diff --git a/tests/integration/study_data_blueprint/test_renewable.py b/tests/integration/study_data_blueprint/test_renewable.py index 49b68faf28..8e3fb8051b 100644 --- a/tests/integration/study_data_blueprint/test_renewable.py +++ b/tests/integration/study_data_blueprint/test_renewable.py @@ -36,7 +36,6 @@ * validate the consistency of the matrices (and properties) """ -import json import re import typing as t @@ -50,7 +49,7 @@ from antarest.study.storage.rawstudy.model.filesystem.config.renewable import RenewableProperties from tests.integration.utils import wait_task_completion -DEFAULT_PROPERTIES = json.loads(RenewableProperties(name="Dummy").json()) +DEFAULT_PROPERTIES = RenewableProperties(name="Dummy").model_dump(mode="json") DEFAULT_PROPERTIES = {to_camel_case(k): v for k, v in DEFAULT_PROPERTIES.items() if k != "name"} # noinspection SpellCheckingInspection @@ -537,13 +536,10 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var In this test, we want to check that renewable clusters can be managed in the context of a "variant" study. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Create an area area_name = "France" - res = client.post( - f"/v1/studies/{variant_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": area_name, "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{variant_id}/areas", json={"name": area_name, "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_cfg = res.json() area_id = area_cfg["id"] @@ -552,7 +548,6 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var cluster_name = "Th1" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": cluster_name, "group": "Wind Offshore", @@ -565,9 +560,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Update the renewable cluster res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"unitCount": 15}, + f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{cluster_id}", json={"unitCount": 15} ) assert res.status_code == 200, res.json() cluster_cfg = res.json() @@ -577,19 +570,13 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() matrix_path = f"input/renewables/series/{area_id}/{cluster_id.lower()}/series" args = {"target": matrix_path, "matrix": matrix} - res = client.post( - f"/v1/studies/{variant_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{variant_id}/commands", json=[{"action": "replace_matrix", "args": args}]) assert res.status_code in {200, 201}, res.json() # Duplicate the renewable cluster new_name = "Th2" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/renewables/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": new_name}, + f"/v1/studies/{variant_id}/areas/{area_id}/renewables/{cluster_id}", params={"newName": new_name} ) assert res.status_code in {200, 201}, res.json() cluster_cfg = res.json() @@ -597,10 +584,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var new_id = cluster_cfg["id"] # Check that the duplicate has the right properties - res = client.get( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{new_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable/{new_id}") assert res.status_code == 200, res.json() cluster_cfg = res.json() assert cluster_cfg["group"] == "Wind Offshore" @@ -609,27 +593,19 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Check that the duplicate has the right matrix new_cluster_matrix_path = f"input/renewables/series/{area_id}/{new_id.lower()}/series" - res = client.get( - f"/v1/studies/{variant_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix # Delete the renewable cluster - res = client.delete( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[cluster_id], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{variant_id}/areas/{area_id}/clusters/renewable", json=[cluster_id] ) assert res.status_code == 204, res.json() # Check the list of variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 7 diff --git a/tests/integration/study_data_blueprint/test_st_storage.py b/tests/integration/study_data_blueprint/test_st_storage.py index 7f03317659..8d83d12f37 100644 --- a/tests/integration/study_data_blueprint/test_st_storage.py +++ b/tests/integration/study_data_blueprint/test_st_storage.py @@ -31,11 +31,11 @@ _ST_STORAGE_OUTPUT_860 = create_storage_output(860, cluster_id="dummy", config={"name": "dummy"}) _ST_STORAGE_OUTPUT_880 = create_storage_output(880, cluster_id="dummy", config={"name": "dummy"}) -DEFAULT_CONFIG_860 = json.loads(_ST_STORAGE_860_CONFIG.json(by_alias=True, exclude={"id", "name"})) -DEFAULT_CONFIG_880 = json.loads(_ST_STORAGE_880_CONFIG.json(by_alias=True, exclude={"id", "name"})) +DEFAULT_CONFIG_860 = _ST_STORAGE_860_CONFIG.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) +DEFAULT_CONFIG_880 = _ST_STORAGE_880_CONFIG.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) -DEFAULT_OUTPUT_860 = json.loads(_ST_STORAGE_OUTPUT_860.json(by_alias=True, exclude={"id", "name"})) -DEFAULT_OUTPUT_880 = json.loads(_ST_STORAGE_OUTPUT_880.json(by_alias=True, exclude={"id", "name"})) +DEFAULT_OUTPUT_860 = _ST_STORAGE_OUTPUT_860.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) +DEFAULT_OUTPUT_880 = _ST_STORAGE_OUTPUT_880.model_dump(mode="json", by_alias=True, exclude={"id", "name"}) # noinspection SpellCheckingInspection @@ -325,7 +325,7 @@ def test_lifecycle__nominal( json=[siemens_battery_id], ) assert res.status_code == 204, res.json() - assert res.text in {"", "null"} # Old FastAPI versions return 'null'. + assert not res.text # If the short-term storage list is empty, the deletion should be a no-op. res = client.request( @@ -335,7 +335,7 @@ def test_lifecycle__nominal( json=[], ) assert res.status_code == 204, res.json() - assert res.text in {"", "null"} # Old FastAPI versions return 'null'. + assert not res.text # It's possible to delete multiple short-term storages at once. # In the following example, we will create two short-term storages: @@ -395,7 +395,7 @@ def test_lifecycle__nominal( json=[grand_maison_id, duplicated_output["id"]], ) assert res.status_code == 204, res.json() - assert res.text in {"", "null"} # Old FastAPI versions return 'null'. + assert not res.text # Only one st-storage should remain. res = client.get( @@ -487,7 +487,7 @@ def test_lifecycle__nominal( assert res.status_code == 422, res.json() obj = res.json() description = obj["description"] - assert re.search(r"not a valid enumeration member", description, flags=re.IGNORECASE) + assert re.search(r"Input should be", description) # Check PATCH with the wrong `area_id` res = client.patch( @@ -589,40 +589,26 @@ def test__default_values( Then the short-term storage is created with initialLevel = 0.0, and initialLevelOptim = False. """ # Create a new study in version 860 (or higher) - user_headers = {"Authorization": f"Bearer {user_access_token}"} - res = client.post( - "/v1/studies", - headers=user_headers, - params={"name": "MyStudy", "version": study_version}, - ) + client.headers = {"Authorization": f"Bearer {user_access_token}"} + res = client.post("/v1/studies", params={"name": "MyStudy", "version": study_version}) assert res.status_code in {200, 201}, res.json() study_id = res.json() if study_type == "variant": # Create Variant - res = client.post( - f"/v1/studies/{study_id}/variants", - headers=user_headers, - params={"name": "Variant 1"}, - ) + res = client.post(f"/v1/studies/{study_id}/variants", params={"name": "Variant 1"}) assert res.status_code in {200, 201}, res.json() study_id = res.json() # Create a new area named "FR" - res = client.post( - f"/v1/studies/{study_id}/areas", - headers=user_headers, - json={"name": "FR", "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{study_id}/areas", json={"name": "FR", "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_id = res.json()["id"] # Create a new short-term storage named "Tesla Battery" tesla_battery = "Tesla Battery" res = client.post( - f"/v1/studies/{study_id}/areas/{area_id}/storages", - headers=user_headers, - json={"name": tesla_battery, "group": "Battery"}, + f"/v1/studies/{study_id}/areas/{area_id}/storages", json={"name": tesla_battery, "group": "Battery"} ) assert res.status_code == 200, res.json() tesla_battery_id = res.json()["id"] @@ -633,7 +619,6 @@ def test__default_values( # are properly set in the configuration file. res = client.get( f"/v1/studies/{study_id}/raw", - headers=user_headers, params={"path": f"input/st-storage/clusters/{area_id}/list/{tesla_battery_id}"}, ) assert res.status_code == 200, res.json() @@ -646,28 +631,19 @@ def test__default_values( # in the variant commands. # Create a variant of the study - res = client.post( - f"/v1/studies/{study_id}/variants", - headers=user_headers, - params={"name": "MyVariant"}, - ) + res = client.post(f"/v1/studies/{study_id}/variants", params={"name": "MyVariant"}) assert res.status_code in {200, 201}, res.json() variant_id = res.json() # Create a new short-term storage named "Siemens Battery" siemens_battery = "Siemens Battery" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers=user_headers, - json={"name": siemens_battery, "group": "Battery"}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages", json={"name": siemens_battery, "group": "Battery"} ) assert res.status_code == 200, res.json() # Check the variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers=user_headers, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 1 @@ -691,17 +667,12 @@ def test__default_values( # Update the initialLevel property of the "Siemens Battery" short-term storage to 0.5 siemens_battery_id = transform_name_to_id(siemens_battery) res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers=user_headers, - json={"initialLevel": 0.5}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", json={"initialLevel": 0.5} ) assert res.status_code == 200, res.json() # Check the variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers=user_headers, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 2 @@ -710,7 +681,7 @@ def test__default_values( "id": ANY, "action": "update_config", "args": { - "data": "0.5", + "data": 0.5, "target": "input/st-storage/clusters/fr/list/siemens battery/initiallevel", }, "version": 1, @@ -720,16 +691,12 @@ def test__default_values( # Update the initialLevel property of the "Siemens Battery" short-term storage back to 0 res = client.patch( f"/v1/studies/{variant_id}/areas/{area_id}/storages/{siemens_battery_id}", - headers=user_headers, json={"initialLevel": 0.0, "injectionNominalCapacity": 1600}, ) assert res.status_code == 200, res.json() # Check the variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers=user_headers, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 3 @@ -739,11 +706,11 @@ def test__default_values( "action": "update_config", "args": [ { - "data": "1600.0", + "data": 1600.0, "target": "input/st-storage/clusters/fr/list/siemens battery/injectionnominalcapacity", }, { - "data": "0.0", + "data": 0.0, "target": "input/st-storage/clusters/fr/list/siemens battery/initiallevel", }, ], @@ -755,7 +722,6 @@ def test__default_values( # are properly set in the configuration file. res = client.get( f"/v1/studies/{variant_id}/raw", - headers=user_headers, params={"path": f"input/st-storage/clusters/{area_id}/list/{siemens_battery_id}"}, ) assert res.status_code == 200, res.json() @@ -803,13 +769,10 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var In this test, we want to check that short-term storages can be managed in the context of a "variant" study. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Create an area area_name = "France" - res = client.post( - f"/v1/studies/{variant_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": area_name, "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{variant_id}/areas", json={"name": area_name, "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_cfg = res.json() area_id = area_cfg["id"] @@ -818,7 +781,6 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var cluster_name = "Tesla1" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": cluster_name, "group": "Battery", @@ -832,9 +794,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Update the short-term storage res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"reservoirCapacity": 5600}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", json={"reservoirCapacity": 5600} ) assert res.status_code == 200, res.json() cluster_cfg = res.json() @@ -844,19 +804,13 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() matrix_path = f"input/st-storage/series/{area_id}/{cluster_id.lower()}/pmax_injection" args = {"target": matrix_path, "matrix": matrix} - res = client.post( - f"/v1/studies/{variant_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{variant_id}/commands", json=[{"action": "replace_matrix", "args": args}]) assert res.status_code in {200, 201}, res.json() # Duplicate the short-term storage new_name = "Tesla2" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": new_name}, + f"/v1/studies/{variant_id}/areas/{area_id}/storages/{cluster_id}", params={"newName": new_name} ) assert res.status_code in {200, 201}, res.json() cluster_cfg = res.json() @@ -864,10 +818,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var new_id = cluster_cfg["id"] # Check that the duplicate has the right properties - res = client.get( - f"/v1/studies/{variant_id}/areas/{area_id}/storages/{new_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/areas/{area_id}/storages/{new_id}") assert res.status_code == 200, res.json() cluster_cfg = res.json() assert cluster_cfg["group"] == "Battery" @@ -877,27 +828,19 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Check that the duplicate has the right matrix new_cluster_matrix_path = f"input/st-storage/series/{area_id}/{new_id.lower()}/pmax_injection" - res = client.get( - f"/v1/studies/{variant_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix # Delete the short-term storage - res = client.delete( - f"/v1/studies/{variant_id}/areas/{area_id}/storages", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[cluster_id], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{variant_id}/areas/{area_id}/storages", json=[cluster_id] ) assert res.status_code == 204, res.json() # Check the list of variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 7 diff --git a/tests/integration/study_data_blueprint/test_thermal.py b/tests/integration/study_data_blueprint/test_thermal.py index 17ce7d7037..6567fc205f 100644 --- a/tests/integration/study_data_blueprint/test_thermal.py +++ b/tests/integration/study_data_blueprint/test_thermal.py @@ -40,7 +40,6 @@ * validate the consistency of the matrices (and properties) """ import io -import json import re import typing as t @@ -54,7 +53,7 @@ from antarest.study.storage.rawstudy.model.filesystem.config.thermal import ThermalProperties from tests.integration.utils import wait_task_completion -DEFAULT_PROPERTIES = json.loads(ThermalProperties(name="Dummy").json()) +DEFAULT_PROPERTIES = ThermalProperties(name="Dummy").model_dump(mode="json") DEFAULT_PROPERTIES = {to_camel_case(k): v for k, v in DEFAULT_PROPERTIES.items() if k != "name"} # noinspection SpellCheckingInspection @@ -511,7 +510,6 @@ def test_lifecycle( json={"nox": 10.0}, ) assert res.status_code == 200 - assert res.json()["nox"] == 10.0 # Update with the field `efficiency`. Should succeed even with versions prior to v8.7 res = client.patch( @@ -520,7 +518,6 @@ def test_lifecycle( json={"efficiency": 97.0}, ) assert res.status_code == 200 - assert res.json()["efficiency"] == 97.0 # ============================= # THERMAL CLUSTER DUPLICATION @@ -949,13 +946,10 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var In this test, we want to check that thermal clusters can be managed in the context of a "variant" study. """ + client.headers = {"Authorization": f"Bearer {user_access_token}"} # Create an area area_name = "France" - res = client.post( - f"/v1/studies/{variant_id}/areas", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={"name": area_name, "type": "AREA"}, - ) + res = client.post(f"/v1/studies/{variant_id}/areas", json={"name": area_name, "type": "AREA"}) assert res.status_code in {200, 201}, res.json() area_cfg = res.json() area_id = area_cfg["id"] @@ -964,7 +958,6 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var cluster_name = "Th1" res = client.post( f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, json={ "name": cluster_name, "group": "Nuclear", @@ -978,11 +971,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Update the thermal cluster res = client.patch( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - json={ - "marginalCost": 0.2, - }, + f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{cluster_id}", json={"marginalCost": 0.2} ) assert res.status_code == 200, res.json() cluster_cfg = res.json() @@ -992,19 +981,13 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var matrix = np.random.randint(0, 2, size=(8760, 1)).tolist() matrix_path = f"input/thermal/prepro/{area_id}/{cluster_id.lower()}/data" args = {"target": matrix_path, "matrix": matrix} - res = client.post( - f"/v1/studies/{variant_id}/commands", - json=[{"action": "replace_matrix", "args": args}], - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.post(f"/v1/studies/{variant_id}/commands", json=[{"action": "replace_matrix", "args": args}]) assert res.status_code in {200, 201}, res.json() # Duplicate the thermal cluster new_name = "Th2" res = client.post( - f"/v1/studies/{variant_id}/areas/{area_id}/thermals/{cluster_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - params={"newName": new_name}, + f"/v1/studies/{variant_id}/areas/{area_id}/thermals/{cluster_id}", params={"newName": new_name} ) assert res.status_code in {200, 201}, res.json() cluster_cfg = res.json() @@ -1012,10 +995,7 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var new_id = cluster_cfg["id"] # Check that the duplicate has the right properties - res = client.get( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{new_id}", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal/{new_id}") assert res.status_code == 200, res.json() cluster_cfg = res.json() assert cluster_cfg["group"] == "Nuclear" @@ -1025,27 +1005,19 @@ def test_variant_lifecycle(self, client: TestClient, user_access_token: str, var # Check that the duplicate has the right matrix new_cluster_matrix_path = f"input/thermal/prepro/{area_id}/{new_id.lower()}/data" - res = client.get( - f"/v1/studies/{variant_id}/raw", - params={"path": new_cluster_matrix_path}, - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/raw", params={"path": new_cluster_matrix_path}) assert res.status_code == 200 assert res.json()["data"] == matrix # Delete the thermal cluster - res = client.delete( - f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal", - headers={"Authorization": f"Bearer {user_access_token}"}, - json=[cluster_id], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{variant_id}/areas/{area_id}/clusters/thermal", json=[cluster_id] ) assert res.status_code == 204, res.json() # Check the list of variant commands - res = client.get( - f"/v1/studies/{variant_id}/commands", - headers={"Authorization": f"Bearer {user_access_token}"}, - ) + res = client.get(f"/v1/studies/{variant_id}/commands") assert res.status_code == 200, res.json() commands = res.json() assert len(commands) == 7 @@ -1182,9 +1154,9 @@ def test_thermal_cluster_deletion(self, client: TestClient, user_access_token: s assert res.status_code == 200, res.json() # check that deleting the thermal cluster in area_1 fails - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", - json=["cluster_1"], + # usage of request instead of delete as httpx doesn't support delete with a payload anymore. + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", json=["cluster_1"] ) assert res.status_code == 403, res.json() @@ -1195,16 +1167,14 @@ def test_thermal_cluster_deletion(self, client: TestClient, user_access_token: s assert res.status_code == 200, res.json() # check that deleting the thermal cluster in area_1 succeeds - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", - json=["cluster_1"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_1/clusters/thermal", json=["cluster_1"] ) assert res.status_code == 204, res.json() # check that deleting the thermal cluster in area_2 fails - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", - json=["cluster_2"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", json=["cluster_2"] ) assert res.status_code == 403, res.json() @@ -1215,15 +1185,13 @@ def test_thermal_cluster_deletion(self, client: TestClient, user_access_token: s assert res.status_code == 200, res.json() # check that deleting the thermal cluster in area_2 succeeds - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", - json=["cluster_2"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_2/clusters/thermal", json=["cluster_2"] ) assert res.status_code == 204, res.json() # check that deleting the thermal cluster in area_3 succeeds - res = client.delete( - f"/v1/studies/{internal_study_id}/areas/area_3/clusters/thermal", - json=["cluster_3"], + res = client.request( + method="DELETE", url=f"/v1/studies/{internal_study_id}/areas/area_3/clusters/thermal", json=["cluster_3"] ) assert res.status_code == 204, res.json() diff --git a/tests/integration/test_apidoc.py b/tests/integration/test_apidoc.py index 605d91333d..d0eb7445d8 100644 --- a/tests/integration/test_apidoc.py +++ b/tests/integration/test_apidoc.py @@ -10,14 +10,15 @@ # # This file is part of the Antares project. -from fastapi.openapi.utils import get_flat_models_from_routes -from fastapi.utils import get_model_definitions -from pydantic.schema import get_model_name_map from starlette.testclient import TestClient +from antarest import __version__ + def test_apidoc(client: TestClient) -> None: - # Asserts that the apidoc can be loaded - flat_models = get_flat_models_from_routes(client.app.routes) - model_name_map = get_model_name_map(flat_models) - get_model_definitions(flat_models=flat_models, model_name_map=model_name_map) + # Local import to avoid breaking all tests if FastAPI changes its API + from fastapi.openapi.utils import get_openapi + + routes = client.app.routes + openapi = get_openapi(title="Antares Web", version=__version__, routes=routes) + assert openapi diff --git a/tests/integration/test_core_blueprint.py b/tests/integration/test_core_blueprint.py index de555ceaf9..a60ba6dcfb 100644 --- a/tests/integration/test_core_blueprint.py +++ b/tests/integration/test_core_blueprint.py @@ -10,7 +10,6 @@ # # This file is part of the Antares project. -import http import re from unittest import mock @@ -48,23 +47,3 @@ def test_version_info(self, app: FastAPI): "dependencies": mock.ANY, } assert actual == expected - - -class TestKillWorker: - def test_kill_worker__not_granted(self, app: FastAPI): - client = TestClient(app, raise_server_exceptions=False) - res = client.get("/kill") - assert res.status_code == http.HTTPStatus.UNAUTHORIZED, res.json() - assert res.json() == {"detail": "Missing cookie access_token_cookie"} - - def test_kill_worker__nominal_case(self, app: FastAPI): - client = TestClient(app, raise_server_exceptions=False) - # login as "admin" - res = client.post("/v1/login", json={"username": "admin", "password": "admin"}) - res.raise_for_status() - credentials = res.json() - admin_access_token = credentials["access_token"] - # kill the worker - res = client.get("/kill", headers={"Authorization": f"Bearer {admin_access_token}"}) - assert res.status_code == 500, res.json() - assert not res.content diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 8e814e322a..a136184969 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -452,7 +452,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: }, ) - client.post( + res = client.post( f"/v1/studies/{study_id}/commands", json=[ { @@ -465,6 +465,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: } ], ) + res.raise_for_status() client.post( f"/v1/studies/{study_id}/commands", @@ -606,13 +607,14 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: }, ] - client.post( + res = client.post( f"/v1/studies/{study_id}/links", json={ "area1": "area 1", "area2": "area 2", }, ) + res.raise_for_status() res_links = client.get(f"/v1/studies/{study_id}/links?with_ui=true") assert res_links.json() == [ { @@ -625,15 +627,16 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: # -- `layers` integration tests res = client.get(f"/v1/studies/{study_id}/layers") - assert res.json() == [LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict()] + res.raise_for_status() + assert res.json() == [LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump()] res = client.post(f"/v1/studies/{study_id}/layers?name=test") assert res.json() == "1" res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), - LayerInfoDTO(id="1", name="test", areas=[]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), + LayerInfoDTO(id="1", name="test", areas=[]).model_dump(), ] res = client.put(f"/v1/studies/{study_id}/layers/1?name=test2") @@ -644,8 +647,8 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: assert res.status_code in {200, 201}, res.json() res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), - LayerInfoDTO(id="1", name="test2", areas=["area 2"]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), + LayerInfoDTO(id="1", name="test2", areas=["area 2"]).model_dump(), ] # Delete the layer '1' that has 1 area @@ -655,7 +658,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: # Ensure the layer is deleted res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), ] # Create the layer again without areas @@ -669,7 +672,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: # Ensure the layer is deleted res = client.get(f"/v1/studies/{study_id}/layers") assert res.json() == [ - LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).dict(), + LayerInfoDTO(id="0", name="All", areas=["area 1", "area 2"]).model_dump(), ] # Try to delete a non-existing layer @@ -755,7 +758,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: "simplexOptimizationRange": SimplexOptimizationRange.WEEK.value, } - client.put( + res = client.put( f"/v1/studies/{study_id}/config/optimization/form", json={ "strategicReserve": False, @@ -763,6 +766,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: "simplexOptimizationRange": SimplexOptimizationRange.DAY.value, }, ) + res.raise_for_status() res_optimization_config = client.get(f"/v1/studies/{study_id}/config/optimization/form") res_optimization_config_json = res_optimization_config.json() assert res_optimization_config_json == { @@ -825,7 +829,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: ) assert res.status_code == 422 assert res.json()["exception"] == "RequestValidationError" - assert res.json()["description"] == "value is not a valid integer" + assert res.json()["description"] == "Input should be a valid integer" # General form @@ -1269,7 +1273,7 @@ def test_area_management(client: TestClient, admin_access_token: str) -> None: }, "ntc": {"stochasticTsStatus": False, "intraModal": False}, } - res_ts_config = client.put( + client.put( f"/v1/studies/{study_id}/config/timeseries/form", json={ "thermal": {"stochasticTsStatus": True}, diff --git a/tests/integration/test_integration_variantmanager_tool.py b/tests/integration/test_integration_variantmanager_tool.py index 605c2659f5..85783fe789 100644 --- a/tests/integration/test_integration_variantmanager_tool.py +++ b/tests/integration/test_integration_variantmanager_tool.py @@ -28,10 +28,12 @@ COMMAND_FILE, MATRIX_STORE_DIR, RemoteVariantGenerator, + create_http_client, extract_commands, generate_diff, generate_study, parse_commands, + set_auth_token, ) from tests.integration.assets import ASSETS_DIR @@ -62,7 +64,9 @@ def generate_study_with_server( ) assert res.status_code == 200, res.json() variant_id = res.json() - generator = RemoteVariantGenerator(variant_id, session=client, token=admin_credentials["access_token"]) + + set_auth_token(client, admin_credentials["access_token"]) + generator = RemoteVariantGenerator(variant_id, host="", session=client) return generator.apply_commands(commands, matrices_dir), variant_id diff --git a/tests/integration/variant_blueprint/test_st_storage.py b/tests/integration/variant_blueprint/test_st_storage.py index bc19c036f2..b4092f0acb 100644 --- a/tests/integration/variant_blueprint/test_st_storage.py +++ b/tests/integration/variant_blueprint/test_st_storage.py @@ -233,18 +233,5 @@ def test_lifecycle( ) assert res.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY description = res.json()["description"] - """ - 4 validation errors for CreateSTStorage - parameters -> group - value is not a valid enumeration member […] - parameters -> injectionnominalcapacity - ensure this value is greater than or equal to 0 (type=value_error.number.not_ge; limit_value=0) - parameters -> initialleveloptim - value could not be parsed to a boolean (type=type_error.bool) - pmax_withdrawal - Matrix values should be between 0 and 1 (type=value_error) - """ - assert "parameters -> group" in description - assert "parameters -> injectionnominalcapacity" in description - assert "parameters -> initialleveloptim" in description - assert "pmax_withdrawal" in description + assert "Matrix values should be between 0 and 1" in description + assert "1 validation error for CreateSTStorage" in description diff --git a/tests/integration/variant_blueprint/test_thermal_cluster.py b/tests/integration/variant_blueprint/test_thermal_cluster.py index 3746bf42ac..db78599cee 100644 --- a/tests/integration/variant_blueprint/test_thermal_cluster.py +++ b/tests/integration/variant_blueprint/test_thermal_cluster.py @@ -140,7 +140,7 @@ def test_cascade_update( ) assert res.status_code == http.HTTPStatus.OK, res.json() task = TaskDTO(**res.json()) - assert task.dict() == { + assert task.model_dump() == { "completion_date_utc": mock.ANY, "creation_date_utc": mock.ANY, "id": task_id, diff --git a/tests/integration/variant_blueprint/test_variant_manager.py b/tests/integration/variant_blueprint/test_variant_manager.py index 8946bb5113..323b8f8d25 100644 --- a/tests/integration/variant_blueprint/test_variant_manager.py +++ b/tests/integration/variant_blueprint/test_variant_manager.py @@ -262,7 +262,7 @@ def test_variant_manager( res = client.get(f"/v1/tasks/{res.json()}?wait_for_completion=true", headers=admin_headers) assert res.status_code == 200 - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result.success # type: ignore @@ -309,7 +309,7 @@ def test_comments(client: TestClient, admin_access_token: str, variant_id: str) # Wait for task completion res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True}) assert res.status_code == 200 - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result is not None assert task_result.result.success @@ -383,7 +383,7 @@ def test_outputs(client: TestClient, admin_access_token: str, variant_id: str, t # Wait for task completion res = client.get(f"/v1/tasks/{task_id}", headers=admin_headers, params={"wait_for_completion": True}) res.raise_for_status() - task_result = TaskDTO.parse_obj(res.json()) + task_result = TaskDTO.model_validate(res.json()) assert task_result.status == TaskStatus.COMPLETED assert task_result.result is not None assert task_result.result.success diff --git a/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py b/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py index 061ab3e1c7..332b532680 100644 --- a/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py +++ b/tests/integration/xpansion_studies_blueprint/test_integration_xpansion.py @@ -23,7 +23,6 @@ def _create_area( client: TestClient, - headers: t.Mapping[str, str], study_id: str, area_name: str, *, @@ -31,7 +30,6 @@ def _create_area( ) -> str: res = client.post( f"/v1/studies/{study_id}/areas", - headers=headers, json={"name": area_name, "type": "AREA", "metadata": {"country": country}}, ) assert res.status_code in {200, 201}, res.json() @@ -40,32 +38,28 @@ def _create_area( def _create_link( client: TestClient, - headers: t.Mapping[str, str], study_id: str, src_area_id: str, dst_area_id: str, ) -> None: - res = client.post( - f"/v1/studies/{study_id}/links", - headers=headers, - json={"area1": src_area_id, "area2": dst_area_id}, - ) + res = client.post(f"/v1/studies/{study_id}/links", json={"area1": src_area_id, "area2": dst_area_id}) assert res.status_code in {200, 201}, res.json() def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_token: str) -> None: headers = {"Authorization": f"Bearer {admin_access_token}"} + client.headers = headers - res = client.post("/v1/studies", headers=headers, params={"name": "foo", "version": "860"}) + res = client.post("/v1/studies", params={"name": "foo", "version": "860"}) assert res.status_code == 201, res.json() study_id = res.json() - area1_id = _create_area(client, headers, study_id, "area1", country="FR") - area2_id = _create_area(client, headers, study_id, "area2", country="DE") - area3_id = _create_area(client, headers, study_id, "area3", country="DE") - _create_link(client, headers, study_id, area1_id, area2_id) + area1_id = _create_area(client, study_id, "area1", country="FR") + area2_id = _create_area(client, study_id, "area2", country="DE") + area3_id = _create_area(client, study_id, "area3", country="DE") + _create_link(client, study_id, area1_id, area2_id) - res = client.post(f"/v1/studies/{study_id}/extensions/xpansion", headers=headers) + res = client.post(f"/v1/studies/{study_id}/extensions/xpansion") assert res.status_code in {200, 201}, res.json() expansion_path = tmp_path / "internal_workspace" / study_id / "user" / "expansion" @@ -73,9 +67,9 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t # Create a client for Xpansion with the xpansion URL xpansion_base_url = f"/v1/studies/{study_id}/extensions/xpansion/" - xp_client = TestClient(client.app, base_url=urljoin(client.base_url, xpansion_base_url)) - - res = xp_client.get("settings", headers=headers) + xp_client = TestClient(client.app, base_url=urljoin(str(client.base_url), xpansion_base_url)) + xp_client.headers = headers + res = xp_client.get("settings") assert res.status_code == 200 assert res.json() == { "master": "integer", @@ -94,7 +88,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "sensitivity_config": {"epsilon": 0.0, "projection": [], "capex": False}, } - res = xp_client.put("settings", headers=headers, json={"optimality_gap": 42}) + res = xp_client.put("settings", json={"optimality_gap": 42}) assert res.status_code == 200 assert res.json() == { "master": "integer", @@ -113,13 +107,13 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "sensitivity_config": {"epsilon": 0.0, "projection": [], "capex": False}, } - res = xp_client.put("settings", headers=headers, json={"additional-constraints": "missing.txt"}) + res = xp_client.put("settings", json={"additional-constraints": "missing.txt"}) assert res.status_code == 404 err_obj = res.json() assert re.search(r"file 'missing.txt' does not exist", err_obj["description"]) assert err_obj["exception"] == "XpansionFileNotFoundError" - res = xp_client.put("settings/additional-constraints", headers=headers, params={"filename": "missing.txt"}) + res = xp_client.put("settings/additional-constraints", params={"filename": "missing.txt"}) assert res.status_code == 404 err_obj = res.json() assert re.search(r"file 'missing.txt' does not exist", err_obj["description"]) @@ -139,7 +133,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "image/jpeg", ) } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code in {200, 201} actual_path = expansion_path / "constraints" / filename_constraints1 assert actual_path.read_text() == content_constraints1 @@ -152,7 +146,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t ), } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -169,7 +163,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "image/jpeg", ), } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code in {200, 201} files = { @@ -179,14 +173,14 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "image/jpeg", ), } - res = xp_client.post("resources/constraints", headers=headers, files=files) + res = xp_client.post("resources/constraints", files=files) assert res.status_code in {200, 201} - res = xp_client.get(f"resources/constraints/{filename_constraints1}", headers=headers) + res = xp_client.get(f"resources/constraints/{filename_constraints1}") assert res.status_code == 200 assert res.json() == content_constraints1 - res = xp_client.get("resources/constraints/", headers=headers) + res = xp_client.get("resources/constraints/") assert res.status_code == 200 assert res.json() == [ filename_constraints1, @@ -194,14 +188,10 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t filename_constraints3, ] - res = xp_client.put( - "settings/additional-constraints", - headers=headers, - params={"filename": filename_constraints1}, - ) + res = xp_client.put("settings/additional-constraints", params={"filename": filename_constraints1}) assert res.status_code == 200 - res = xp_client.delete(f"resources/constraints/{filename_constraints1}", headers=headers) + res = xp_client.delete(f"resources/constraints/{filename_constraints1}") assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -211,10 +201,10 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t ) assert err_obj["exception"] == "FileCurrentlyUsedInSettings" - res = xp_client.put("settings/additional-constraints", headers=headers) + res = xp_client.put("settings/additional-constraints") assert res.status_code == 200 - res = xp_client.delete(f"resources/constraints/{filename_constraints1}", headers=headers) + res = xp_client.delete(f"resources/constraints/{filename_constraints1}") assert res.status_code == 200 candidate1 = { @@ -223,7 +213,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.post("candidates", headers=headers, json=candidate1) + res = xp_client.post("candidates", json=candidate1) assert res.status_code in {200, 201} candidate2 = { @@ -232,7 +222,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.post("candidates", headers=headers, json=candidate2) + res = xp_client.post("candidates", json=candidate2) assert res.status_code == 404 err_obj = res.json() assert re.search( @@ -248,7 +238,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.post("candidates", headers=headers, json=candidate3) + res = xp_client.post("candidates", json=candidate3) assert res.status_code == 404 err_obj = res.json() assert re.search( @@ -271,12 +261,12 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "txt/csv", ) } - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code in {200, 201} actual_path = expansion_path / "capa" / filename_capa1 assert actual_path.read_text() == content_capa1 - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -293,7 +283,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "txt/csv", ) } - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code in {200, 201} files = { @@ -303,11 +293,11 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "txt/csv", ) } - res = xp_client.post("resources/capacities", headers=headers, files=files) + res = xp_client.post("resources/capacities", files=files) assert res.status_code in {200, 201} # get single capa - res = xp_client.get(f"resources/capacities/{filename_capa1}", headers=headers) + res = xp_client.get(f"resources/capacities/{filename_capa1}") assert res.status_code == 200 assert res.json() == { "columns": [0], @@ -315,7 +305,7 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "index": [0], } - res = xp_client.get("resources/capacities", headers=headers) + res = xp_client.get("resources/capacities") assert res.status_code == 200 assert res.json() == [filename_capa1, filename_capa2, filename_capa3] @@ -326,21 +316,21 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "max-investment": 1.0, "link-profile": filename_capa1, } - res = xp_client.post("candidates", headers=headers, json=candidate4) + res = xp_client.post("candidates", json=candidate4) assert res.status_code in {200, 201} - res = xp_client.get(f"candidates/{candidate1['name']}", headers=headers) + res = xp_client.get(f"candidates/{candidate1['name']}") assert res.status_code == 200 - assert res.json() == XpansionCandidateDTO.parse_obj(candidate1).dict(by_alias=True) + assert res.json() == XpansionCandidateDTO.model_validate(candidate1).model_dump(by_alias=True) - res = xp_client.get("candidates", headers=headers) + res = xp_client.get("candidates") assert res.status_code == 200 assert res.json() == [ - XpansionCandidateDTO.parse_obj(candidate1).dict(by_alias=True), - XpansionCandidateDTO.parse_obj(candidate4).dict(by_alias=True), + XpansionCandidateDTO.model_validate(candidate1).model_dump(by_alias=True), + XpansionCandidateDTO.model_validate(candidate4).model_dump(by_alias=True), ] - res = xp_client.delete(f"resources/capacities/{filename_capa1}", headers=headers) + res = xp_client.delete(f"resources/capacities/{filename_capa1}") assert res.status_code == 409 err_obj = res.json() assert re.search( @@ -355,13 +345,13 @@ def test_integration_xpansion(client: TestClient, tmp_path: Path, admin_access_t "annual-cost-per-mw": 1, "max-investment": 1.0, } - res = xp_client.put(f"candidates/{candidate4['name']}", headers=headers, json=candidate5) + res = xp_client.put(f"candidates/{candidate4['name']}", json=candidate5) assert res.status_code == 200 - res = xp_client.delete(f"resources/capacities/{filename_capa1}", headers=headers) + res = xp_client.delete(f"resources/capacities/{filename_capa1}") assert res.status_code == 200 - res = client.delete(f"/v1/studies/{study_id}/extensions/xpansion", headers=headers) + res = client.delete(f"/v1/studies/{study_id}/extensions/xpansion") assert res.status_code == 200 assert not expansion_path.exists() diff --git a/tests/launcher/test_service.py b/tests/launcher/test_service.py index ae56766b2f..0c4904a870 100644 --- a/tests/launcher/test_service.py +++ b/tests/launcher/test_service.py @@ -27,7 +27,7 @@ from antarest.core.config import ( Config, - InvalidConfigurationError, + Launcher, LauncherConfig, LocalConfig, NbCoresConfig, @@ -95,7 +95,7 @@ def test_service_run_study(self, get_current_user_mock) -> None: study_id="study_uuid", job_status=JobStatus.PENDING, launcher="local", - launcher_params=LauncherParametersDTO().json(), + launcher_params=LauncherParametersDTO().model_dump_json(), ) repository = Mock() repository.save.return_value = pending @@ -136,12 +136,12 @@ def test_service_run_study(self, get_current_user_mock) -> None: # so we need to compare them manually. mock_call = repository.save.mock_calls[0] actual_obj: JobResult = mock_call.args[0] - assert actual_obj.to_dto().dict() == pending.to_dto().dict() + assert actual_obj.to_dto().model_dump() == pending.to_dto().model_dump() event_bus.push.assert_called_once_with( Event( type=EventType.STUDY_JOB_STARTED, - payload=pending.to_dto().dict(), + payload=pending.to_dto().model_dump(), permissions=PermissionInfo(owner=0), ) ) @@ -492,11 +492,6 @@ def test_service_get_solver_versions( "unknown", {}, id="local-config-unknown", - marks=pytest.mark.xfail( - reason="Configuration is not available for the 'unknown' launcher", - raises=InvalidConfigurationError, - strict=True, - ), ), pytest.param( { @@ -524,11 +519,6 @@ def test_service_get_solver_versions( "unknown", {}, id="slurm-config-unknown", - marks=pytest.mark.xfail( - reason="Configuration is not available for the 'unknown' launcher", - raises=InvalidConfigurationError, - strict=True, - ), ), pytest.param( { @@ -569,10 +559,13 @@ def test_get_nb_cores( ) # Fetch the number of cores - actual = launcher_service.get_nb_cores(solver) - - # Check the result - assert actual == NbCoresConfig(**expected) + try: + actual = launcher_service.get_nb_cores(Launcher(solver)) + except ValueError as e: + assert e.args[0] == f"'{solver}' is not a valid Launcher" + else: + # Check the result + assert actual == NbCoresConfig(**expected) @pytest.mark.unit_test def test_service_kill_job(self, tmp_path: Path) -> None: @@ -896,7 +889,7 @@ def test_save_solver_stats(self, tmp_path: Path) -> None: solver_stats=expected_saved_stats, owner_id=1, ) - assert actual_obj.to_dto().dict() == expected_obj.to_dto().dict() + assert actual_obj.to_dto().model_dump() == expected_obj.to_dto().model_dump() zip_file = tmp_path / "test.zip" with ZipFile(zip_file, "w", ZIP_DEFLATED) as output_data: @@ -913,7 +906,7 @@ def test_save_solver_stats(self, tmp_path: Path) -> None: solver_stats="0\n1", owner_id=1, ) - assert actual_obj.to_dto().dict() == expected_obj.to_dto().dict() + assert actual_obj.to_dto().model_dump() == expected_obj.to_dto().model_dump() @pytest.mark.parametrize( ["running_jobs", "expected_result", "default_launcher"], @@ -996,7 +989,7 @@ def test_get_load( job_repository.get_running.return_value = running_jobs - launcher_expected_result = LauncherLoadDTO.parse_obj(expected_result) + launcher_expected_result = LauncherLoadDTO.model_validate(expected_result) actual_result = launcher_service.get_load() assert launcher_expected_result.launcher_status == actual_result.launcher_status diff --git a/tests/launcher/test_web.py b/tests/launcher/test_web.py index b59edc783b..e1104f87af 100644 --- a/tests/launcher/test_web.py +++ b/tests/launcher/test_web.py @@ -11,7 +11,7 @@ # This file is part of the Antares project. import http -from typing import Dict, List, Union +from typing import List, Union from unittest.mock import Mock, call from uuid import uuid4 @@ -19,6 +19,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTGroup, JWTUser from antarest.core.requests import RequestParameters @@ -35,10 +36,9 @@ def create_app(service: Mock) -> FastAPI: - app = FastAPI(title=__name__) - + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) build_launcher( - app, + build_ctxt, study_service=Mock(), file_transfer_manager=Mock(), task_service=Mock(), @@ -46,7 +46,7 @@ def create_app(service: Mock) -> FastAPI: config=Config(security=SecurityConfig(disabled=True)), cache=Mock(), ) - return app + return build_ctxt.build() @pytest.mark.unit_test @@ -86,7 +86,7 @@ def test_result() -> None: res = client.get(f"/v1/launcher/jobs/{job}") assert res.status_code == 200 - assert JobResultDTO.parse_obj(res.json()) == result.to_dto() + assert JobResultDTO.model_validate(res.json()) == result.to_dto() service.get_result.assert_called_once_with(job, RequestParameters(DEFAULT_ADMIN_USER)) @@ -110,11 +110,11 @@ def test_jobs() -> None: client = TestClient(app) res = client.get(f"/v1/launcher/jobs?study={str(study_id)}") assert res.status_code == 200 - assert [JobResultDTO.parse_obj(j) for j in res.json()] == [result.to_dto()] + assert [JobResultDTO.model_validate(j) for j in res.json()] == [result.to_dto()] res = client.get("/v1/launcher/jobs") assert res.status_code == 200 - assert [JobResultDTO.parse_obj(j) for j in res.json()] == [result.to_dto()] + assert [JobResultDTO.model_validate(j) for j in res.json()] == [result.to_dto()] service.get_jobs.assert_has_calls( [ call( @@ -148,7 +148,7 @@ def test_get_solver_versions() -> None: pytest.param( "", http.HTTPStatus.UNPROCESSABLE_ENTITY, - {"detail": "Unknown solver configuration: ''"}, + "Input should be 'slurm', 'local' or 'default'", id="empty", ), pytest.param("default", http.HTTPStatus.OK, ["1", "2", "3"], id="default"), @@ -157,7 +157,7 @@ def test_get_solver_versions() -> None: pytest.param( "remote", http.HTTPStatus.UNPROCESSABLE_ENTITY, - {"detail": "Unknown solver configuration: 'remote'"}, + "Input should be 'slurm', 'local' or 'default'", id="remote", ), ], @@ -165,7 +165,7 @@ def test_get_solver_versions() -> None: def test_get_solver_versions__with_query_string( solver: str, status_code: http.HTTPStatus, - expected: Union[List[str], Dict[str, str]], + expected: Union[List[str], str], ) -> None: service = Mock() if status_code == http.HTTPStatus.OK: @@ -177,7 +177,12 @@ def test_get_solver_versions__with_query_string( client = TestClient(app) res = client.get(f"/v1/launcher/versions?solver={solver}") assert res.status_code == status_code # OK or UNPROCESSABLE_ENTITY - assert res.json() == expected + if status_code == http.HTTPStatus.OK: + assert res.json() == expected + else: + actual = res.json()["detail"][0] + assert actual["type"] == "enum" + assert actual["msg"] == expected @pytest.mark.unit_test diff --git a/tests/login/test_login_service.py b/tests/login/test_login_service.py index f2393eb5d7..0c4c9b756c 100644 --- a/tests/login/test_login_service.py +++ b/tests/login/test_login_service.py @@ -382,7 +382,7 @@ def test_get_group_info(self, login_service: LoginService) -> None: actual = login_service.get_group_info("superman", _param) assert actual is not None assert actual.name == "Superman" - assert [obj.dict() for obj in actual.users] == [ + assert [obj.model_dump() for obj in actual.users] == [ {"id": 2, "name": "Clark Kent", "role": RoleType.ADMIN}, {"id": 3, "name": "Lois Lane", "role": RoleType.READER}, ] @@ -462,7 +462,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: clark_id = 2 actual = login_service.get_user_info(clark_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": clark_id, "name": "Clark Kent", "roles": [ @@ -480,7 +480,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: lois_id = 3 actual = login_service.get_user_info(lois_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": lois_id, "name": "Lois Lane", "roles": [ @@ -503,7 +503,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: _param = get_user_param(login_service, user_id=lois_id, group_id="superman") actual = login_service.get_user_info(lois_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": lois_id, "name": "Lois Lane", "roles": [ @@ -524,7 +524,7 @@ def test_get_user_info(self, login_service: LoginService) -> None: _param = get_bot_param(login_service, bot_id=bot.id) actual = login_service.get_user_info(lois_id, _param) assert actual is not None - assert actual.dict() == { + assert actual.model_dump() == { "id": lois_id, "name": "Lois Lane", "roles": [ @@ -578,13 +578,13 @@ def test_get_bot_info(self, login_service: LoginService) -> None: _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_bot_info(joh_bot.id, _param) assert actual is not None - assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + assert actual.model_dump() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} # Joh Fredersen can get its own bot _param = get_user_param(login_service, user_id=joh_id, group_id="superman") actual = login_service.get_bot_info(joh_bot.id, _param) assert actual is not None - assert actual.dict() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} + assert actual.model_dump() == {"id": 6, "isAuthor": True, "name": "Maria", "roles": []} # The bot cannot get itself _param = get_bot_param(login_service, bot_id=joh_bot.id) @@ -613,13 +613,13 @@ def test_get_all_bots_by_owner(self, login_service: LoginService) -> None: _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_bots_by_owner(joh_id, _param) expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] - assert [obj.to_dto().dict() for obj in actual] == expected + assert [obj.to_dto().model_dump() for obj in actual] == expected # Freder Fredersen can get its own bot _param = get_user_param(login_service, user_id=joh_id, group_id="superman") actual = login_service.get_all_bots_by_owner(joh_id, _param) expected = [{"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}] - assert [obj.to_dto().dict() for obj in actual] == expected + assert [obj.to_dto().model_dump() for obj in actual] == expected # The bot cannot get itself _param = get_bot_param(login_service, bot_id=joh_bot.id) @@ -730,7 +730,7 @@ def test_get_all_groups(self, login_service: LoginService) -> None: # The site admin can get all groups _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_groups(_param) - assert [g.dict() for g in actual] == [ + assert [g.model_dump() for g in actual] == [ {"id": "admin", "name": "X-Men"}, {"id": "superman", "name": "Superman"}, {"id": "metropolis", "name": "Metropolis"}, @@ -739,19 +739,19 @@ def test_get_all_groups(self, login_service: LoginService) -> None: # The group admin can its own groups _param = get_user_param(login_service, user_id=2, group_id="superman") actual = login_service.get_all_groups(_param) - assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] + assert [g.model_dump() for g in actual] == [{"id": "superman", "name": "Superman"}] # The user can get its own groups _param = get_user_param(login_service, user_id=3, group_id="superman") actual = login_service.get_all_groups(_param) - assert [g.dict() for g in actual] == [{"id": "superman", "name": "Superman"}] + assert [g.model_dump() for g in actual] == [{"id": "superman", "name": "Superman"}] @with_db_context def test_get_all_users(self, login_service: LoginService) -> None: # The site admin can get all users _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_users(_param) - assert [u.dict() for u in actual] == [ + assert [u.model_dump() for u in actual] == [ {"id": 1, "name": "Professor Xavier"}, {"id": 2, "name": "Clark Kent"}, {"id": 3, "name": "Lois Lane"}, @@ -763,7 +763,7 @@ def test_get_all_users(self, login_service: LoginService) -> None: # note: I don't know why the group admin can get all users -- Laurent _param = get_user_param(login_service, user_id=2, group_id="superman") actual = login_service.get_all_users(_param) - assert [u.dict() for u in actual] == [ + assert [u.model_dump() for u in actual] == [ {"id": 1, "name": "Professor Xavier"}, {"id": 2, "name": "Clark Kent"}, {"id": 3, "name": "Lois Lane"}, @@ -774,7 +774,7 @@ def test_get_all_users(self, login_service: LoginService) -> None: # The user can get its own users _param = get_user_param(login_service, user_id=3, group_id="superman") actual = login_service.get_all_users(_param) - assert [u.dict() for u in actual] == [ + assert [u.model_dump() for u in actual] == [ {"id": 2, "name": "Clark Kent"}, {"id": 3, "name": "Lois Lane"}, ] @@ -789,7 +789,7 @@ def test_get_all_bots(self, login_service: LoginService) -> None: # The site admin can get all bots _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_bots(_param) - assert [b.to_dto().dict() for b in actual] == [ + assert [b.to_dto().model_dump() for b in actual] == [ {"id": joh_bot.id, "is_author": True, "name": "Maria", "owner": joh_id}, ] @@ -808,7 +808,7 @@ def test_get_all_roles_in_group(self, login_service: LoginService) -> None: # The site admin can get all roles in a given group _param = get_user_param(login_service, user_id=ADMIN_ID, group_id="admin") actual = login_service.get_all_roles_in_group("superman", _param) - assert [b.to_dto().dict() for b in actual] == [ + assert [b.to_dto().model_dump() for b in actual] == [ { "group": {"id": "superman", "name": "Superman"}, "identity": {"id": 2, "name": "Clark Kent"}, @@ -824,7 +824,7 @@ def test_get_all_roles_in_group(self, login_service: LoginService) -> None: # The group admin can get all roles his own group _param = get_user_param(login_service, user_id=2, group_id="superman") actual = login_service.get_all_roles_in_group("superman", _param) - assert [b.to_dto().dict() for b in actual] == [ + assert [b.to_dto().model_dump() for b in actual] == [ { "group": {"id": "superman", "name": "Superman"}, "identity": {"id": 2, "name": "Clark Kent"}, diff --git a/tests/login/test_web.py b/tests/login/test_web.py index fe1cda2230..389dc94cf7 100644 --- a/tests/login/test_web.py +++ b/tests/login/test_web.py @@ -19,12 +19,13 @@ import pytest from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT from starlette.testclient import TestClient +from antarest.core.application import AppBuildContext, create_app_ctxt from antarest.core.config import Config, SecurityConfig from antarest.core.jwt import JWTGroup, JWTUser from antarest.core.requests import RequestParameters +from antarest.fastapi_jwt_auth import AuthJWT from antarest.login.main import build_login from antarest.login.model import ( Bot, @@ -63,15 +64,16 @@ def get_config(): authjwt_token_location=("headers", "cookies"), ) + app_ctxt = create_app_ctxt(app) build_login( - app, + app_ctxt, service=service, config=Config( resources_path=Path(), security=SecurityConfig(disabled=auth_disabled), ), ) - return app + return app_ctxt.build() class TokenType: @@ -189,7 +191,7 @@ def test_user() -> None: client = TestClient(app) res = client.get("/v1/users", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [User(id=1, name="user").to_dto().dict()] + assert res.json() == [User(id=1, name="user").to_dto().model_dump()] @pytest.mark.unit_test @@ -201,7 +203,7 @@ def test_user_id() -> None: client = TestClient(app) res = client.get("/v1/users/1", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == User(id=1, name="user").to_dto().dict() + assert res.json() == User(id=1, name="user").to_dto().model_dump() @pytest.mark.unit_test @@ -213,7 +215,7 @@ def test_user_id_with_details() -> None: client = TestClient(app) res = client.get("/v1/users/1?details=true", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == IdentityDTO(id=1, name="user", roles=[]).dict() + assert res.json() == IdentityDTO(id=1, name="user", roles=[]).model_dump() @pytest.mark.unit_test @@ -228,12 +230,12 @@ def test_user_create() -> None: res = client.post( "/v1/users", headers=create_auth_token(app), - json=user.dict(), + json=user.model_dump(), ) assert res.status_code == 200 service.create_user.assert_called_once_with(user, PARAMS) - assert res.json() == user_id.to_dto().dict() + assert res.json() == user_id.to_dto().model_dump() @pytest.mark.unit_test @@ -244,7 +246,7 @@ def test_user_save() -> None: app = create_app(service) client = TestClient(app) - user_obj = user.to_dto().dict() + user_obj = user.to_dto().model_dump() res = client.put( "/v1/users/0", headers=create_auth_token(app), @@ -256,7 +258,7 @@ def test_user_save() -> None: assert service.save_user.call_count == 1 call = service.save_user.call_args_list[0] - assert call[0][0].to_dto().dict() == user_obj + assert call[0][0].to_dto().model_dump() == user_obj assert call[0][1] == PARAMS @@ -281,7 +283,7 @@ def test_group() -> None: client = TestClient(app) res = client.get("/v1/groups", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [Group(id="my-group", name="group").to_dto().dict()] + assert res.json() == [Group(id="my-group", name="group").to_dto().model_dump()] @pytest.mark.unit_test @@ -293,7 +295,7 @@ def test_group_id() -> None: client = TestClient(app) res = client.get("/v1/groups/1", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == Group(id="my-group", name="group").to_dto().dict() + assert res.json() == Group(id="my-group", name="group").to_dto().model_dump() @pytest.mark.unit_test @@ -311,7 +313,7 @@ def test_group_create() -> None: ) assert res.status_code == 200 - assert res.json() == group.to_dto().dict() + assert res.json() == group.to_dto().model_dump() @pytest.mark.unit_test @@ -341,7 +343,7 @@ def test_role() -> None: client = TestClient(app) res = client.get("/v1/roles/group/g", headers=create_auth_token(app)) assert res.status_code == 200 - assert [RoleDetailDTO.parse_obj(el) for el in res.json()] == [role.to_dto()] + assert [RoleDetailDTO.model_validate(el) for el in res.json()] == [role.to_dto()] @pytest.mark.unit_test @@ -363,7 +365,7 @@ def test_role_create() -> None: ) assert res.status_code == 200 - assert RoleDetailDTO.parse_obj(res.json()) == role.to_dto().dict() + assert RoleDetailDTO.model_validate(res.json()) == role.to_dto() @pytest.mark.unit_test @@ -403,10 +405,10 @@ def test_bot_create() -> None: service.save_bot.return_value = bot service.get_group.return_value = Group(id="group", name="group") - print(create.json()) + create.model_dump_json() app = create_app(service) client = TestClient(app) - res = client.post("/v1/bots", headers=create_auth_token(app), json=create.dict()) + res = client.post("/v1/bots", headers=create_auth_token(app), json=create.model_dump()) assert res.status_code == 200 assert len(res.json().split(".")) == 3 @@ -422,7 +424,7 @@ def test_bot() -> None: client = TestClient(app) res = client.get("/v1/bots/0", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == bot.to_dto().dict() + assert res.json() == bot.to_dto().model_dump() @pytest.mark.unit_test @@ -436,11 +438,11 @@ def test_all_bots() -> None: client = TestClient(app) res = client.get("/v1/bots", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [b.to_dto().dict() for b in bots] + assert res.json() == [b.to_dto().model_dump() for b in bots] res = client.get("/v1/bots?owner=4", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == [b.to_dto().dict() for b in bots] + assert res.json() == [b.to_dto().model_dump() for b in bots] service.get_all_bots.assert_called_once() service.get_all_bots_by_owner.assert_called_once() diff --git a/tests/matrixstore/test_matrix_editor.py b/tests/matrixstore/test_matrix_editor.py index ad46b316fa..229e9e53f3 100644 --- a/tests/matrixstore/test_matrix_editor.py +++ b/tests/matrixstore/test_matrix_editor.py @@ -82,7 +82,7 @@ class TestMatrixSlice: ) def test_init(self, kwargs: Dict[str, Any], expected: Dict[str, Any]) -> None: obj = MatrixSlice(**kwargs) - assert obj.dict(by_alias=False) == expected + assert obj.model_dump(by_alias=False) == expected class TestOperation: @@ -109,12 +109,12 @@ class TestOperation: ) def test_init(self, kwargs: Dict[str, Any], expected: Dict[str, Any]) -> None: obj = Operation(**kwargs) - assert obj.dict(by_alias=False) == expected + assert obj.model_dump(by_alias=False) == expected @pytest.mark.parametrize("operation", list(OPERATIONS)) def test_init__valid_operation(self, operation: str) -> None: obj = Operation(operation=operation, value=123) - assert obj.dict(by_alias=False) == { + assert obj.model_dump(by_alias=False) == { "operation": operation, "value": 123.0, } @@ -204,4 +204,4 @@ class TestMatrixEditInstruction: ) def test_init(self, kwargs: Dict[str, Any], expected: Dict[str, Any]) -> None: obj = MatrixEditInstruction(**kwargs) - assert obj.dict(by_alias=False) == expected + assert obj.model_dump(by_alias=False) == expected diff --git a/tests/matrixstore/test_service.py b/tests/matrixstore/test_service.py index 584db35013..65ce952e1d 100644 --- a/tests/matrixstore/test_service.py +++ b/tests/matrixstore/test_service.py @@ -497,12 +497,8 @@ def test_dataset_lifecycle() -> None: dataset_repo.delete.assert_called_once() -def _create_upload_file(filename: str, file: io.BytesIO, content_type: str = "") -> UploadFile: - if hasattr(UploadFile, "content_type"): - # `content_type` attribute was replace by a read-ony property in starlette-v0.24. - headers = Headers(headers={"content-type": content_type}) - # noinspection PyTypeChecker,PyArgumentList - return UploadFile(filename=filename, file=file, headers=headers) - else: - # noinspection PyTypeChecker,PyArgumentList - return UploadFile(filename=filename, file=file, content_type=content_type) +def _create_upload_file(filename: str, file: t.IO = None, content_type: str = "") -> UploadFile: + # `content_type` attribute was replace by a read-ony property in starlette-v0.24. + headers = Headers(headers={"content-type": content_type}) + # noinspection PyTypeChecker,PyArgumentList + return UploadFile(filename=filename, file=file, headers=headers) diff --git a/tests/matrixstore/test_web.py b/tests/matrixstore/test_web.py index 2879b07bf4..d47fb030bc 100644 --- a/tests/matrixstore/test_web.py +++ b/tests/matrixstore/test_web.py @@ -15,10 +15,11 @@ import pytest from fastapi import FastAPI -from fastapi_jwt_auth import AuthJWT from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig +from antarest.fastapi_jwt_auth import AuthJWT from antarest.main import JwtSettings from antarest.matrixstore.main import build_matrix_service from antarest.matrixstore.model import MatrixDTO, MatrixInfoDTO @@ -26,7 +27,7 @@ def create_app(service: Mock, auth_disabled=False) -> FastAPI: - app = FastAPI(title=__name__) + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) @AuthJWT.load_config def get_config(): @@ -37,7 +38,7 @@ def get_config(): ) build_matrix_service( - app, + build_ctxt, user_service=Mock(), file_transfer_manager=Mock(), task_service=Mock(), @@ -47,7 +48,7 @@ def get_config(): security=SecurityConfig(disabled=auth_disabled), ), ) - return app + return build_ctxt.build() @pytest.mark.unit_test @@ -74,7 +75,7 @@ def test_create() -> None: json=matrix_data, ) assert res.status_code == 200 - assert res.json() == matrix.dict() + assert res.json() == matrix.model_dump() @pytest.mark.unit_test @@ -96,7 +97,7 @@ def test_get() -> None: client = TestClient(app) res = client.get("/v1/matrix/123", headers=create_auth_token(app)) assert res.status_code == 200 - assert res.json() == matrix.dict() + assert res.json() == matrix.model_dump() service.get.assert_called_once_with("123") @@ -126,4 +127,4 @@ def test_import() -> None: files={"file": ("Matrix.zip", bytes(5), "application/zip")}, ) assert res.status_code == 200 - assert res.json() == matrix_info + assert [MatrixInfoDTO.model_validate(res.json()[0])] == matrix_info diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index a5e94e983c..cb1d29c971 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -174,35 +174,36 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): variant_id, [ CommandDTO( + id=None, action=CommandName.UPDATE_CONFIG.value, args=[ { "target": "input/areas/test/ui/ui/x", - "data": "100", + "data": 100, }, { "target": "input/areas/test/ui/ui/y", - "data": "200", + "data": 200, }, { "target": "input/areas/test/ui/ui/color_r", - "data": "255", + "data": 255, }, { "target": "input/areas/test/ui/ui/color_g", - "data": "0", + "data": 0, }, { "target": "input/areas/test/ui/ui/color_b", - "data": "100", + "data": 100, }, { "target": "input/areas/test/ui/layerX/0", - "data": "100", + "data": 100, }, { "target": "input/areas/test/ui/layerY/0", - "data": "200", + "data": 200, }, { "target": "input/areas/test/ui/layerColor/0", @@ -302,7 +303,7 @@ def test_get_all_area(): area_manager.patch_service = Mock() area_manager.patch_service.get.return_value = Patch( areas={"a1": PatchArea(country="fr")}, - thermal_clusters={"a1.a": PatchCluster.parse_obj({"code-oi": "1"})}, + thermal_clusters={"a1.a": PatchCluster.model_validate({"code-oi": "1"})}, ) file_tree_mock.get.side_effect = [ { @@ -362,7 +363,7 @@ def test_get_all_area(): }, ] areas = area_manager.get_all_areas(study, AreaType.AREA) - assert expected_areas == [area.dict() for area in areas] + assert expected_areas == [area.model_dump() for area in areas] expected_clusters = [ { @@ -375,7 +376,7 @@ def test_get_all_area(): } ] clusters = area_manager.get_all_areas(study, AreaType.DISTRICT) - assert expected_clusters == [area.dict() for area in clusters] + assert expected_clusters == [area.model_dump() for area in clusters] file_tree_mock.get.side_effect = [{}, {}, {}] expected_all = [ @@ -413,14 +414,14 @@ def test_get_all_area(): }, ] all_areas = area_manager.get_all_areas(study) - assert expected_all == [area.dict() for area in all_areas] + assert expected_all == [area.model_dump() for area in all_areas] links = link_manager.get_all_links(study) assert [ {"area1": "a1", "area2": "a2", "ui": None}, {"area1": "a1", "area2": "a3", "ui": None}, {"area1": "a2", "area2": "a3", "ui": None}, - ] == [link.dict() for link in links] + ] == [link.model_dump() for link in links] def test_update_area(): @@ -463,7 +464,7 @@ def test_update_area(): new_area_info = area_manager.update_area_metadata(study, "a1", PatchArea(country="fr")) assert new_area_info.id == "a1" - assert new_area_info.metadata == {"country": "fr", "tags": []} + assert new_area_info.metadata.model_dump() == {"country": "fr", "tags": []} def test_update_clusters(): @@ -496,7 +497,7 @@ def test_update_clusters(): area_manager.patch_service = Mock() area_manager.patch_service.get.return_value = Patch( areas={"a1": PatchArea(country="fr")}, - thermal_clusters={"a1.a": PatchCluster.parse_obj({"code-oi": "1"})}, + thermal_clusters={"a1.a": PatchCluster.model_validate({"code-oi": "1"})}, ) file_tree_mock.get.side_effect = [ { diff --git a/tests/storage/business/test_config_manager.py b/tests/storage/business/test_config_manager.py index d35bcd78f2..f89a4438be 100644 --- a/tests/storage/business/test_config_manager.py +++ b/tests/storage/business/test_config_manager.py @@ -31,7 +31,7 @@ def test_thematic_trimming_config() -> None: - command_context = CommandContext.construct() + command_context = CommandContext.model_construct() command_factory_mock = Mock() command_factory_mock.command_context = command_context raw_study_service = Mock(spec=RawStudyService) @@ -66,27 +66,27 @@ def test_thematic_trimming_config() -> None: study.version = config.version = 700 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) assert actual == expected study.version = config.version = 800 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) expected.avl_dtg = False assert actual == expected study.version = config.version = 820 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) expected.avl_dtg = False assert actual == expected study.version = config.version = 830 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, True)) expected.dens = False expected.profit_by_plant = False @@ -94,7 +94,7 @@ def test_thematic_trimming_config() -> None: study.version = config.version = 840 actual = thematic_trimming_manager.get_field_values(study) - fields_info = get_fields_info(study.version) + fields_info = get_fields_info(int(study.version)) expected = ThematicTrimmingFormFields(**dict.fromkeys(fields_info, False)) expected.cong_fee_alg = True assert actual == expected diff --git a/tests/storage/business/test_patch_service.py b/tests/storage/business/test_patch_service.py index 04a2b8c2dd..3abb8790f0 100644 --- a/tests/storage/business/test_patch_service.py +++ b/tests/storage/business/test_patch_service.py @@ -194,7 +194,7 @@ def test_set_output_ref(self, tmp_path: Path): additional_data=StudyAdditionalData( author="john.doe", horizon="foo-horizon", - patch=patch_outputs.json(), + patch=patch_outputs.model_dump_json(), ), archived=False, owner=None, diff --git a/tests/storage/business/test_timeseries_config_manager.py b/tests/storage/business/test_timeseries_config_manager.py index c93c3a597d..3bc7cd0ad9 100644 --- a/tests/storage/business/test_timeseries_config_manager.py +++ b/tests/storage/business/test_timeseries_config_manager.py @@ -56,7 +56,7 @@ def file_study_720(tmpdir: Path) -> FileStudy: def test_ts_field_values(file_study_820: FileStudy, file_study_720: FileStudy): command_factory_mock = Mock() - command_factory_mock.command_context = CommandContext.construct() + command_factory_mock.command_context = CommandContext.model_construct() raw_study_service = Mock(spec=RawStudyService) diff --git a/tests/storage/business/test_xpansion_manager.py b/tests/storage/business/test_xpansion_manager.py index c1b5e7de62..5c08c72242 100644 --- a/tests/storage/business/test_xpansion_manager.py +++ b/tests/storage/business/test_xpansion_manager.py @@ -209,7 +209,7 @@ def test_get_xpansion_settings(tmp_path: Path, version: int, expected_output: JS xpansion_manager.create_xpansion_configuration(study) actual = xpansion_manager.get_xpansion_settings(study) - assert actual.dict(by_alias=True) == expected_output + assert actual.model_dump(by_alias=True) == expected_output @pytest.mark.unit_test @@ -256,7 +256,7 @@ def test_update_xpansion_settings(tmp_path: Path) -> None: "timelimit": int(1e12), "sensitivity_config": {"epsilon": 10500.0, "projection": ["foo"], "capex": False}, } - assert actual.dict(by_alias=True) == expected + assert actual.model_dump(by_alias=True) == expected @pytest.mark.unit_test @@ -266,7 +266,7 @@ def test_add_candidate(tmp_path: Path) -> None: actual = empty_study.tree.get(["user", "expansion", "candidates"]) assert actual == {} - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -275,7 +275,7 @@ def test_add_candidate(tmp_path: Path) -> None: } ) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -296,13 +296,13 @@ def test_add_candidate(tmp_path: Path) -> None: xpansion_manager.add_candidate(study, new_candidate) - candidates = {"1": new_candidate.dict(by_alias=True, exclude_none=True)} + candidates = {"1": new_candidate.model_dump(by_alias=True, exclude_none=True)} actual = empty_study.tree.get(["user", "expansion", "candidates"]) assert actual == candidates xpansion_manager.add_candidate(study, new_candidate2) - candidates["2"] = new_candidate2.dict(by_alias=True, exclude_none=True) + candidates["2"] = new_candidate2.model_dump(by_alias=True, exclude_none=True) actual = empty_study.tree.get(["user", "expansion", "candidates"]) assert actual == candidates @@ -314,7 +314,7 @@ def test_get_candidate(tmp_path: Path) -> None: assert empty_study.tree.get(["user", "expansion", "candidates"]) == {} - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -323,7 +323,7 @@ def test_get_candidate(tmp_path: Path) -> None: } ) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -347,7 +347,7 @@ def test_get_candidates(tmp_path: Path) -> None: assert empty_study.tree.get(["user", "expansion", "candidates"]) == {} - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -356,7 +356,7 @@ def test_get_candidates(tmp_path: Path) -> None: } ) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -384,7 +384,7 @@ def test_update_candidates(tmp_path: Path) -> None: make_link_and_areas(empty_study) - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -394,7 +394,7 @@ def test_update_candidates(tmp_path: Path) -> None: ) xpansion_manager.add_candidate(study, new_candidate) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -415,7 +415,7 @@ def test_delete_candidate(tmp_path: Path) -> None: make_link_and_areas(empty_study) - new_candidate = XpansionCandidateDTO.parse_obj( + new_candidate = XpansionCandidateDTO.model_validate( { "name": "candidate_1", "link": "area1 - area2", @@ -425,7 +425,7 @@ def test_delete_candidate(tmp_path: Path) -> None: ) xpansion_manager.add_candidate(study, new_candidate) - new_candidate2 = XpansionCandidateDTO.parse_obj( + new_candidate2 = XpansionCandidateDTO.model_validate( { "name": "candidate_2", "link": "area1 - area2", @@ -500,14 +500,14 @@ def test_add_resources(tmp_path: Path) -> None: settings = xpansion_manager.get_xpansion_settings(study) settings.yearly_weights = filename3 - update_settings = UpdateXpansionSettings(**settings.dict()) + update_settings = UpdateXpansionSettings(**settings.model_dump()) xpansion_manager.update_xpansion_settings(study, update_settings) with pytest.raises(FileCurrentlyUsedInSettings): xpansion_manager.delete_resource(study, XpansionResourceFileType.WEIGHTS, filename3) settings.yearly_weights = "" - update_settings = UpdateXpansionSettings(**settings.dict()) + update_settings = UpdateXpansionSettings(**settings.model_dump()) xpansion_manager.update_xpansion_settings(study, update_settings) xpansion_manager.delete_resource(study, XpansionResourceFileType.WEIGHTS, filename3) diff --git a/tests/storage/integration/conftest.py b/tests/storage/integration/conftest.py index 1136f3ca8a..dcf7e5e830 100644 --- a/tests/storage/integration/conftest.py +++ b/tests/storage/integration/conftest.py @@ -109,7 +109,7 @@ def storage_service(tmp_path: Path, project_path: Path, sta_mini_zip_path: Path) ) matrix_service = SimpleMatrixService(matrix_content_repository=matrix_content_repository) storage_service = build_study_service( - application=Mock(), + app_ctxt=Mock(), cache=LocalCache(config=config.cache), file_transfer_manager=Mock(), task_service=task_service_mock, diff --git a/tests/storage/integration/test_STA_mini.py b/tests/storage/integration/test_STA_mini.py index ca7904a075..35aaa4092d 100644 --- a/tests/storage/integration/test_STA_mini.py +++ b/tests/storage/integration/test_STA_mini.py @@ -22,6 +22,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTGroup, JWTUser from antarest.core.model import JSON from antarest.core.requests import RequestParameters @@ -46,19 +47,23 @@ ) -def assert_url_content(storage_service: StudyService, url: str, expected_output: dict) -> None: - app = FastAPI(title=__name__) +def create_test_client(service: StudyService) -> TestClient: + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) build_study_service( - app, + build_ctxt, cache=Mock(), user_service=Mock(), task_service=Mock(), file_transfer_manager=Mock(), - study_service=storage_service, + study_service=service, matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, + config=service.storage_service.raw_study_service.config, ) - client = TestClient(app) + return TestClient(build_ctxt.build()) + + +def assert_url_content(storage_service: StudyService, url: str, expected_output: dict) -> None: + client = create_test_client(storage_service) res = client.get(url) assert_study(res.json(), expected_output) @@ -493,18 +498,7 @@ def test_sta_mini_copy(storage_service) -> None: source_study_name = UUID destination_study_name = "copy-STA-mini" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - user_service=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.post(f"/v1/studies/{source_study_name}/copy?dest={destination_study_name}&use_task=false") assert result.status_code == HTTPStatus.CREATED.value @@ -590,18 +584,7 @@ def test_sta_mini_import(tmp_path: Path, storage_service) -> None: sta_mini_zip_filepath = shutil.make_archive(tmp_path, "zip", path_study) sta_mini_zip_path = Path(sta_mini_zip_filepath) - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, - ) - client = TestClient(app) + client = create_test_client(storage_service) study_data = io.BytesIO(sta_mini_zip_path.read_bytes()) result = client.post("/v1/studies/_import", files={"study": study_data}) @@ -620,18 +603,7 @@ def test_sta_mini_import_output(tmp_path: Path, storage_service) -> None: sta_mini_output_zip_path = Path(sta_mini_output_zip_filepath) - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, - ) - client = TestClient(app) + client = create_test_client(storage_service) study_output_data = io.BytesIO(sta_mini_output_zip_path.read_bytes()) result = client.post( diff --git a/tests/storage/integration/test_exporter.py b/tests/storage/integration/test_exporter.py index 2da0686302..46e077a9ca 100644 --- a/tests/storage/integration/test_exporter.py +++ b/tests/storage/integration/test_exporter.py @@ -21,6 +21,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig, StorageConfig, WorkspaceConfig from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.jwt import DEFAULT_ADMIN_USER @@ -54,10 +55,10 @@ def assert_url_content(url: str, tmp_dir: Path, sta_mini_zip_path: Path) -> byte repo = Mock() repo.get.return_value = md - app = FastAPI(title=__name__) + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) ftm = SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_dir))) build_study_service( - app, + build_ctxt, cache=Mock(), user_service=Mock(), task_service=SimpleSyncTaskService(), @@ -69,7 +70,7 @@ def assert_url_content(url: str, tmp_dir: Path, sta_mini_zip_path: Path) -> byte ) # Simulate the download of data using a streamed request - client = TestClient(app) + client = TestClient(build_ctxt.build()) if client.stream is False: # `TestClient` is based on `Requests` (old way before AntaREST-v2.15) # noinspection PyArgumentList diff --git a/tests/storage/rawstudies/test_factory.py b/tests/storage/rawstudies/test_factory.py index 62f9c70519..80b2a2e19c 100644 --- a/tests/storage/rawstudies/test_factory.py +++ b/tests/storage/rawstudies/test_factory.py @@ -65,8 +65,8 @@ def test_factory_cache() -> None: cache.get.return_value = None study = factory.create_from_fs(path, study_id) assert study.config == config - cache.put.assert_called_once_with(cache_id, FileStudyTreeConfigDTO.from_build_config(config).dict()) + cache.put.assert_called_once_with(cache_id, FileStudyTreeConfigDTO.from_build_config(config).model_dump()) - cache.get.return_value = FileStudyTreeConfigDTO.from_build_config(config).dict() + cache.get.return_value = FileStudyTreeConfigDTO.from_build_config(config).model_dump() study = factory.create_from_fs(path, study_id) assert study.config == config diff --git a/tests/storage/repository/filesystem/config/test_config_files.py b/tests/storage/repository/filesystem/config/test_config_files.py index dbd62de305..a0fdf6d4f5 100644 --- a/tests/storage/repository/filesystem/config/test_config_files.py +++ b/tests/storage/repository/filesystem/config/test_config_files.py @@ -415,7 +415,7 @@ def test_parse_thermal_860(study_path: Path, version, caplog) -> None: assert not caplog.text else: expected = [ThermalConfig(id="t1", name="t1")] - assert "extra fields not permitted" in caplog.text + assert "Extra inputs are not permitted" in caplog.text assert actual == expected diff --git a/tests/storage/test_model.py b/tests/storage/test_model.py index c94e437b86..f59ce30137 100644 --- a/tests/storage/test_model.py +++ b/tests/storage/test_model.py @@ -67,5 +67,5 @@ def test_file_study_tree_config_dto(): enr_modelling="aggregated", ) config_dto = FileStudyTreeConfigDTO.from_build_config(config) - assert sorted(list(config_dto.dict()) + ["cache"]) == sorted(list(config.__dict__)) + assert sorted(list(config_dto.model_dump()) + ["cache"]) == sorted(list(config.__dict__)) assert config_dto.to_build_config() == config diff --git a/tests/storage/test_service.py b/tests/storage/test_service.py index e10cff9df6..47ba9fe49c 100644 --- a/tests/storage/test_service.py +++ b/tests/storage/test_service.py @@ -597,7 +597,7 @@ def test_download_output() -> None: name="east", type=StudyDownloadType.AREA, data={ - 1: [ + "1": [ TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5]), TimeSerie(name="some cluster", unit="Euro/MWh", data=[0.8]), ] @@ -661,7 +661,7 @@ def test_download_output() -> None: TimeSeriesData( name="east^west", type=StudyDownloadType.LINK, - data={1: [TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5])]}, + data={"1": [TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5])]}, ) ], warnings=[], @@ -695,7 +695,7 @@ def test_download_output() -> None: name="north", type=StudyDownloadType.DISTRICT, data={ - 1: [ + "1": [ TimeSerie(name="H. VAL", unit="Euro/MWh", data=[0.5]), TimeSerie(name="some cluster", unit="Euro/MWh", data=[0.8]), ] @@ -1391,7 +1391,7 @@ def test_unarchive_output(tmp_path: Path) -> None: src=str(tmp_path / "output" / f"{output_id}.zip"), dest=str(tmp_path / "output" / output_id), remove_src=False, - ).dict(), + ).model_dump(), name=f"Unarchive output {study_name}/{output_id} ({study_id})", ref_id=study_id, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), @@ -1522,7 +1522,7 @@ def test_archive_output_locks(tmp_path: Path) -> None: src=str(tmp_path / "output" / f"{output_id}.zip"), dest=str(tmp_path / "output" / output_id), remove_src=False, - ).dict(), + ).model_dump(), name=f"Unarchive output {study_name}/{output_id} ({study_id})", ref_id=study_id, request_params=RequestParameters(user=DEFAULT_ADMIN_USER), diff --git a/tests/storage/web/test_studies_bp.py b/tests/storage/web/test_studies_bp.py index bb538fdb51..33cedcc8df 100644 --- a/tests/storage/web/test_studies_bp.py +++ b/tests/storage/web/test_studies_bp.py @@ -24,9 +24,11 @@ from markupsafe import Markup from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig, StorageConfig, WorkspaceConfig from antarest.core.exceptions import UrlNotMatchJsonDataError from antarest.core.filetransfer.model import FileDownloadDTO, FileDownloadTaskDTO +from antarest.core.filetransfer.service import FileTransferManager from antarest.core.jwt import JWTGroup, JWTUser from antarest.core.requests import RequestParameters from antarest.core.roles import RoleType @@ -48,6 +50,7 @@ TimeSerie, TimeSeriesData, ) +from antarest.study.service import StudyService from tests.storage.conftest import SimpleFileTransferManager from tests.storage.integration.conftest import UUID @@ -66,23 +69,29 @@ ) -@pytest.mark.unit_test -def test_server() -> None: - mock_service = Mock() - mock_service.get.return_value = {} - - app = FastAPI(title=__name__) +def create_test_client( + service: StudyService, file_transfer_manager: FileTransferManager = Mock(), raise_server_exceptions: bool = True +) -> TestClient: + app_ctxt = create_app_ctxt(FastAPI(title=__name__)) build_study_service( - app, + app_ctxt, cache=Mock(), task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, + file_transfer_manager=file_transfer_manager, + study_service=service, config=CONFIG, user_service=Mock(), matrix_service=Mock(spec=MatrixService), ) - client = TestClient(app) + return TestClient(app_ctxt.build(), raise_server_exceptions=raise_server_exceptions) + + +@pytest.mark.unit_test +def test_server() -> None: + mock_service = Mock() + mock_service.get.return_value = {} + + client = create_test_client(mock_service) client.get("/v1/studies/study1/raw?path=settings/general/params") mock_service.get.assert_called_once_with( @@ -95,18 +104,7 @@ def test_404() -> None: mock_storage_service = Mock() mock_storage_service.get.side_effect = UrlNotMatchJsonDataError("Test") - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(mock_storage_service, raise_server_exceptions=False) result = client.get("/v1/studies/study1/raw?path=settings/general/params") assert result.status_code == HTTPStatus.NOT_FOUND @@ -119,18 +117,7 @@ def test_server_with_parameters() -> None: mock_storage_service = Mock() mock_storage_service.get.return_value = {} - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) result = client.get("/v1/studies/study1/raw?depth=4") parameters = RequestParameters(user=ADMIN) @@ -158,18 +145,7 @@ def test_create_study(tmp_path: str, project_path) -> None: storage_service = Mock() storage_service.create_study.return_value = "my-uuid" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result_right = client.post("/v1/studies?name=study2") @@ -193,18 +169,7 @@ def test_import_study_zipped(tmp_path: Path, project_path) -> None: study_uuid = str(uuid.uuid4()) mock_storage_service.import_study.return_value = study_uuid - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) result = client.post("/v1/studies") @@ -223,18 +188,7 @@ def test_copy_study(tmp_path: Path) -> None: storage_service = Mock() storage_service.copy_study.return_value = "/studies/study-copied" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.post(f"/v1/studies/{UUID}/copy?dest=study-copied") @@ -285,21 +239,10 @@ def test_list_studies(tmp_path: str) -> None: storage_service = Mock() storage_service.get_studies_information.return_value = studies - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.get("/v1/studies") - assert {k: StudyMetadataDTO.parse_obj(v) for k, v in result.json().items()} == studies + assert {k: StudyMetadataDTO.model_validate(v) for k, v in result.json().items()} == studies def test_study_metadata(tmp_path: str) -> None: @@ -320,21 +263,10 @@ def test_study_metadata(tmp_path: str) -> None: storage_service = Mock() storage_service.get_study_information.return_value = study - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.get("/v1/studies/1") - assert StudyMetadataDTO.parse_obj(result.json()) == study + assert StudyMetadataDTO.model_validate(result.json()) == study @pytest.mark.unit_test @@ -352,20 +284,8 @@ def test_export_files(tmp_path: Path) -> None: ) mock_storage_service.export_study.return_value = expected - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - # Simulate the download of data using a streamed request - client = TestClient(app) + client = create_test_client(mock_storage_service) if client.stream is False: # `TestClient` is based on `Requests` (old way before AntaREST-v2.15) # noinspection PyArgumentList @@ -382,7 +302,7 @@ def test_export_files(tmp_path: Path) -> None: res.raise_for_status() result = json.loads(data.getvalue()) - assert FileDownloadTaskDTO(**result).json() == expected.json() + assert FileDownloadTaskDTO(**result).model_dump_json() == expected.model_dump_json() mock_storage_service.export_study.assert_called_once_with(UUID, PARAMS, True) @@ -402,18 +322,7 @@ def test_export_params(tmp_path: Path) -> None: ) mock_storage_service.export_study.return_value = expected - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) client.get(f"/v1/studies/{UUID}/export?no_output=true") client.get(f"/v1/studies/{UUID}/export?no_output=false") mock_storage_service.export_study.assert_has_calls( @@ -428,18 +337,7 @@ def test_export_params(tmp_path: Path) -> None: def test_delete_study() -> None: mock_storage_service = Mock() - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) study_uuid = "8319b5f8-2a35-4984-9ace-2ab072bd6eef" client.delete(f"/v1/studies/{study_uuid}") @@ -452,18 +350,7 @@ def test_edit_study() -> None: mock_storage_service = Mock() mock_storage_service.edit_study.return_value = {} - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) client.post("/v1/studies/my-uuid/raw?path=url/to/change", json={"Hello": "World"}) mock_storage_service.edit_study.assert_called_once_with("my-uuid", "url/to/change", {"Hello": "World"}, PARAMS) @@ -497,18 +384,7 @@ def test_validate() -> None: mock_service = Mock() mock_service.check_errors.return_value = ["Hello"] - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(mock_service, raise_server_exceptions=False) res = client.get("/v1/studies/my-uuid/raw/validate") assert res.json() == ["Hello"] @@ -551,24 +427,13 @@ def test_output_download(tmp_path: Path) -> None: synthesis=False, includeClusters=True, ) - - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + ftm = SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))) + client = create_test_client(mock_service, ftm, raise_server_exceptions=False) res = client.post( f"/v1/studies/{UUID}/outputs/my-output-id/download", - json=study_download.dict(), + json=study_download.model_dump(), ) - assert res.json() == output_data.dict() + assert res.json() == output_data.model_dump() @pytest.mark.unit_test @@ -588,18 +453,8 @@ def test_output_whole_download(tmp_path: Path) -> None: ) mock_service.export_output.return_value = expected - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + ftm = SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))) + client = create_test_client(mock_service, ftm, raise_server_exceptions=False) res = client.get( f"/v1/studies/{UUID}/outputs/{output_id}/export", ) @@ -612,18 +467,7 @@ def test_sim_reference() -> None: study_id = str(uuid.uuid4()) output_id = "my-output-id" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(mock_service, raise_server_exceptions=False) res = client.put(f"/v1/studies/{study_id}/outputs/{output_id}/reference") mock_service.set_sim_reference.assert_called_once_with(study_id, output_id, True, PARAMS) assert res.status_code == HTTPStatus.OK @@ -656,38 +500,17 @@ def test_sim_result() -> None: ) ] mock_service.get_study_sim_result.return_value = result_data - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + + client = create_test_client(mock_service, raise_server_exceptions=False) res = client.get(f"/v1/studies/{study_id}/outputs") - assert res.json() == result_data + actual_object = [StudySimResultDTO.parse_obj(res.json()[0])] + assert actual_object == result_data @pytest.mark.unit_test def test_study_permission_management(tmp_path: Path) -> None: storage_service = Mock() - - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=CONFIG, - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(storage_service, raise_server_exceptions=False) result = client.put(f"/v1/studies/{UUID}/owner/2") storage_service.change_owner.assert_called_with( @@ -727,18 +550,7 @@ def test_study_permission_management(tmp_path: Path) -> None: @pytest.mark.unit_test def test_get_study_versions(tmp_path: Path) -> None: - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=Mock(), - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=CONFIG, - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(Mock(), raise_server_exceptions=False) result = client.get("/v1/studies/_versions") assert result.json() == list(STUDY_REFERENCE_TEMPLATES.keys()) diff --git a/tests/study/business/areas/test_st_storage_management.py b/tests/study/business/areas/test_st_storage_management.py index c48768427c..56a42d88a0 100644 --- a/tests/study/business/areas/test_st_storage_management.py +++ b/tests/study/business/areas/test_st_storage_management.py @@ -76,6 +76,8 @@ "east": {"list": {}}, } +GEN = np.random.default_rng(1000) + class TestSTStorageManager: @pytest.fixture(name="study_storage_service") @@ -147,7 +149,7 @@ def test_get_all_storages__nominal_case( # Check actual = { - area_id: [form.dict(by_alias=True) for form in clusters_by_ids.values()] + area_id: [form.model_dump(by_alias=True) for form in clusters_by_ids.values()] for area_id, clusters_by_ids in all_storages.items() } expected = { @@ -253,7 +255,7 @@ def test_get_st_storages__nominal_case( groups = manager.get_storages(study, area_id="West") # Check - actual = [form.dict(by_alias=True) for form in groups] + actual = [form.model_dump(by_alias=True) for form in groups] expected = [ { "efficiency": 0.94, @@ -365,7 +367,7 @@ def test_get_st_storage__nominal_case( edit_form = manager.get_storage(study, area_id="West", storage_id="storage1") # Assert that the returned storage fields match the expected fields - actual = edit_form.dict(by_alias=True) + actual = edit_form.model_dump(by_alias=True) expected = { "efficiency": 0.94, "group": STStorageGroup.BATTERY, @@ -455,11 +457,11 @@ def test_update_storage__nominal_case( actual = {(call_args[0][0], tuple(call_args[0][1])) for call_args in file_study.tree.save.call_args_list} expected = { ( - str(0.0), + 0.0, ("input", "st-storage", "clusters", "West", "list", "storage1", "initiallevel"), ), ( - str(False), + False, ("input", "st-storage", "clusters", "West", "list", "storage1", "initialleveloptim"), ), } @@ -528,16 +530,15 @@ def test_get_matrix__nominal_case( # Prepare the mocks storage = study_storage_service.get_storage(study) file_study = storage.get_raw(study) - array = np.random.rand(8760, 1) * 1000 + array = GEN.random((8760, 1)) * 1000 + expected = { + "index": list(range(8760)), + "columns": [0], + "data": array.tolist(), + } file_study.tree = Mock( spec=FileStudyTree, - get=Mock( - return_value={ - "index": list(range(8760)), - "columns": [0], - "data": array.tolist(), - } - ), + get=Mock(return_value=expected), ) # Given the following arguments @@ -547,8 +548,8 @@ def test_get_matrix__nominal_case( matrix = manager.get_matrix(study, area_id="West", storage_id="storage1", ts_name="inflows") # Assert that the returned storage fields match the expected fields - actual = matrix.dict(by_alias=True) - assert actual == matrix + actual = matrix.model_dump(by_alias=True) + assert actual == expected def test_get_matrix__config_not_found( self, @@ -616,7 +617,7 @@ def test_get_matrix__invalid_matrix( # Prepare the mocks storage = study_storage_service.get_storage(study) file_study = storage.get_raw(study) - array = np.random.rand(365, 1) * 1000 + array = GEN.random((365, 1)) * 1000 matrix = { "index": list(range(365)), "columns": [0], @@ -649,11 +650,11 @@ def test_validate_matrices__nominal( # prepare some random matrices, insuring `lower_rule_curve` <= `upper_rule_curve` matrices = { - "pmax_injection": np.random.rand(8760, 1), - "pmax_withdrawal": np.random.rand(8760, 1), - "lower_rule_curve": np.random.rand(8760, 1) / 2, - "upper_rule_curve": np.random.rand(8760, 1) / 2 + 0.5, - "inflows": np.random.rand(8760, 1) * 1000, + "pmax_injection": GEN.random((8760, 1)), + "pmax_withdrawal": GEN.random((8760, 1)), + "lower_rule_curve": GEN.random((8760, 1)) / 2, + "upper_rule_curve": GEN.random((8760, 1)) / 2 + 0.5, + "inflows": GEN.random((8760, 1)) * 1000, } # Prepare the mocks @@ -686,11 +687,11 @@ def test_validate_matrices__out_of_bound( # prepare some random matrices, insuring `lower_rule_curve` <= `upper_rule_curve` matrices = { - "pmax_injection": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "pmax_withdrawal": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "lower_rule_curve": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "upper_rule_curve": np.random.rand(8760, 1) * 2 - 0.5, # out of bound - "inflows": np.random.rand(8760, 1) * 1000, + "pmax_injection": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "pmax_withdrawal": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "lower_rule_curve": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "upper_rule_curve": GEN.random((8760, 1)) * 2 - 0.5, # out of bound + "inflows": GEN.random((8760, 1)) * 1000, } # Prepare the mocks @@ -707,7 +708,6 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: file_study = storage.get_raw(study) file_study.tree = Mock(spec=FileStudyTree, get=tree_get) - # Given the following arguments, the validation shouldn't raise any exception manager = STStorageManager(study_storage_service) # Run the method being tested and expect an exception @@ -716,29 +716,11 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: match=re.escape("4 validation errors"), ) as ctx: manager.validate_matrices(study, area_id="West", storage_id="storage1") - errors = ctx.value.errors() - assert errors == [ - { - "loc": ("pmax_injection",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - { - "loc": ("pmax_withdrawal",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - { - "loc": ("lower_rule_curve",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - { - "loc": ("upper_rule_curve",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - }, - ] + assert ctx.value.error_count() == 4 + for error in ctx.value.errors(): + assert error["type"] == "value_error" + assert error["msg"] == "Value error, Matrix values should be between 0 and 1" + assert error["loc"][0] in ["upper_rule_curve", "lower_rule_curve", "pmax_withdrawal", "pmax_injection"] # noinspection SpellCheckingInspection def test_validate_matrices__rule_curve( @@ -750,13 +732,15 @@ def test_validate_matrices__rule_curve( # The study must be fetched from the database study: RawStudy = db_session.query(Study).get(study_uuid) - # prepare some random matrices, insuring `lower_rule_curve` <= `upper_rule_curve` + # prepare some random matrices, not respecting `lower_rule_curve` <= `upper_rule_curve` + upper_curve = np.zeros((8760, 1)) + lower_curve = np.ones((8760, 1)) matrices = { - "pmax_injection": np.random.rand(8760, 1), - "pmax_withdrawal": np.random.rand(8760, 1), - "lower_rule_curve": np.random.rand(8760, 1), - "upper_rule_curve": np.random.rand(8760, 1), - "inflows": np.random.rand(8760, 1) * 1000, + "pmax_injection": GEN.random((8760, 1)), + "pmax_withdrawal": GEN.random((8760, 1)), + "lower_rule_curve": lower_curve, + "upper_rule_curve": upper_curve, + "inflows": GEN.random((8760, 1)) * 1000, } # Prepare the mocks @@ -773,7 +757,7 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: file_study = storage.get_raw(study) file_study.tree = Mock(spec=FileStudyTree, get=tree_get) - # Given the following arguments, the validation shouldn't raise any exception + # Given the following arguments manager = STStorageManager(study_storage_service) # Run the method being tested and expect an exception @@ -783,6 +767,8 @@ def tree_get(url: t.Sequence[str], **_: t.Any) -> t.MutableMapping[str, t.Any]: ) as ctx: manager.validate_matrices(study, area_id="West", storage_id="storage1") error = ctx.value.errors()[0] - assert error["loc"] == ("__root__",) - assert "lower_rule_curve" in error["msg"] - assert "upper_rule_curve" in error["msg"] + assert error["type"] == "value_error" + assert ( + error["msg"] + == "Value error, Each 'lower_rule_curve' value must be lower or equal to each 'upper_rule_curve'" + ) diff --git a/tests/study/business/areas/test_thermal_management.py b/tests/study/business/areas/test_thermal_management.py index e1276286a2..6bb2e34d7e 100644 --- a/tests/study/business/areas/test_thermal_management.py +++ b/tests/study/business/areas/test_thermal_management.py @@ -144,7 +144,7 @@ def test_get_cluster__study_legacy( form = manager.get_cluster(study, area_id="north", cluster_id="2 avail and must 1") # Assert that the returned fields match the expected fields - actual = form.dict(by_alias=True) + actual = form.model_dump(by_alias=True) expected = { "id": "2 avail and must 1", "group": ThermalClusterGroup.GAS, @@ -210,7 +210,7 @@ def test_get_clusters__study_legacy( groups = manager.get_clusters(study, area_id="north") # Assert that the returned fields match the expected fields - actual = [form.dict(by_alias=True) for form in groups] + actual = [form.model_dump(by_alias=True) for form in groups] expected = [ { "id": "2 avail and must 1", @@ -366,7 +366,7 @@ def test_create_cluster__study_legacy( form = manager.create_cluster(study, area_id="north", cluster_data=cluster_data) # Assert that the returned fields match the expected fields - actual = form.dict(by_alias=True) + actual = form.model_dump(by_alias=True) expected = { "co2": 12.59, "enabled": True, @@ -426,7 +426,7 @@ def test_update_cluster( # Assert that the returned fields match the expected fields form = manager.get_cluster(study, area_id="north", cluster_id="2 avail and must 1") - actual = form.dict(by_alias=True) + actual = form.model_dump(by_alias=True) expected = { "id": "2 avail and must 1", "group": ThermalClusterGroup.GAS, diff --git a/tests/study/business/test_all_optional_metaclass.py b/tests/study/business/test_all_optional_metaclass.py index 741b16ccc8..b8d1197c5e 100644 --- a/tests/study/business/test_all_optional_metaclass.py +++ b/tests/study/business/test_all_optional_metaclass.py @@ -10,352 +10,51 @@ # # This file is part of the Antares project. -import typing as t +from pydantic import BaseModel, Field -import pytest -from pydantic import BaseModel, Field, ValidationError +from antarest.study.business.all_optional_meta import all_optional_model, camel_case_model -from antarest.study.business.all_optional_meta import AllOptionalMetaclass -# ============================================== -# Classic way to use default and optional values -# ============================================== +class Model(BaseModel): + float_with_default: float = 1 + float_without_default: float + boolean_with_default: bool = True + boolean_without_default: bool + field_with_alias: str = Field(default="default", alias="field-with-alias") -class ClassicModel(BaseModel): - mandatory: float = Field(ge=0, le=1) - mandatory_with_default: float = Field(ge=0, le=1, default=0.2) - mandatory_with_none: float = Field(ge=0, le=1, default=None) - optional: t.Optional[float] = Field(ge=0, le=1) - optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) - optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) - - -class ClassicSubModel(ClassicModel): - pass - - -class TestClassicModel: - """ - Test that default and optional values work as expected. - """ - - @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) - def test_classes(self, cls: t.Type[BaseModel]) -> None: - assert cls.__fields__["mandatory"].required is True - assert cls.__fields__["mandatory"].allow_none is False - assert cls.__fields__["mandatory"].default is None - assert cls.__fields__["mandatory"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_default"].required is False - assert cls.__fields__["mandatory_with_default"].allow_none is False - assert cls.__fields__["mandatory_with_default"].default == 0.2 - assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_none"].required is False - assert cls.__fields__["mandatory_with_none"].allow_none is True - assert cls.__fields__["mandatory_with_none"].default is None - assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined - - assert cls.__fields__["optional"].required is False - assert cls.__fields__["optional"].allow_none is True - assert cls.__fields__["optional"].default is None - assert cls.__fields__["optional"].default_factory is None # undefined - - assert cls.__fields__["optional_with_default"].required is False - assert cls.__fields__["optional_with_default"].allow_none is True - assert cls.__fields__["optional_with_default"].default == 0.2 - assert cls.__fields__["optional_with_default"].default_factory is None # undefined - - assert cls.__fields__["optional_with_none"].required is False - assert cls.__fields__["optional_with_none"].allow_none is True - assert cls.__fields__["optional_with_none"].default is None - assert cls.__fields__["optional_with_none"].default_factory is None # undefined - - @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) - def test_initialization(self, cls: t.Type[ClassicModel]) -> None: - # We can build a model without providing optional or default values. - # The initialized value will be the default value or `None` for optional values. - obj = cls(mandatory=0.5) - assert obj.mandatory == 0.5 - assert obj.mandatory_with_default == 0.2 - assert obj.mandatory_with_none is None - assert obj.optional is None - assert obj.optional_with_default == 0.2 - assert obj.optional_with_none is None - - # We must provide a value for mandatory fields. - with pytest.raises(ValidationError): - cls() - - @pytest.mark.parametrize("cls", [ClassicModel, ClassicSubModel]) - def test_validation(self, cls: t.Type[ClassicModel]) -> None: - # We CANNOT use `None` as a value for a field with a default value. - with pytest.raises(ValidationError): - cls(mandatory=0.5, mandatory_with_default=None) - - # We can use `None` as a value for optional fields with default value. - cls(mandatory=0.5, optional_with_default=None) - - # We can validate a model with valid values. - cls( - mandatory=0.5, - mandatory_with_default=0.2, - mandatory_with_none=0.2, - optional=0.5, - optional_with_default=0.2, - optional_with_none=0.2, - ) - - # We CANNOT validate a model with invalid values. - with pytest.raises(ValidationError): - cls(mandatory=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, mandatory_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, mandatory_with_none=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, optional=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, optional_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory=0.5, optional_with_none=2) - - -# ========================== -# Using AllOptionalMetaclass -# ========================== - - -class AllOptionalModel(BaseModel, metaclass=AllOptionalMetaclass): - mandatory: float = Field(ge=0, le=1) - mandatory_with_default: float = Field(ge=0, le=1, default=0.2) - mandatory_with_none: float = Field(ge=0, le=1, default=None) - optional: t.Optional[float] = Field(ge=0, le=1) - optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) - optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) - - -class AllOptionalSubModel(AllOptionalModel): +@all_optional_model +class OptionalModel(Model): pass -class InheritedAllOptionalModel(ClassicModel, metaclass=AllOptionalMetaclass): +@all_optional_model +@camel_case_model +class OptionalCamelCaseModel(Model): pass -class TestAllOptionalModel: - """ - Test that AllOptionalMetaclass works with base classes. - """ - - @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) - def test_classes(self, cls: t.Type[BaseModel]) -> None: - assert cls.__fields__["mandatory"].required is False - assert cls.__fields__["mandatory"].allow_none is True - assert cls.__fields__["mandatory"].default is None - assert cls.__fields__["mandatory"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_default"].required is False - assert cls.__fields__["mandatory_with_default"].allow_none is True - assert cls.__fields__["mandatory_with_default"].default == 0.2 - assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_none"].required is False - assert cls.__fields__["mandatory_with_none"].allow_none is True - assert cls.__fields__["mandatory_with_none"].default is None - assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined - - assert cls.__fields__["optional"].required is False - assert cls.__fields__["optional"].allow_none is True - assert cls.__fields__["optional"].default is None - assert cls.__fields__["optional"].default_factory is None # undefined - - assert cls.__fields__["optional_with_default"].required is False - assert cls.__fields__["optional_with_default"].allow_none is True - assert cls.__fields__["optional_with_default"].default == 0.2 - assert cls.__fields__["optional_with_default"].default_factory is None # undefined - - assert cls.__fields__["optional_with_none"].required is False - assert cls.__fields__["optional_with_none"].allow_none is True - assert cls.__fields__["optional_with_none"].default is None - assert cls.__fields__["optional_with_none"].default_factory is None # undefined - - @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) - def test_initialization(self, cls: t.Type[AllOptionalModel]) -> None: - # We can build a model without providing values. - # The initialized value will be the default value or `None` for optional values. - # Note that the mandatory fields are not required anymore, and can be `None`. - obj = cls() - assert obj.mandatory is None - assert obj.mandatory_with_default == 0.2 - assert obj.mandatory_with_none is None - assert obj.optional is None - assert obj.optional_with_default == 0.2 - assert obj.optional_with_none is None - - # If we convert the model to a dictionary, without `None` values, - # we should have a dictionary with default values only. - actual = obj.dict(exclude_none=True) - expected = { - "mandatory_with_default": 0.2, - "optional_with_default": 0.2, - } - assert actual == expected - - @pytest.mark.parametrize("cls", [AllOptionalModel, AllOptionalSubModel, InheritedAllOptionalModel]) - def test_validation(self, cls: t.Type[AllOptionalModel]) -> None: - # We can use `None` as a value for all fields. - cls(mandatory=None) - cls(mandatory_with_default=None) - cls(mandatory_with_none=None) - cls(optional=None) - cls(optional_with_default=None) - cls(optional_with_none=None) - - # We can validate a model with valid values. - cls( - mandatory=0.5, - mandatory_with_default=0.2, - mandatory_with_none=0.2, - optional=0.5, - optional_with_default=0.2, - optional_with_none=0.2, - ) - - # We CANNOT validate a model with invalid values. - with pytest.raises(ValidationError): - cls(mandatory=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_none=2) - - with pytest.raises(ValidationError): - cls(optional=2) - - with pytest.raises(ValidationError): - cls(optional_with_default=2) - - with pytest.raises(ValidationError): - cls(optional_with_none=2) - - -# The `use_none` keyword argument is set to `True` to allow the use of `None` -# as a default value for the fields of the model. - - -class UseNoneModel(BaseModel, metaclass=AllOptionalMetaclass, use_none=True): - mandatory: float = Field(ge=0, le=1) - mandatory_with_default: float = Field(ge=0, le=1, default=0.2) - mandatory_with_none: float = Field(ge=0, le=1, default=None) - optional: t.Optional[float] = Field(ge=0, le=1) - optional_with_default: t.Optional[float] = Field(ge=0, le=1, default=0.2) - optional_with_none: t.Optional[float] = Field(ge=0, le=1, default=None) - - -class UseNoneSubModel(UseNoneModel): - pass - - -class InheritedUseNoneModel(ClassicModel, metaclass=AllOptionalMetaclass, use_none=True): - pass - - -class TestUseNoneModel: - @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) - def test_classes(self, cls: t.Type[BaseModel]) -> None: - assert cls.__fields__["mandatory"].required is False - assert cls.__fields__["mandatory"].allow_none is True - assert cls.__fields__["mandatory"].default is None - assert cls.__fields__["mandatory"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_default"].required is False - assert cls.__fields__["mandatory_with_default"].allow_none is True - assert cls.__fields__["mandatory_with_default"].default is None - assert cls.__fields__["mandatory_with_default"].default_factory is None # undefined - - assert cls.__fields__["mandatory_with_none"].required is False - assert cls.__fields__["mandatory_with_none"].allow_none is True - assert cls.__fields__["mandatory_with_none"].default is None - assert cls.__fields__["mandatory_with_none"].default_factory is None # undefined - - assert cls.__fields__["optional"].required is False - assert cls.__fields__["optional"].allow_none is True - assert cls.__fields__["optional"].default is None - assert cls.__fields__["optional"].default_factory is None # undefined - - assert cls.__fields__["optional_with_default"].required is False - assert cls.__fields__["optional_with_default"].allow_none is True - assert cls.__fields__["optional_with_default"].default is None - assert cls.__fields__["optional_with_default"].default_factory is None # undefined - - assert cls.__fields__["optional_with_none"].required is False - assert cls.__fields__["optional_with_none"].allow_none is True - assert cls.__fields__["optional_with_none"].default is None - assert cls.__fields__["optional_with_none"].default_factory is None # undefined - - @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) - def test_initialization(self, cls: t.Type[UseNoneModel]) -> None: - # We can build a model without providing values. - # The initialized value will be the default value or `None` for optional values. - # Note that the mandatory fields are not required anymore, and can be `None`. - obj = cls() - assert obj.mandatory is None - assert obj.mandatory_with_default is None - assert obj.mandatory_with_none is None - assert obj.optional is None - assert obj.optional_with_default is None - assert obj.optional_with_none is None - - # If we convert the model to a dictionary, without `None` values, - # we should have an empty dictionary. - actual = obj.dict(exclude_none=True) - expected = {} - assert actual == expected - - @pytest.mark.parametrize("cls", [UseNoneModel, UseNoneSubModel, InheritedUseNoneModel]) - def test_validation(self, cls: t.Type[UseNoneModel]) -> None: - # We can use `None` as a value for all fields. - cls(mandatory=None) - cls(mandatory_with_default=None) - cls(mandatory_with_none=None) - cls(optional=None) - cls(optional_with_default=None) - cls(optional_with_none=None) - - # We can validate a model with valid values. - cls( - mandatory=0.5, - mandatory_with_default=0.2, - mandatory_with_none=0.2, - optional=0.5, - optional_with_default=0.2, - optional_with_none=0.2, - ) - - # We CANNOT validate a model with invalid values. - with pytest.raises(ValidationError): - cls(mandatory=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_default=2) - - with pytest.raises(ValidationError): - cls(mandatory_with_none=2) - - with pytest.raises(ValidationError): - cls(optional=2) - - with pytest.raises(ValidationError): - cls(optional_with_default=2) - - with pytest.raises(ValidationError): - cls(optional_with_none=2) +def test_model() -> None: + optional_model = OptionalModel() + assert optional_model.float_with_default is None + assert optional_model.float_without_default is None + assert optional_model.boolean_with_default is None + assert optional_model.boolean_without_default is None + assert optional_model.field_with_alias is None + + optional_model = OptionalModel(boolean_with_default=False) + assert optional_model.float_with_default is None + assert optional_model.float_without_default is None + assert optional_model.boolean_with_default is False + assert optional_model.boolean_without_default is None + assert optional_model.field_with_alias is None + + # build with alias should succeed + args = {"field-with-alias": "test"} + optional_model = OptionalModel(**args) + assert optional_model.field_with_alias == "test" + + # build with camel_case should succeed + args = {"fieldWithAlias": "test"} + camel_case_model = OptionalCamelCaseModel(**args) + assert camel_case_model.field_with_alias == "test" diff --git a/tests/study/business/test_allocation_manager.py b/tests/study/business/test_allocation_manager.py index 8c49dd07c3..a7fec22f2e 100644 --- a/tests/study/business/test_allocation_manager.py +++ b/tests/study/business/test_allocation_manager.py @@ -47,7 +47,7 @@ def test_base(self): def test_camel_case(self): field = AllocationField(areaId="NORTH", coefficient=1) - assert field.dict(by_alias=True) == { + assert field.model_dump(by_alias=True) == { "areaId": "NORTH", "coefficient": 1, } diff --git a/tests/study/storage/variantstudy/model/test_dbmodel.py b/tests/study/storage/variantstudy/model/test_dbmodel.py index 162c9a269b..aff5686dcb 100644 --- a/tests/study/storage/variantstudy/model/test_dbmodel.py +++ b/tests/study/storage/variantstudy/model/test_dbmodel.py @@ -164,7 +164,7 @@ def test_init(self, db_session: Session, variant_study_id: str) -> None: # check CommandBlock.to_dto() dto = obj.to_dto() # note: it is easier to compare the dict representation of the DTO - assert dto.dict() == { + assert dto.model_dump() == { "id": command_id, "action": command, "args": json.loads(args), diff --git a/tests/study/storage/variantstudy/test_snapshot_generator.py b/tests/study/storage/variantstudy/test_snapshot_generator.py index 2cf975403b..0f2d4775ef 100644 --- a/tests/study/storage/variantstudy/test_snapshot_generator.py +++ b/tests/study/storage/variantstudy/test_snapshot_generator.py @@ -863,7 +863,7 @@ def test_generate__nominal_case( assert len(db_recorder.sql_statements) == 5, str(db_recorder) # Check: the variant generation must succeed. - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { @@ -1048,7 +1048,7 @@ def test_generate__with_denormalize_true( ) # Check the results - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { @@ -1171,7 +1171,7 @@ def test_generate__notification_failure( ) # Check the results - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { @@ -1253,7 +1253,7 @@ def test_generate__variant_of_variant( ) # Check the results - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ { diff --git a/tests/study/test_repository.py b/tests/study/test_repository.py index 0be20a0bc8..bcdcea759b 100644 --- a/tests/study/test_repository.py +++ b/tests/study/test_repository.py @@ -917,9 +917,9 @@ def test_get_all__non_admin_permissions_filter( user_2 = User(id=102, name="user2") user_3 = User(id=103, name="user3") - group_1 = Group(id=101, name="group1") - group_2 = Group(id=102, name="group2") - group_3 = Group(id=103, name="group3") + group_1 = Group(id="101", name="group1") + group_2 = Group(id="102", name="group2") + group_3 = Group(id="103", name="group3") user_groups_mapping = {101: [group_2.id], 102: [group_1.id], 103: []} @@ -1191,23 +1191,23 @@ def test_update_tags( (None, [], False, ["5", "6"]), (None, [], True, ["1", "2", "3", "4", "7", "8"]), (None, [], None, ["1", "2", "3", "4", "5", "6", "7", "8"]), - (None, [1, 3, 5, 7], False, ["5"]), - (None, [1, 3, 5, 7], True, ["1", "3", "7"]), - (None, [1, 3, 5, 7], None, ["1", "3", "5", "7"]), + (None, ["1", "3", "5", "7"], False, ["5"]), + (None, ["1", "3", "5", "7"], True, ["1", "3", "7"]), + (None, ["1", "3", "5", "7"], None, ["1", "3", "5", "7"]), (True, [], False, ["5"]), (True, [], True, ["1", "2", "3", "4", "8"]), (True, [], None, ["1", "2", "3", "4", "5", "8"]), - (True, [1, 3, 5, 7], False, ["5"]), - (True, [1, 3, 5, 7], True, ["1", "3"]), - (True, [1, 3, 5, 7], None, ["1", "3", "5"]), - (True, [2, 4, 6, 8], True, ["2", "4", "8"]), - (True, [2, 4, 6, 8], None, ["2", "4", "8"]), + (True, ["1", "3", "5", "7"], False, ["5"]), + (True, ["1", "3", "5", "7"], True, ["1", "3"]), + (True, ["1", "3", "5", "7"], None, ["1", "3", "5"]), + (True, ["2", "4", "6", "8"], True, ["2", "4", "8"]), + (True, ["2", "4", "6", "8"], None, ["2", "4", "8"]), (False, [], False, ["6"]), (False, [], True, ["7"]), (False, [], None, ["6", "7"]), - (False, [1, 3, 5, 7], False, []), - (False, [1, 3, 5, 7], True, ["7"]), - (False, [1, 3, 5, 7], None, ["7"]), + (False, ["1", "3", "5", "7"], False, []), + (False, ["1", "3", "5", "7"], True, ["7"]), + (False, ["1", "3", "5", "7"], None, ["7"]), ], ) def test_count_studies__general_case( @@ -1221,14 +1221,14 @@ def test_count_studies__general_case( icache: Mock = Mock(spec=ICache) repository = StudyMetadataRepository(cache_service=icache, session=db_session) - study_1 = VariantStudy(id=1, name="study-1") - study_2 = VariantStudy(id=2, name="study-2") - study_3 = VariantStudy(id=3, name="study-3") - study_4 = VariantStudy(id=4, name="study-4") - study_5 = RawStudy(id=5, name="study-5", missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) - study_6 = RawStudy(id=6, name="study-6", missing=datetime.datetime.now(), workspace=test_workspace) - study_7 = RawStudy(id=7, name="study-7", missing=None, workspace=test_workspace) - study_8 = RawStudy(id=8, name="study-8", missing=None, workspace=DEFAULT_WORKSPACE_NAME) + study_1 = VariantStudy(id="1", name="study-1") + study_2 = VariantStudy(id="2", name="study-2") + study_3 = VariantStudy(id="3", name="study-3") + study_4 = VariantStudy(id="4", name="study-4") + study_5 = RawStudy(id="5", name="study-5", missing=datetime.datetime.now(), workspace=DEFAULT_WORKSPACE_NAME) + study_6 = RawStudy(id="6", name="study-6", missing=datetime.datetime.now(), workspace=test_workspace) + study_7 = RawStudy(id="7", name="study-7", missing=None, workspace=test_workspace) + study_8 = RawStudy(id="8", name="study-8", missing=None, workspace=DEFAULT_WORKSPACE_NAME) db_session.add_all([study_1, study_2, study_3, study_4, study_5, study_6, study_7, study_8]) db_session.commit() diff --git a/tests/test_front.py b/tests/test_front.py new file mode 100644 index 0000000000..5046a868cf --- /dev/null +++ b/tests/test_front.py @@ -0,0 +1,111 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from pathlib import Path + +import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient + +from antarest.front import RedirectMiddleware, add_front_app + + +@pytest.fixture +def base_back_app() -> FastAPI: + """ + A simple app which has only one backend endpoint + """ + app = FastAPI(title=__name__) + + @app.get(path="/api/a-backend-endpoint") + def get_from_api() -> str: + return "back" + + return app + + +@pytest.fixture +def resources_dir(tmp_path: Path) -> Path: + resource_dir = tmp_path / "resources" + resource_dir.mkdir() + webapp_dir = resource_dir / "webapp" + webapp_dir.mkdir() + with open(webapp_dir / "index.html", mode="w") as f: + f.write("index") + with open(webapp_dir / "front.css", mode="w") as f: + f.write("css") + return resource_dir + + +@pytest.fixture +def app_with_home(base_back_app) -> FastAPI: + """ + A simple app which has only a home endpoint and one backend endpoint + """ + + @base_back_app.get(path="/") + def home() -> str: + return "home" + + return base_back_app + + +@pytest.fixture +def redirect_app(app_with_home: FastAPI) -> FastAPI: + """ + Same as app with redirect middleware + """ + route_paths = [r.path for r in app_with_home.routes] # type: ignore + app_with_home.add_middleware(RedirectMiddleware, route_paths=route_paths) + return app_with_home + + +def test_redirect_middleware_does_not_modify_home(redirect_app: FastAPI) -> None: + client = TestClient(redirect_app) + response = client.get("/") + assert response.status_code == 200 + assert response.json() == "home" + + +def test_redirect_middleware_redirects_unknown_routes_to_home(redirect_app: FastAPI) -> None: + client = TestClient(redirect_app) + response = client.get("/a-front-route") + assert response.status_code == 200 + assert response.json() == "home" + + +def test_redirect_middleware_does_not_redirect_backend_routes(redirect_app: FastAPI) -> None: + client = TestClient(redirect_app) + response = client.get("/api/a-backend-endpoint") + assert response.status_code == 200 + assert response.json() == "back" + + +def test_frontend_paths(base_back_app, resources_dir: Path) -> None: + add_front_app(base_back_app, resources_dir, "/api") + client = TestClient(base_back_app) + + config_response = client.get("/config.json") + assert config_response.status_code == 200 + assert config_response.json() == {"restEndpoint": "/api", "wsEndpoint": "/api/ws"} + + index_response = client.get("/index.html") + assert index_response.status_code == 200 + assert index_response.text == "index" + + front_route_response = client.get("/any-route") + assert front_route_response.status_code == 200 + assert front_route_response.text == "index" + + front_static_file_response = client.get("/static/front.css") + assert front_static_file_response.status_code == 200 + assert front_static_file_response.text == "css" diff --git a/tests/variantstudy/model/command/test_create_cluster.py b/tests/variantstudy/model/command/test_create_cluster.py index a0285f06a2..35c21e91ab 100644 --- a/tests/variantstudy/model/command/test_create_cluster.py +++ b/tests/variantstudy/model/command/test_create_cluster.py @@ -28,11 +28,13 @@ from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig from antarest.study.storage.variantstudy.model.command_context import CommandContext +GEN = np.random.default_rng(1000) + class TestCreateCluster: def test_init(self, command_context: CommandContext): - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() cl = CreateCluster( area_id="foo", cluster_name="Cluster1", @@ -52,7 +54,7 @@ def test_init(self, command_context: CommandContext): modulation_id = command_context.matrix_service.create(modulation) assert cl.area_id == "foo" assert cl.cluster_name == "Cluster1" - assert cl.parameters == {"group": "Nuclear", "nominalcapacity": "2400", "unitcount": "2"} + assert cl.parameters == {"group": "Nuclear", "nominalcapacity": 2400, "unitcount": 2} assert cl.prepro == f"matrix://{prepro_id}" assert cl.modulation == f"matrix://{modulation_id}" @@ -85,8 +87,8 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): "market-bid-cost": "30", } - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() command = CreateCluster( area_id=area_id, cluster_name=cluster_name, @@ -147,8 +149,8 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): ) def test_to_dto(self, command_context: CommandContext): - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() command = CreateCluster( area_id="foo", cluster_name="Cluster1", @@ -160,12 +162,12 @@ def test_to_dto(self, command_context: CommandContext): prepro_id = command_context.matrix_service.create(prepro) modulation_id = command_context.matrix_service.create(modulation) dto = command.to_dto() - assert dto.dict() == { + assert dto.model_dump() == { "action": "create_cluster", "args": { "area_id": "foo", "cluster_name": "Cluster1", - "parameters": {"group": "Nuclear", "nominalcapacity": "2400", "unitcount": "2"}, + "parameters": {"group": "Nuclear", "nominalcapacity": 2400, "unitcount": 2}, "prepro": prepro_id, "modulation": modulation_id, }, @@ -175,8 +177,8 @@ def test_to_dto(self, command_context: CommandContext): def test_match(command_context: CommandContext): - prepro = np.random.rand(365, 6).tolist() - modulation = np.random.rand(8760, 4).tolist() + prepro = GEN.random((365, 6)).tolist() + modulation = GEN.random((8760, 4)).tolist() base = CreateCluster( area_id="foo", cluster_name="foo", @@ -235,8 +237,8 @@ def test_revert(command_context: CommandContext): def test_create_diff(command_context: CommandContext): - prepro_a = np.random.rand(365, 6).tolist() - modulation_a = np.random.rand(8760, 4).tolist() + prepro_a = GEN.random((365, 6)).tolist() + modulation_a = GEN.random((8760, 4)).tolist() base = CreateCluster( area_id="foo", cluster_name="foo", @@ -246,8 +248,8 @@ def test_create_diff(command_context: CommandContext): command_context=command_context, ) - prepro_b = np.random.rand(365, 6).tolist() - modulation_b = np.random.rand(8760, 4).tolist() + prepro_b = GEN.random((365, 6)).tolist() + modulation_b = GEN.random((8760, 4)).tolist() other_match = CreateCluster( area_id="foo", cluster_name="foo", diff --git a/tests/variantstudy/model/command/test_create_link.py b/tests/variantstudy/model/command/test_create_link.py index 36c91d198b..9d847f0f24 100644 --- a/tests/variantstudy/model/command/test_create_link.py +++ b/tests/variantstudy/model/command/test_create_link.py @@ -37,16 +37,15 @@ def test_validation(self, empty_study: FileStudy, command_context: CommandContex area1_id = transform_name_to_id(area1) area2 = "Area2" - area2_id = transform_name_to_id(area2) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area2, "command_context": command_context, @@ -54,7 +53,7 @@ def test_validation(self, empty_study: FileStudy, command_context: CommandContex ).apply(empty_study) with pytest.raises(ValidationError): - create_link_command: ICommand = CreateLink( + CreateLink( area1=area1_id, area2=area1_id, parameters={}, @@ -73,21 +72,21 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): area3 = "Area3" area3_id = transform_name_to_id(area3) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area2, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area3, "command_context": command_context, @@ -145,7 +144,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): # TODO:assert matrix default content : 1 column, 8760 rows, value = 1 - output = CreateLink.parse_obj( + output = CreateLink.model_validate( { "area1": area1_id, "area2": area2_id, @@ -173,7 +172,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): "filter-year-by-year": "hourly", } - create_link_command: ICommand = CreateLink.parse_obj( + create_link_command: ICommand = CreateLink.model_validate( { "area1": area3_id, "area2": area1_id, diff --git a/tests/variantstudy/model/command/test_create_renewables_cluster.py b/tests/variantstudy/model/command/test_create_renewables_cluster.py index 91d3ffc3a9..b23d6ba553 100644 --- a/tests/variantstudy/model/command/test_create_renewables_cluster.py +++ b/tests/variantstudy/model/command/test_create_renewables_cluster.py @@ -46,7 +46,7 @@ def test_init(self, command_context: CommandContext) -> None: # Check the command data assert cl.area_id == "foo" assert cl.cluster_name == "Cluster1" - assert cl.parameters == {"group": "Solar Thermal", "nominalcapacity": "2400", "unitcount": "2"} + assert cl.parameters == {"group": "Solar Thermal", "nominalcapacity": 2400, "unitcount": 2} def test_validate_cluster_name(self, command_context: CommandContext) -> None: with pytest.raises(ValidationError, match="cluster_name"): @@ -131,12 +131,12 @@ def test_to_dto(self, command_context: CommandContext) -> None: command_context=command_context, ) dto = command.to_dto() - assert dto.dict() == { + assert dto.model_dump() == { "action": "create_renewables_cluster", # "renewables" with a final "s". "args": { "area_id": "foo", "cluster_name": "Cluster1", - "parameters": {"group": "Solar Thermal", "nominalcapacity": "2400", "unitcount": "2"}, + "parameters": {"group": "Solar Thermal", "nominalcapacity": 2400, "unitcount": 2}, }, "id": None, "version": 1, diff --git a/tests/variantstudy/model/command/test_create_st_storage.py b/tests/variantstudy/model/command/test_create_st_storage.py index bc0b8e80d0..14fbef9b8b 100644 --- a/tests/variantstudy/model/command/test_create_st_storage.py +++ b/tests/variantstudy/model/command/test_create_st_storage.py @@ -29,6 +29,8 @@ from antarest.study.storage.variantstudy.model.command_context import CommandContext from antarest.study.storage.variantstudy.model.model import CommandDTO +GEN = np.random.default_rng(1000) + @pytest.fixture(name="recent_study") def recent_study_fixture(empty_study: FileStudy) -> FileStudy: @@ -75,8 +77,8 @@ def recent_study_fixture(empty_study: FileStudy) -> FileStudy: class TestCreateSTStorage: # noinspection SpellCheckingInspection def test_init(self, command_context: CommandContext): - pmax_injection = np.random.rand(8760, 1) - inflows = np.random.uniform(0, 1000, size=(8760, 1)) + pmax_injection = GEN.random((8760, 1)) + inflows = GEN.uniform(0, 1000, size=(8760, 1)) cmd = CreateSTStorage( command_context=command_context, area_id="area_fr", @@ -112,17 +114,22 @@ def test_init__invalid_storage_name(self, recent_study: FileStudy, command_conte parameters=STStorageConfig(**parameters), ) # We get 2 errors because the `storage_name` is duplicated in the `parameters`: - assert ctx.value.errors() == [ - { - "loc": ("__root__",), - "msg": "Invalid name '?%$$'.", - "type": "value_error", - } - ] + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Invalid name '?%$$'." + assert raised_error["input"] == { + "efficiency": 0.94, + "group": "Battery", + "initialleveloptim": True, + "injectionnominalcapacity": 1500, + "name": "?%$$", + "reservoircapacity": 20000, + "withdrawalnominalcapacity": 1500, + } - # noinspection SpellCheckingInspection def test_init__invalid_matrix_values(self, command_context: CommandContext): - array = np.random.rand(8760, 1) # OK + array = GEN.random((8760, 1)) array[10] = 25 # BAD with pytest.raises(ValidationError) as ctx: CreateSTStorage( @@ -131,17 +138,15 @@ def test_init__invalid_matrix_values(self, command_context: CommandContext): parameters=STStorageConfig(**PARAMETERS), pmax_injection=array.tolist(), # type: ignore ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "Matrix values should be between 0 and 1", - "type": "value_error", - } - ] + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Matrix values should be between 0 and 1" + assert "pmax_injection" in raised_error["input"] # noinspection SpellCheckingInspection def test_init__invalid_matrix_shape(self, command_context: CommandContext): - array = np.random.rand(24, 1) # BAD SHAPE + array = GEN.random((24, 1)) # BAD SHAPE with pytest.raises(ValidationError) as ctx: CreateSTStorage( command_context=command_context, @@ -149,18 +154,14 @@ def test_init__invalid_matrix_shape(self, command_context: CommandContext): parameters=STStorageConfig(**PARAMETERS), pmax_injection=array.tolist(), # type: ignore ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "Invalid matrix shape (24, 1), expected (8760, 1)", - "type": "value_error", - } - ] - - # noinspection SpellCheckingInspection + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Invalid matrix shape (24, 1), expected (8760, 1)" + assert "pmax_injection" in raised_error["input"] def test_init__invalid_nan_value(self, command_context: CommandContext): - array = np.random.rand(8760, 1) # OK + array = GEN.random((8760, 1)) # OK array[20] = np.nan # BAD with pytest.raises(ValidationError) as ctx: CreateSTStorage( @@ -169,37 +170,25 @@ def test_init__invalid_nan_value(self, command_context: CommandContext): parameters=STStorageConfig(**PARAMETERS), pmax_injection=array.tolist(), # type: ignore ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "Matrix values cannot contain NaN", - "type": "value_error", - } - ] - - # noinspection SpellCheckingInspection + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Matrix values cannot contain NaN" + assert "pmax_injection" in raised_error["input"] def test_init__invalid_matrix_type(self, command_context: CommandContext): - array = {"data": [1, 2, 3]} with pytest.raises(ValidationError) as ctx: CreateSTStorage( command_context=command_context, area_id="area_fr", parameters=STStorageConfig(**PARAMETERS), - pmax_injection=array, # type: ignore + pmax_injection=[1, 2, 3], ) - assert ctx.value.errors() == [ - { - "loc": ("pmax_injection",), - "msg": "value is not a valid list", - "type": "type_error.list", - }, - { - "loc": ("pmax_injection",), - "msg": "str type expected", - "type": "type_error.str", - }, - ] + assert ctx.value.error_count() == 1 + raised_error = ctx.value.errors()[0] + assert raised_error["type"] == "value_error" + assert raised_error["msg"] == "Value error, Invalid matrix shape (3,), expected (8760, 1)" + assert "pmax_injection" in raised_error["input"] def test_apply_config__invalid_version(self, empty_study: FileStudy, command_context: CommandContext): # Given an old study in version 720 @@ -305,8 +294,8 @@ def test_apply__nominal_case(self, recent_study: FileStudy, command_context: Com create_area.apply(recent_study) # Then, apply the command to create a new ST Storage - pmax_injection = np.random.rand(8760, 1) - inflows = np.random.uniform(0, 1000, size=(8760, 1)) + pmax_injection = GEN.random((8760, 1)) + inflows = GEN.uniform(0, 1000, size=(8760, 1)) cmd = CreateSTStorage( command_context=command_context, area_id=transform_name_to_id(create_area.area_name), @@ -439,8 +428,8 @@ def test_create_diff__not_equals(self, command_context: CommandContext): area_id="area_fr", parameters=STStorageConfig(**PARAMETERS), ) - upper_rule_curve = np.random.rand(8760, 1) - inflows = np.random.uniform(0, 1000, size=(8760, 1)) + upper_rule_curve = GEN.random((8760, 1)) + inflows = GEN.uniform(0, 1000, size=(8760, 1)) other = CreateSTStorage( command_context=command_context, area_id=cmd.area_id, diff --git a/tests/variantstudy/model/command/test_manage_district.py b/tests/variantstudy/model/command/test_manage_district.py index cd944dfdba..47b6e55c42 100644 --- a/tests/variantstudy/model/command/test_manage_district.py +++ b/tests/variantstudy/model/command/test_manage_district.py @@ -26,7 +26,6 @@ def test_manage_district(empty_study: FileStudy, command_context: CommandContext): - study_path = empty_study.config.study_path area1 = "Area1" area1_id = transform_name_to_id(area1) @@ -34,23 +33,22 @@ def test_manage_district(empty_study: FileStudy, command_context: CommandContext area2_id = transform_name_to_id(area2) area3 = "Area3" - area3_id = transform_name_to_id(area3) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area2, "command_context": command_context, } ).apply(empty_study) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area3, "command_context": command_context, diff --git a/tests/variantstudy/model/command/test_remove_area.py b/tests/variantstudy/model/command/test_remove_area.py index 94d3a8f3e5..474ce51a42 100644 --- a/tests/variantstudy/model/command/test_remove_area.py +++ b/tests/variantstudy/model/command/test_remove_area.py @@ -212,7 +212,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): if empty_study.config.version >= 810: default_ruleset[f"r,{area_id2},0,{renewable_id.lower()}"] = 1 if empty_study.config.version >= 870: - default_ruleset[f"bc,bd 2,0"] = 1 + default_ruleset["bc,bd 2,0"] = 1 if empty_study.config.version >= 920: default_ruleset[f"hfl,{area_id2},0"] = 1 if empty_study.config.version >= 910: diff --git a/tests/variantstudy/model/command/test_remove_link.py b/tests/variantstudy/model/command/test_remove_link.py index 1a3f5f0428..e7e46bb6af 100644 --- a/tests/variantstudy/model/command/test_remove_link.py +++ b/tests/variantstudy/model/command/test_remove_link.py @@ -70,7 +70,7 @@ def test_remove_link__validation(self, area1: str, area2: str, expected: t.Dict[ and that the areas are well-ordered in alphabetical order (Antares Solver convention). """ command = RemoveLink(area1=area1, area2=area2, command_context=Mock(spec=CommandContext)) - actual = command.dict(include={"area1", "area2"}) + actual = command.model_dump(include={"area1", "area2"}) assert actual == expected @staticmethod diff --git a/tests/variantstudy/model/command/test_remove_st_storage.py b/tests/variantstudy/model/command/test_remove_st_storage.py index d8d2048bf5..032a303650 100644 --- a/tests/variantstudy/model/command/test_remove_st_storage.py +++ b/tests/variantstudy/model/command/test_remove_st_storage.py @@ -83,9 +83,11 @@ def test_init__invalid_storage_id(self, recent_study: FileStudy, command_context assert ctx.value.errors() == [ { "ctx": {"pattern": "[a-z0-9_(),& -]+"}, + "input": "?%$$", "loc": ("storage_id",), - "msg": 'string does not match regex "[a-z0-9_(),& -]+"', - "type": "value_error.str.regex", + "msg": "String should match pattern '[a-z0-9_(),& -]+'", + "type": "string_pattern_mismatch", + "url": "https://errors.pydantic.dev/2.8/v/string_pattern_mismatch", } ] diff --git a/tests/variantstudy/model/command/test_replace_matrix.py b/tests/variantstudy/model/command/test_replace_matrix.py index f30e05ff7a..8a215692f5 100644 --- a/tests/variantstudy/model/command/test_replace_matrix.py +++ b/tests/variantstudy/model/command/test_replace_matrix.py @@ -32,7 +32,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): area1 = "Area1" area1_id = transform_name_to_id(area1) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, @@ -40,7 +40,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): ).apply(empty_study) target_element = f"input/hydro/common/capacity/maxpower_{area1_id}" - replace_matrix = ReplaceMatrix.parse_obj( + replace_matrix = ReplaceMatrix.model_validate( { "target": target_element, "matrix": [[0]], @@ -56,7 +56,7 @@ def test_apply(self, empty_study: FileStudy, command_context: CommandContext): assert matrix_id in target_path.read_text() target_element = "fake/matrix/path" - replace_matrix = ReplaceMatrix.parse_obj( + replace_matrix = ReplaceMatrix.model_validate( { "target": target_element, "matrix": [[0]], diff --git a/tests/variantstudy/model/command/test_update_config.py b/tests/variantstudy/model/command/test_update_config.py index 28800ad188..bbce0a45c1 100644 --- a/tests/variantstudy/model/command/test_update_config.py +++ b/tests/variantstudy/model/command/test_update_config.py @@ -32,7 +32,7 @@ def test_update_config(empty_study: FileStudy, command_context: CommandContext): area1 = "Area1" area1_id = transform_name_to_id(area1) - CreateArea.parse_obj( + CreateArea.model_validate( { "area_name": area1, "command_context": command_context, diff --git a/tests/variantstudy/model/test_variant_model.py b/tests/variantstudy/model/test_variant_model.py index 3b714c046a..b51dc8509d 100644 --- a/tests/variantstudy/model/test_variant_model.py +++ b/tests/variantstudy/model/test_variant_model.py @@ -153,7 +153,7 @@ def test_commands_service( repository=variant_study_service.repository, ) results = generator.generate_snapshot(saved_id, jwt_user, denormalize=False) - assert results.dict() == { + assert results.model_dump() == { "success": True, "details": [ {