From f332df168fabe5a6af0fcfaa4cad72051c20a75f Mon Sep 17 00:00:00 2001 From: Sylvain Leclerc Date: Thu, 5 Sep 2024 13:49:08 +0200 Subject: [PATCH] fix(mypy,doc): restore more annotations, fix deprecated on_event Signed-off-by: Sylvain Leclerc --- antarest/launcher/web.py | 26 ++++++++++-- antarest/main.py | 25 ++++++------ antarest/study/business/all_optional_meta.py | 9 +++++ .../study/business/areas/hydro_management.py | 40 ++++--------------- .../study/business/xpansion_management.py | 1 - .../study/storage/variantstudy/model/model.py | 11 ++--- antarest/study/web/study_data_blueprint.py | 3 +- antarest/study/web/variant_blueprint.py | 1 - 8 files changed, 56 insertions(+), 60 deletions(-) diff --git a/antarest/launcher/web.py b/antarest/launcher/web.py index 1510936921..c07261384a 100644 --- a/antarest/launcher/web.py +++ b/antarest/launcher/web.py @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from fastapi.exceptions import HTTPException from antarest.core.config import Config, InvalidConfigurationError, Launcher @@ -53,7 +53,25 @@ def __init__(self, solver: str) -> None: ) -# TODO SL: restore query example ? for launcher +LauncherQuery = Query( + default=Launcher.DEFAULT, + openapi_examples={ + "Default launcher": { + "description": "Default solver (auto-detected)", + "value": "default", + }, + "SLURM launcher": { + "description": "SLURM solver configuration", + "value": "slurm", + }, + "Local launcher": { + "description": "Local solver configuration", + "value": "local", + }, + }, +) + + def create_launcher_api(service: LauncherService, config: Config) -> APIRouter: bp = APIRouter(prefix="/v1/launcher") @@ -268,8 +286,8 @@ def get_nb_cores(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]: "/time-limit", tags=[APITag.launcher], summary="Retrieve the time limit for a job (in hours)", - ) # TODO SL: check this annotation - def get_time_limit(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]: + ) + def get_time_limit(launcher: Launcher = LauncherQuery) -> Dict[str, int]: """ Retrieve the time limit for a job (in hours) of the given launcher: "local" or "slurm". diff --git a/antarest/main.py b/antarest/main.py index 2fe6f3c5e9..93bbc50bf0 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -14,8 +14,9 @@ import copy import logging import re +from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, cast +from typing import Any, AsyncGenerator, Dict, Optional, Sequence, Tuple, cast import pydantic import uvicorn @@ -250,12 +251,22 @@ def fastapi_app( logger.info("Initiating application") + @asynccontextmanager + async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]: + import asyncio + from concurrent.futures import ThreadPoolExecutor + + loop = asyncio.get_running_loop() + loop.set_default_executor(ThreadPoolExecutor(max_workers=config.server.worker_threadpool_size)) + yield + application = FastAPI( title="AntaREST", version=__version__, docs_url=None, root_path=config.root_path, openapi_tags=tags_metadata, + lifespan=set_default_executor, ) # Database @@ -286,18 +297,6 @@ def home(request: Request) -> Any: def home(request: Request) -> Any: return "" - # TODO SL: on_event is deprecated, use lifespan event handlers instead. - # - # Read more about it in the - # [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/). - @application.on_event("startup") - def set_default_executor() -> None: - import asyncio - from concurrent.futures import ThreadPoolExecutor - - loop = asyncio.get_running_loop() - loop.set_default_executor(ThreadPoolExecutor(max_workers=config.server.worker_threadpool_size)) - # TODO move that elsewhere @AuthJWT.load_config # type: ignore def get_config() -> JwtSettings: diff --git a/antarest/study/business/all_optional_meta.py b/antarest/study/business/all_optional_meta.py index 9f6d1b4f86..c580e8fa00 100644 --- a/antarest/study/business/all_optional_meta.py +++ b/antarest/study/business/all_optional_meta.py @@ -19,6 +19,15 @@ def all_optional_model(model: t.Type[BaseModel]) -> t.Type[BaseModel]: + """ + This decorator can be used to make all fields of a pydantic model optionals. + + Args: + model: The pydantic model to modify. + + Returns: + The modified model. + """ kwargs = {} for field_name, field_info in model.model_fields.items(): new = copy.deepcopy(field_info) diff --git a/antarest/study/business/areas/hydro_management.py b/antarest/study/business/areas/hydro_management.py index c42438bfda..d464336d93 100644 --- a/antarest/study/business/areas/hydro_management.py +++ b/antarest/study/business/areas/hydro_management.py @@ -37,47 +37,23 @@ class InflowStructure(FormFieldsBaseModel): ) -# TODO SL: why field validation is moved to method ? @all_optional_model class ManagementOptionsFormFields(FormFieldsBaseModel): - inter_daily_breakdown: float - intra_daily_modulation: float - inter_monthly_breakdown: float + inter_daily_breakdown: float = Field(ge=0) + intra_daily_modulation: float = Field(ge=1) + inter_monthly_breakdown: float = Field(ge=0) reservoir: bool - reservoir_capacity: float + reservoir_capacity: float = Field(ge=0) follow_load: bool use_water: bool hard_bounds: bool - initialize_reservoir_date: int + initialize_reservoir_date: int = Field(ge=0, le=11) use_heuristic: bool power_to_level: bool use_leeway: bool - leeway_low: float - leeway_up: float - pumping_efficiency: float - - @model_validator(mode="before") - def check_type_validity(cls, values: Dict[str, Any]) -> Dict[str, Optional[Any]]: - cls.validate_ge("inter_daily_breakdown", values.get("inter_daily_breakdown", 0), 0) - cls.validate_ge("intra_daily_modulation", values.get("intra_daily_modulation", 1), 1) - cls.validate_ge("inter_monthly_breakdown", values.get("inter_monthly_breakdown", 0), 0) - cls.validate_ge("reservoir_capacity", values.get("reservoir_capacity", 0), 0) - cls.validate_ge("initialize_reservoir_date", values.get("initialize_reservoir_date", 0), 0) - cls.validate_le("initialize_reservoir_date", values.get("initialize_reservoir_date", 11), 11) - cls.validate_ge("leeway_low", values.get("leeway_low", 0), 0) - cls.validate_ge("leeway_up", values.get("leeway_up", 0), 0) - cls.validate_ge("pumping_efficiency", values.get("pumping_efficiency", 0), 0) - return values - - @staticmethod - def validate_ge(field: str, value: Union[int, float], ge: int) -> None: - if value < ge: - raise ValueError(f"Field {field} must be greater than or equal to {ge}") - - @staticmethod - def validate_le(field: str, value: Union[int, float], le: int) -> None: - if value > le: - raise ValueError(f"Field {field} must be lower than or equal to {le}") + leeway_low: float = Field(ge=0) + leeway_up: float = Field(ge=0) + pumping_efficiency: float = Field(ge=0) HYDRO_PATH = "input/hydro/hydro" diff --git a/antarest/study/business/xpansion_management.py b/antarest/study/business/xpansion_management.py index 06dc0446a0..6c02724afd 100644 --- a/antarest/study/business/xpansion_management.py +++ b/antarest/study/business/xpansion_management.py @@ -233,7 +233,6 @@ class UpdateXpansionSettings(XpansionSettings): ) -# TODO SL: None in non optional field ? class XpansionCandidateDTO(BaseModel): # The id of the candidate is irrelevant, so it should stay hidden for the user # The names should be the section titles of the file, and the id should be removed diff --git a/antarest/study/storage/variantstudy/model/model.py b/antarest/study/storage/variantstudy/model/model.py index 0be3c75353..171a1f4a78 100644 --- a/antarest/study/storage/variantstudy/model/model.py +++ b/antarest/study/storage/variantstudy/model/model.py @@ -14,7 +14,7 @@ import uuid import typing_extensions as te -from pydantic import BaseModel +from pydantic import BaseModel, Field from antarest.core.model import JSON from antarest.study.model import StudyMetadataDTO @@ -92,7 +92,7 @@ class CommandResultDTO(BaseModel): message: str -class VariantTreeDTO: +class VariantTreeDTO(BaseModel): """ This class represents a variant tree structure. @@ -101,8 +101,5 @@ class VariantTreeDTO: children: A list of variant children. """ - def __init__(self, node: StudyMetadataDTO, children: t.MutableSequence["VariantTreeDTO"]) -> None: - # We are intentionally not using Pydantic’s `BaseModel` here to prevent potential - # `RecursionError` exceptions that can occur with Pydantic versions before v2. - self.node = node - self.children = children or [] + node: StudyMetadataDTO + children: t.List["VariantTreeDTO"] = Field(default_factory=list) diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index 355901830e..9b1cca0fc5 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -508,7 +508,6 @@ def get_inflow_structure( "/studies/{uuid}/areas/{area_id}/hydro/inflow-structure", tags=[APITag.study_data], summary="Update inflow structure values", - response_model=None, # TODO SL: was InflowStructure ) def update_inflow_structure( uuid: str, @@ -523,7 +522,7 @@ def update_inflow_structure( ) params = RequestParameters(user=current_user) study = study_service.check_study_access(uuid, StudyPermissionType.WRITE, params) - return study_service.hydro_manager.update_inflow_structure(study, area_id, values) + study_service.hydro_manager.update_inflow_structure(study, area_id, values) @bp.put( "/studies/{uuid}/matrix", diff --git a/antarest/study/web/variant_blueprint.py b/antarest/study/web/variant_blueprint.py index cb6eaf840f..2f83223557 100644 --- a/antarest/study/web/variant_blueprint.py +++ b/antarest/study/web/variant_blueprint.py @@ -107,7 +107,6 @@ def create_variant( "/studies/{uuid}/variants", tags=[APITag.study_variant_management], summary="Get children variants", - response_model=None, # TODO SL: check this, there was a responses also ) def get_variants( uuid: str,