Skip to content

Commit

Permalink
fix(mypy,doc): restore more annotations, fix deprecated on_event
Browse files Browse the repository at this point in the history
Signed-off-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
sylvlecl committed Sep 5, 2024
1 parent b4ed725 commit f332df1
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 60 deletions.
26 changes: 22 additions & 4 deletions antarest/launcher/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Dict, List, Optional
from uuid import UUID

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from fastapi.exceptions import HTTPException

from antarest.core.config import Config, InvalidConfigurationError, Launcher
Expand Down Expand Up @@ -53,7 +53,25 @@ def __init__(self, solver: str) -> None:
)


# TODO SL: restore query example ? for launcher
LauncherQuery = Query(
default=Launcher.DEFAULT,
openapi_examples={
"Default launcher": {
"description": "Default solver (auto-detected)",
"value": "default",
},
"SLURM launcher": {
"description": "SLURM solver configuration",
"value": "slurm",
},
"Local launcher": {
"description": "Local solver configuration",
"value": "local",
},
},
)


def create_launcher_api(service: LauncherService, config: Config) -> APIRouter:
bp = APIRouter(prefix="/v1/launcher")

Expand Down Expand Up @@ -268,8 +286,8 @@ def get_nb_cores(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]:
"/time-limit",
tags=[APITag.launcher],
summary="Retrieve the time limit for a job (in hours)",
) # TODO SL: check this annotation
def get_time_limit(launcher: Launcher = Launcher.DEFAULT) -> Dict[str, int]:
)
def get_time_limit(launcher: Launcher = LauncherQuery) -> Dict[str, int]:
"""
Retrieve the time limit for a job (in hours) of the given launcher: "local" or "slurm".
Expand Down
25 changes: 12 additions & 13 deletions antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import copy
import logging
import re
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, cast
from typing import Any, AsyncGenerator, Dict, Optional, Sequence, Tuple, cast

import pydantic
import uvicorn
Expand Down Expand Up @@ -250,12 +251,22 @@ def fastapi_app(

logger.info("Initiating application")

@asynccontextmanager
async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]:
import asyncio
from concurrent.futures import ThreadPoolExecutor

loop = asyncio.get_running_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=config.server.worker_threadpool_size))
yield

application = FastAPI(
title="AntaREST",
version=__version__,
docs_url=None,
root_path=config.root_path,
openapi_tags=tags_metadata,
lifespan=set_default_executor,
)

# Database
Expand Down Expand Up @@ -286,18 +297,6 @@ def home(request: Request) -> Any:
def home(request: Request) -> Any:
return ""

# TODO SL: on_event is deprecated, use lifespan event handlers instead.
#
# Read more about it in the
# [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/).
@application.on_event("startup")
def set_default_executor() -> None:
import asyncio
from concurrent.futures import ThreadPoolExecutor

loop = asyncio.get_running_loop()
loop.set_default_executor(ThreadPoolExecutor(max_workers=config.server.worker_threadpool_size))

# TODO move that elsewhere
@AuthJWT.load_config # type: ignore
def get_config() -> JwtSettings:
Expand Down
9 changes: 9 additions & 0 deletions antarest/study/business/all_optional_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@


def all_optional_model(model: t.Type[BaseModel]) -> t.Type[BaseModel]:
"""
This decorator can be used to make all fields of a pydantic model optionals.
Args:
model: The pydantic model to modify.
Returns:
The modified model.
"""
kwargs = {}
for field_name, field_info in model.model_fields.items():
new = copy.deepcopy(field_info)
Expand Down
40 changes: 8 additions & 32 deletions antarest/study/business/areas/hydro_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,23 @@ class InflowStructure(FormFieldsBaseModel):
)


# TODO SL: why field validation is moved to method ?
@all_optional_model
class ManagementOptionsFormFields(FormFieldsBaseModel):
inter_daily_breakdown: float
intra_daily_modulation: float
inter_monthly_breakdown: float
inter_daily_breakdown: float = Field(ge=0)
intra_daily_modulation: float = Field(ge=1)
inter_monthly_breakdown: float = Field(ge=0)
reservoir: bool
reservoir_capacity: float
reservoir_capacity: float = Field(ge=0)
follow_load: bool
use_water: bool
hard_bounds: bool
initialize_reservoir_date: int
initialize_reservoir_date: int = Field(ge=0, le=11)
use_heuristic: bool
power_to_level: bool
use_leeway: bool
leeway_low: float
leeway_up: float
pumping_efficiency: float

@model_validator(mode="before")
def check_type_validity(cls, values: Dict[str, Any]) -> Dict[str, Optional[Any]]:
cls.validate_ge("inter_daily_breakdown", values.get("inter_daily_breakdown", 0), 0)
cls.validate_ge("intra_daily_modulation", values.get("intra_daily_modulation", 1), 1)
cls.validate_ge("inter_monthly_breakdown", values.get("inter_monthly_breakdown", 0), 0)
cls.validate_ge("reservoir_capacity", values.get("reservoir_capacity", 0), 0)
cls.validate_ge("initialize_reservoir_date", values.get("initialize_reservoir_date", 0), 0)
cls.validate_le("initialize_reservoir_date", values.get("initialize_reservoir_date", 11), 11)
cls.validate_ge("leeway_low", values.get("leeway_low", 0), 0)
cls.validate_ge("leeway_up", values.get("leeway_up", 0), 0)
cls.validate_ge("pumping_efficiency", values.get("pumping_efficiency", 0), 0)
return values

@staticmethod
def validate_ge(field: str, value: Union[int, float], ge: int) -> None:
if value < ge:
raise ValueError(f"Field {field} must be greater than or equal to {ge}")

@staticmethod
def validate_le(field: str, value: Union[int, float], le: int) -> None:
if value > le:
raise ValueError(f"Field {field} must be lower than or equal to {le}")
leeway_low: float = Field(ge=0)
leeway_up: float = Field(ge=0)
pumping_efficiency: float = Field(ge=0)


HYDRO_PATH = "input/hydro/hydro"
Expand Down
1 change: 0 additions & 1 deletion antarest/study/business/xpansion_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ class UpdateXpansionSettings(XpansionSettings):
)


# TODO SL: None in non optional field ?
class XpansionCandidateDTO(BaseModel):
# The id of the candidate is irrelevant, so it should stay hidden for the user
# The names should be the section titles of the file, and the id should be removed
Expand Down
11 changes: 4 additions & 7 deletions antarest/study/storage/variantstudy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import uuid

import typing_extensions as te
from pydantic import BaseModel
from pydantic import BaseModel, Field

from antarest.core.model import JSON
from antarest.study.model import StudyMetadataDTO
Expand Down Expand Up @@ -92,7 +92,7 @@ class CommandResultDTO(BaseModel):
message: str


class VariantTreeDTO:
class VariantTreeDTO(BaseModel):
"""
This class represents a variant tree structure.
Expand All @@ -101,8 +101,5 @@ class VariantTreeDTO:
children: A list of variant children.
"""

def __init__(self, node: StudyMetadataDTO, children: t.MutableSequence["VariantTreeDTO"]) -> None:
# We are intentionally not using Pydantic’s `BaseModel` here to prevent potential
# `RecursionError` exceptions that can occur with Pydantic versions before v2.
self.node = node
self.children = children or []
node: StudyMetadataDTO
children: t.List["VariantTreeDTO"] = Field(default_factory=list)
3 changes: 1 addition & 2 deletions antarest/study/web/study_data_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def get_inflow_structure(
"/studies/{uuid}/areas/{area_id}/hydro/inflow-structure",
tags=[APITag.study_data],
summary="Update inflow structure values",
response_model=None, # TODO SL: was InflowStructure
)
def update_inflow_structure(
uuid: str,
Expand All @@ -523,7 +522,7 @@ def update_inflow_structure(
)
params = RequestParameters(user=current_user)
study = study_service.check_study_access(uuid, StudyPermissionType.WRITE, params)
return study_service.hydro_manager.update_inflow_structure(study, area_id, values)
study_service.hydro_manager.update_inflow_structure(study, area_id, values)

@bp.put(
"/studies/{uuid}/matrix",
Expand Down
1 change: 0 additions & 1 deletion antarest/study/web/variant_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def create_variant(
"/studies/{uuid}/variants",
tags=[APITag.study_variant_management],
summary="Get children variants",
response_model=None, # TODO SL: check this, there was a responses also
)
def get_variants(
uuid: str,
Expand Down

0 comments on commit f332df1

Please sign in to comment.