diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5a7b9ad212..91523a511d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -53,7 +53,7 @@ jobs: pip install -r requirements-dev.txt - name: Test with pytest run: | - pytest --cov antarest --cov-report xml + pytest --cov antarest --cov-report xml -n auto - name: Archive code coverage results if: matrix.os == 'ubuntu-20.04' uses: actions/upload-artifact@v4 diff --git a/antarest/core/application.py b/antarest/core/application.py new file mode 100644 index 0000000000..3f09bd4102 --- /dev/null +++ b/antarest/core/application.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from dataclasses import dataclass +from typing import Optional + +from fastapi import APIRouter, FastAPI + + +@dataclass(frozen=True) +class AppBuildContext: + """ + Base elements of the application, for use at construction time: + - app: the actual fastapi application, where middlewares, exception handlers, etc. may be added + - api_root: the route under which all API and WS endpoints must be registered + + API routes should not be added straight to app, but under api_root instead, + so that they are correctly prefixed if needed (/api for standalone mode). + + Warning: the inclusion of api_root must happen AFTER all subroutes + have been registered, hence the build method. + """ + + app: FastAPI + api_root: APIRouter + + def build(self) -> FastAPI: + """ + Finalizes the app construction by including the API route. + Must be performed AFTER all subroutes have been added. + """ + self.app.include_router(self.api_root) + return self.app + + +def create_app_ctxt(app: FastAPI, api_root: Optional[APIRouter] = None) -> AppBuildContext: + if not api_root: + api_root = APIRouter() + return AppBuildContext(app, api_root) diff --git a/antarest/core/config.py b/antarest/core/config.py index 3378444d92..4b85331d6a 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -593,6 +593,7 @@ class Config: cache: CacheConfig = CacheConfig() tasks: TaskConfig = TaskConfig() root_path: str = "" + api_prefix: str = "" @classmethod def from_dict(cls, data: JSON) -> "Config": @@ -611,6 +612,7 @@ def from_dict(cls, data: JSON) -> "Config": cache=CacheConfig.from_dict(data["cache"]) if "cache" in data else defaults.cache, tasks=TaskConfig.from_dict(data["tasks"]) if "tasks" in data else defaults.tasks, root_path=data.get("root_path", defaults.root_path), + api_prefix=data.get("api_prefix", defaults.api_prefix), ) @classmethod diff --git a/antarest/core/filetransfer/main.py b/antarest/core/filetransfer/main.py index cbf732ef22..3583dc5701 100644 --- a/antarest/core/filetransfer/main.py +++ b/antarest/core/filetransfer/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.repository import FileDownloadRepository from antarest.core.filetransfer.service import FileTransferManager @@ -22,10 +23,10 @@ def build_filetransfer_service( - application: Optional[FastAPI], event_bus: IEventBus, config: Config + app_ctxt: Optional[AppBuildContext], event_bus: IEventBus, config: Config ) -> FileTransferManager: ftm = FileTransferManager(repository=FileDownloadRepository(), event_bus=event_bus, config=config) - if application: - application.include_router(create_file_transfer_api(ftm, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_file_transfer_api(ftm, config)) return ftm diff --git a/antarest/core/maintenance/main.py b/antarest/core/maintenance/main.py index b129d85691..8717150d07 100644 --- a/antarest/core/maintenance/main.py +++ b/antarest/core/maintenance/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.cache import ICache from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus @@ -23,7 +24,7 @@ def build_maintenance_manager( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, cache: ICache, event_bus: IEventBus = DummyEventBusService(), @@ -31,7 +32,7 @@ def build_maintenance_manager( repository = MaintenanceRepository() service = MaintenanceService(config, repository, event_bus, cache) - if application: - application.include_router(create_maintenance_api(service, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_maintenance_api(service, config)) return service diff --git a/antarest/core/tasks/main.py b/antarest/core/tasks/main.py index ae3d7dafb8..74685ba836 100644 --- a/antarest/core/tasks/main.py +++ b/antarest/core/tasks/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus from antarest.core.tasks.repository import TaskJobRepository @@ -22,14 +23,14 @@ def build_taskjob_manager( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, event_bus: IEventBus = DummyEventBusService(), ) -> ITaskService: repository = TaskJobRepository() service = TaskJobService(config, repository, event_bus) - if application: - application.include_router(create_tasks_api(service, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_tasks_api(service, config)) return service diff --git a/antarest/eventbus/main.py b/antarest/eventbus/main.py index 7c53f5cbce..6ccf56e644 100644 --- a/antarest/eventbus/main.py +++ b/antarest/eventbus/main.py @@ -12,9 +12,10 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from redis import Redis +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.eventbus.business.local_eventbus import LocalEventBus from antarest.eventbus.business.redis_eventbus import RedisEventBus @@ -23,7 +24,7 @@ def build_eventbus( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, autostart: bool = True, redis_client: Optional[Redis] = None, # type: ignore @@ -33,6 +34,6 @@ def build_eventbus( autostart, ) - if application: - configure_websockets(application, config, eventbus) + if app_ctxt: + configure_websockets(app_ctxt, config, eventbus) return eventbus diff --git a/antarest/eventbus/web.py b/antarest/eventbus/web.py index c5d635cc99..ba363db6c4 100644 --- a/antarest/eventbus/web.py +++ b/antarest/eventbus/web.py @@ -17,10 +17,11 @@ from http import HTTPStatus from typing import List, Optional -from fastapi import Depends, FastAPI, HTTPException, Query +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query from pydantic import BaseModel from starlette.websockets import WebSocket, WebSocketDisconnect +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.eventbus import Event, IEventBus from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTUser @@ -91,7 +92,7 @@ async def broadcast(self, message: str, permissions: PermissionInfo, channel: st await connection.websocket.send_text(message) -def configure_websockets(application: FastAPI, config: Config, event_bus: IEventBus) -> None: +def configure_websockets(app_ctxt: AppBuildContext, config: Config, event_bus: IEventBus) -> None: manager = ConnectionManager() async def send_event_to_ws(event: Event) -> None: @@ -100,7 +101,7 @@ async def send_event_to_ws(event: Event) -> None: del event_data["channel"] await manager.broadcast(json.dumps(event_data), event.permissions, event.channel) - @application.websocket("/ws") + @app_ctxt.api_root.websocket("/ws") async def connect( websocket: WebSocket, token: str = Query(...), diff --git a/antarest/front.py b/antarest/front.py new file mode 100644 index 0000000000..8de0f05e82 --- /dev/null +++ b/antarest/front.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. +""" +This module contains the logic necessary to serve both +the front-end application and the backend HTTP application. + +This includes: + - serving static frontend files + - redirecting "not found" requests to home, which itself redirects to index.html + - providing the endpoint /config.json, which the front-end uses to know + what are the API and websocket prefixes +""" + +import re +from pathlib import Path +from typing import Any, Optional, Sequence + +from fastapi import FastAPI +from pydantic import BaseModel +from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import FileResponse +from starlette.staticfiles import StaticFiles +from starlette.types import ASGIApp + +from antarest.core.utils.string import to_camel_case + + +class RedirectMiddleware(BaseHTTPMiddleware): + """ + Middleware that rewrites the URL path to "/" for incoming requests + that do not match the known end points. This is useful for redirecting requests + to the main page of a ReactJS application when the user refreshes the browser. + """ + + def __init__( + self, + app: ASGIApp, + dispatch: Optional[DispatchFunction] = None, + route_paths: Sequence[str] = (), + ) -> None: + """ + Initializes an instance of the URLRewriterMiddleware. + + Args: + app: The ASGI application to which the middleware is applied. + dispatch: The dispatch function to use. + route_paths: The known route paths of the application. + Requests that do not match any of these paths will be rewritten to the root path. + + Note: + The `route_paths` should contain all the known endpoints of the application. + """ + dispatch = self.dispatch if dispatch is None else dispatch + super().__init__(app, dispatch) + self.known_prefixes = {re.findall(r"/(?:(?!/).)*", p)[0] for p in route_paths if p != "/"} + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Any: + """ + Intercepts the incoming request and rewrites the URL path if necessary. + Passes the modified or original request to the next middleware or endpoint handler. + """ + url_path = request.scope["path"] + if url_path in {"", "/"}: + pass + elif not any(url_path.startswith(ep) for ep in self.known_prefixes): + request.scope["path"] = "/" + return await call_next(request) + + +class BackEndConfig(BaseModel): + """ + Configuration about backend URLs served to the frontend. + """ + + rest_endpoint: str + ws_endpoint: str + + class Config: + populate_by_name = True + alias_generator = to_camel_case + + +def create_backend_config(api_prefix: str) -> BackEndConfig: + if not api_prefix.startswith("/"): + api_prefix = "/" + api_prefix + return BackEndConfig(rest_endpoint=f"{api_prefix}", ws_endpoint=f"{api_prefix}/ws") + + +def add_front_app(application: FastAPI, resources_dir: Path, api_prefix: str) -> None: + """ + This functions adds the logic necessary to serve both + the front-end application and the backend HTTP application. + + This includes: + - serving static frontend files + - redirecting "not found" requests to home, which itself redirects to index.html + - providing the endpoint /config.json, which the front-end uses to know + what are the API and websocket prefixes + """ + backend_config = create_backend_config(api_prefix) + + front_app_dir = resources_dir / "webapp" + + # Serve front-end files + application.mount( + "/static", + StaticFiles(directory=front_app_dir), + name="static", + ) + + # Redirect home to index.html + @application.get("/", include_in_schema=False) + def home(request: Request) -> Any: + return FileResponse(front_app_dir / "index.html", 200) + + # Serve config for the front-end at /config.json + @application.get("/config.json", include_in_schema=False) + def get_api_paths_config(request: Request) -> BackEndConfig: + return backend_config + + # When the web application is running in Desktop mode, the ReactJS web app + # is served at the `/static` entry point. Any requests that are not API + # requests should be redirected to the `index.html` file, which will handle + # the route provided by the URL. + route_paths = [r.path for r in application.routes] # type: ignore + application.add_middleware( + RedirectMiddleware, + route_paths=route_paths, + ) diff --git a/antarest/gui.py b/antarest/gui.py index 12e2761c6c..f36904a060 100644 --- a/antarest/gui.py +++ b/antarest/gui.py @@ -18,7 +18,7 @@ from multiprocessing import Process from pathlib import Path -import httpx # TODO SL :check its ok on windows +import httpx import uvicorn from PyQt5.QtGui import QIcon from PyQt5.QtWidgets import QAction, QApplication, QMenu, QSystemTrayIcon diff --git a/antarest/launcher/main.py b/antarest/launcher/main.py index 3370936f54..1916b9607b 100644 --- a/antarest/launcher/main.py +++ b/antarest/launcher/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache @@ -26,7 +27,7 @@ def build_launcher( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, study_service: StudyService, file_transfer_manager: FileTransferManager, @@ -49,7 +50,7 @@ def build_launcher( cache=cache, ) - if service_launcher and application: - application.include_router(create_launcher_api(service_launcher, config)) + if service_launcher and app_ctxt: + app_ctxt.api_root.include_router(create_launcher_api(service_launcher, config)) return service_launcher diff --git a/antarest/login/ldap.py b/antarest/login/ldap.py index 4333049ab2..1635efe09d 100644 --- a/antarest/login/ldap.py +++ b/antarest/login/ldap.py @@ -14,7 +14,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -# TODO SL: check this works on windows import httpx from antarest.core.config import Config diff --git a/antarest/login/main.py b/antarest/login/main.py index f39c5c36dd..09fac37ee5 100644 --- a/antarest/login/main.py +++ b/antarest/login/main.py @@ -14,10 +14,11 @@ from http import HTTPStatus from typing import Any, Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from starlette.requests import Request from starlette.responses import JSONResponse +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.interfaces.eventbus import DummyEventBusService, IEventBus from antarest.core.utils.fastapi_sqlalchemy import db @@ -30,7 +31,7 @@ def build_login( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, service: Optional[LoginService] = None, event_bus: IEventBus = DummyEventBusService(), @@ -39,7 +40,7 @@ def build_login( Login module linking dependency Args: - application: flask application + app_ctxt: application config: server configuration service: used by testing to inject mock. Let None to use true instantiation event_bus: used by testing to inject mock. Let None to use true instantiation @@ -66,9 +67,9 @@ def build_login( event_bus=event_bus, ) - if application: + if app_ctxt: - @application.exception_handler(AuthJWTException) + @app_ctxt.app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> Any: return JSONResponse( status_code=HTTPStatus.UNAUTHORIZED, @@ -83,6 +84,6 @@ def check_if_token_is_revoked(decrypted_token: Any) -> bool: with db(): return token_type == "bots" and service is not None and not service.exists_bot(user_id) - if application: - application.include_router(create_login_api(service, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_login_api(service, config)) return service diff --git a/antarest/main.py b/antarest/main.py index 93bbc50bf0..b1fd1480b6 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -13,29 +13,25 @@ import argparse import copy import logging -import re from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, AsyncGenerator, Dict, Optional, Sequence, Tuple, cast +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, cast import pydantic import uvicorn import uvicorn.config -from fastapi import FastAPI, HTTPException +from fastapi import APIRouter, FastAPI, HTTPException from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from ratelimit import RateLimitMiddleware # type: ignore from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore -from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.staticfiles import StaticFiles -from starlette.templating import Jinja2Templates -from starlette.types import ASGIApp from antarest import __version__ +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.core_blueprint import create_utils_routes from antarest.core.filesystem_blueprint import create_file_system_blueprint @@ -47,6 +43,7 @@ from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.fastapi_jwt_auth import AuthJWT +from antarest.front import add_front_app from antarest.login.auth import Auth, JwtSettings from antarest.login.model import init_admin_user from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector @@ -190,55 +187,6 @@ def parse_arguments() -> argparse.Namespace: return parser.parse_args() -class URLRewriterMiddleware(BaseHTTPMiddleware): - """ - Middleware that rewrites the URL path to "/" (root path) for incoming requests - that do not match the known end points. This is useful for redirecting requests - to the main page of a ReactJS application when the user refreshes the browser. - """ - - def __init__( - self, - app: ASGIApp, - dispatch: Optional[DispatchFunction] = None, - root_path: str = "", - route_paths: Sequence[str] = (), - ) -> None: - """ - Initializes an instance of the URLRewriterMiddleware. - - Args: - app: The ASGI application to which the middleware is applied. - dispatch: The dispatch function to use. - root_path: The root path of the application. - The URL path will be rewritten relative to this root path. - route_paths: The known route paths of the application. - Requests that do not match any of these paths will be rewritten to the root path. - - Note: - The `root_path` can be set to a specific component of the URL path, such as "api". - The `route_paths` should contain all the known endpoints of the application. - """ - dispatch = self.dispatch if dispatch is None else dispatch - super().__init__(app, dispatch) - self.root_path = f"/{root_path}" if root_path else "" - self.known_prefixes = {re.findall(r"/(?:(?!/).)*", p)[0] for p in route_paths if p != "/"} - - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Any: - """ - Intercepts the incoming request and rewrites the URL path if necessary. - Passes the modified or original request to the next middleware or endpoint handler. - """ - url_path = request.scope["path"] - if url_path in {"", "/"}: - pass - elif self.root_path and url_path.startswith(self.root_path): - request.scope["path"] = url_path[len(self.root_path) :] - elif not any(url_path.startswith(ep) for ep in self.known_prefixes): - request.scope["path"] = "/" - return await call_next(request) - - def fastapi_app( config_file: Path, resource_path: Optional[Path] = None, @@ -267,8 +215,13 @@ async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]: root_path=config.root_path, openapi_tags=tags_metadata, lifespan=set_default_executor, + openapi_url=f"{config.api_prefix}/openapi.json", ) + api_root = APIRouter(prefix=config.api_prefix) + + app_ctxt = AppBuildContext(application, api_root) + # Database engine = init_db_engine(config_file, config, auto_upgrade_db) application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS) @@ -279,24 +232,6 @@ async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]: application.add_middleware(LoggingMiddleware) - if mount_front: - application.mount( - "/static", - StaticFiles(directory=str(res / "webapp")), - name="static", - ) - templates = Jinja2Templates(directory=str(res / "templates")) - - @application.get("/", include_in_schema=False) - def home(request: Request) -> Any: - return templates.TemplateResponse("index.html", {"request": request}) - - else: - # noinspection PyUnusedLocal - @application.get("/", include_in_schema=False) - def home(request: Request) -> Any: - return "" - # TODO move that elsewhere @AuthJWT.load_config # type: ignore def get_config() -> JwtSettings: @@ -315,8 +250,8 @@ def get_config() -> JwtSettings: allow_methods=["*"], allow_headers=["*"], ) - application.include_router(create_utils_routes(config)) - application.include_router(create_file_system_blueprint(config)) + api_root.include_router(create_utils_routes(config)) + api_root.include_router(create_file_system_blueprint(config)) # noinspection PyUnusedLocal @application.exception_handler(HTTPException) @@ -426,19 +361,9 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: ) init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd) - services = create_services(config, application) + services = create_services(config, app_ctxt) - if mount_front: - # When the web application is running in Desktop mode, the ReactJS web app - # is served at the `/static` entry point. Any requests that are not API - # requests should be redirected to the `index.html` file, which will handle - # the route provided by the URL. - route_paths = [r.path for r in application.routes] # type: ignore - application.add_middleware( - URLRewriterMiddleware, - root_path=application.root_path, - route_paths=route_paths, - ) + application.include_router(api_root) if config.server.services and Module.WATCHER.value in config.server.services: watcher = cast(Watcher, services["watcher"]) @@ -453,6 +378,15 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: auto_archiver.start() customize_openapi(application) + + if mount_front: + add_front_app(application, res, config.api_prefix) + else: + # noinspection PyUnusedLocal + @application.get("/", include_in_schema=False) + def home(request: Request) -> Any: + return "" + cancel_orphan_tasks(engine=engine, session_args=SESSION_ARGS) return application, services diff --git a/antarest/matrixstore/main.py b/antarest/matrixstore/main.py index eaec4f956a..d8eaf0390a 100644 --- a/antarest/matrixstore/main.py +++ b/antarest/matrixstore/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.service import FileTransferManager from antarest.core.tasks.service import ITaskService @@ -24,7 +25,7 @@ def build_matrix_service( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, file_transfer_manager: FileTransferManager, task_service: ITaskService, @@ -35,7 +36,7 @@ def build_matrix_service( Matrix module linking dependency Args: - application: flask application + app_ctxt: application config: server configuration file_transfer_manager: File transfer manager task_service: Task manager @@ -60,7 +61,7 @@ def build_matrix_service( config=config, ) - if application: - application.include_router(create_matrix_api(service, file_transfer_manager, config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_matrix_api(service, file_transfer_manager, config)) return service diff --git a/antarest/singleton_services.py b/antarest/singleton_services.py index 039c3d00b5..13395a439a 100644 --- a/antarest/singleton_services.py +++ b/antarest/singleton_services.py @@ -56,13 +56,13 @@ def _init(config_file: Path, services_list: List[Module]) -> Dict[Module, IServi services: Dict[Module, IService] = {} if Module.WATCHER in services_list: - watcher = create_watcher(config=config, application=None, study_service=study_service) + watcher = create_watcher(config=config, app_ctxt=None, study_service=study_service) services[Module.WATCHER] = watcher if Module.MATRIX_GC in services_list: matrix_gc = create_matrix_gc( config=config, - application=None, + app_ctxt=None, study_service=study_service, matrix_service=matrix_service, ) diff --git a/antarest/study/business/all_optional_meta.py b/antarest/study/business/all_optional_meta.py index c580e8fa00..7a8440226b 100644 --- a/antarest/study/business/all_optional_meta.py +++ b/antarest/study/business/all_optional_meta.py @@ -17,8 +17,10 @@ from antarest.core.utils.string import to_camel_case +ModelClass = t.TypeVar("ModelClass", bound=BaseModel) -def all_optional_model(model: t.Type[BaseModel]) -> t.Type[BaseModel]: + +def all_optional_model(model: t.Type[ModelClass]) -> t.Type[ModelClass]: """ This decorator can be used to make all fields of a pydantic model optionals. diff --git a/antarest/study/main.py b/antarest/study/main.py index e80a353493..1efa9cb00b 100644 --- a/antarest/study/main.py +++ b/antarest/study/main.py @@ -12,8 +12,9 @@ from typing import Optional -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI +from antarest.core.application import AppBuildContext from antarest.core.config import Config from antarest.core.filetransfer.service import FileTransferManager from antarest.core.interfaces.cache import ICache @@ -39,7 +40,7 @@ def build_study_service( - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], config: Config, user_service: LoginService, matrix_service: ISimpleMatrixService, @@ -124,16 +125,17 @@ def build_study_service( config=config, ) - if application: - application.include_router(create_study_routes(study_service, file_transfer_manager, config)) - application.include_router(create_raw_study_routes(study_service, config)) - application.include_router(create_study_data_routes(study_service, config)) - application.include_router( + if app_ctxt: + api_root = app_ctxt.api_root + api_root.include_router(create_study_routes(study_service, file_transfer_manager, config)) + api_root.include_router(create_raw_study_routes(study_service, config)) + api_root.include_router(create_study_data_routes(study_service, config)) + api_root.include_router( create_study_variant_routes( study_service=study_service, config=config, ) ) - application.include_router(create_xpansion_routes(study_service, config)) + api_root.include_router(create_xpansion_routes(study_service, config)) return study_service diff --git a/antarest/tools/lib.py b/antarest/tools/lib.py index dbd58944e6..e2e876f723 100644 --- a/antarest/tools/lib.py +++ b/antarest/tools/lib.py @@ -20,7 +20,7 @@ from zipfile import ZipFile import numpy as np -from httpx import Client # TODO SL: check it work on windows +from httpx import Client from antarest.core.cache.business.local_chache import LocalCache from antarest.core.config import CacheConfig diff --git a/antarest/utils.py b/antarest/utils.py index 4ae8bf3fe7..1f7f4b7594 100644 --- a/antarest/utils.py +++ b/antarest/utils.py @@ -16,7 +16,7 @@ from typing import Any, Dict, Mapping, Optional, Tuple import redis -from fastapi import FastAPI +from fastapi import APIRouter, FastAPI from ratelimit import RateLimitMiddleware # type: ignore from ratelimit.backends.redis import RedisBackend # type: ignore from ratelimit.backends.simple import MemoryBackend # type: ignore @@ -24,6 +24,7 @@ from sqlalchemy.engine.base import Engine # type: ignore from sqlalchemy.pool import NullPool # type: ignore +from antarest.core.application import AppBuildContext from antarest.core.cache.main import build_cache from antarest.core.config import Config from antarest.core.filetransfer.main import build_filetransfer_service @@ -109,24 +110,24 @@ def init_db_engine( return engine -def create_event_bus(application: Optional[FastAPI], config: Config) -> Tuple[IEventBus, Optional[redis.Redis]]: # type: ignore +def create_event_bus(app_ctxt: Optional[AppBuildContext], config: Config) -> Tuple[IEventBus, Optional[redis.Redis]]: # type: ignore redis_client = new_redis_instance(config.redis) if config.redis is not None else None return ( - build_eventbus(application, config, True, redis_client), + build_eventbus(app_ctxt, config, True, redis_client), redis_client, ) def create_core_services( - application: Optional[FastAPI], config: Config + app_ctxt: Optional[AppBuildContext], config: Config ) -> Tuple[ICache, IEventBus, ITaskService, FileTransferManager, LoginService, MatrixService, StudyService,]: - event_bus, redis_client = create_event_bus(application, config) + event_bus, redis_client = create_event_bus(app_ctxt, config) cache = build_cache(config=config, redis_client=redis_client) - filetransfer_service = build_filetransfer_service(application, event_bus, config) - task_service = build_taskjob_manager(application, config, event_bus) - login_service = build_login(application, config, event_bus=event_bus) + filetransfer_service = build_filetransfer_service(app_ctxt, event_bus, config) + task_service = build_taskjob_manager(app_ctxt, config, event_bus) + login_service = build_login(app_ctxt, config, event_bus=event_bus) matrix_service = build_matrix_service( - application, + app_ctxt, config=config, file_transfer_manager=filetransfer_service, task_service=task_service, @@ -134,7 +135,7 @@ def create_core_services( service=None, ) study_service = build_study_service( - application, + app_ctxt, config, matrix_service=matrix_service, cache=cache, @@ -156,7 +157,7 @@ def create_core_services( def create_watcher( config: Config, - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], study_service: Optional[StudyService] = None, ) -> Watcher: if study_service: @@ -166,22 +167,22 @@ def create_watcher( task_service=study_service.task_service, ) else: - _, _, task_service, _, _, _, study_service = create_core_services(application, config) + _, _, task_service, _, _, _, study_service = create_core_services(app_ctxt, config) watcher = Watcher( config=config, study_service=study_service, task_service=task_service, ) - if application: - application.include_router(create_watcher_routes(watcher=watcher, config=config)) + if app_ctxt: + app_ctxt.api_root.include_router(create_watcher_routes(watcher=watcher, config=config)) return watcher def create_matrix_gc( config: Config, - application: Optional[FastAPI], + app_ctxt: Optional[AppBuildContext], study_service: Optional[StudyService] = None, matrix_service: Optional[MatrixService] = None, ) -> MatrixGarbageCollector: @@ -192,7 +193,7 @@ def create_matrix_gc( matrix_service=matrix_service, ) else: - _, _, _, _, _, matrix_service, study_service = create_core_services(application, config) + _, _, _, _, _, matrix_service, study_service = create_core_services(app_ctxt, config) return MatrixGarbageCollector( config=config, study_service=study_service, @@ -221,7 +222,7 @@ def create_simulator_worker( return SimulatorWorker(event_bus, matrix_service, config) -def create_services(config: Config, application: Optional[FastAPI], create_all: bool = False) -> Dict[str, Any]: +def create_services(config: Config, app_ctxt: Optional[AppBuildContext], create_all: bool = False) -> Dict[str, Any]: services: Dict[str, Any] = {} ( @@ -232,12 +233,12 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: user_service, matrix_service, study_service, - ) = create_core_services(application, config) + ) = create_core_services(app_ctxt, config) - maintenance_service = build_maintenance_manager(application, config=config, cache=cache, event_bus=event_bus) + maintenance_service = build_maintenance_manager(app_ctxt, config=config, cache=cache, event_bus=event_bus) launcher = build_launcher( - application, + app_ctxt, config, study_service=study_service, event_bus=event_bus, @@ -246,13 +247,13 @@ def create_services(config: Config, application: Optional[FastAPI], create_all: cache=cache, ) - watcher = create_watcher(config=config, application=application, study_service=study_service) + watcher = create_watcher(config=config, app_ctxt=app_ctxt, study_service=study_service) services["watcher"] = watcher if config.server.services and Module.MATRIX_GC.value in config.server.services or create_all: matrix_garbage_collector = create_matrix_gc( config=config, - application=application, + app_ctxt=app_ctxt, study_service=study_service, matrix_service=matrix_service, ) diff --git a/requirements-test.txt b/requirements-test.txt index 9d58c1bb5b..e72650560c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,9 @@ -r requirements.txt checksumdir~=1.2.0 -pytest~=6.2.5 +pytest~=8.3.0 +pytest-xdist~=3.6.0 pytest-cov~=4.0.0 +pytest-mock~=3.14.0 # In this version DataFrame conversion to Excel is done using 'xlsxwriter' library. # But Excel files reading is done using 'openpyxl' library, during testing only. diff --git a/resources/application.yaml b/resources/application.yaml index 6fbdb31f9f..c962d09015 100644 --- a/resources/application.yaml +++ b/resources/application.yaml @@ -26,7 +26,19 @@ launcher: 700: path/to/700 enable_nb_cores_detection: true -root_path: "api" +# See https://fastapi.tiangolo.com/advanced/behind-a-proxy/ +# root path is used when the API is served behind a proxy which +# adds a prefix for clients. +# It does NOT add any prefix to the URLs which fastapi serve. + +# root_path: "api" + + +# Uncomment to serve the API under /api prefix +# (used in standalone mode to emulate the effect of proxy servers +# used in production deployments). + +# api_prefix: "/api" server: worker_threadpool_size: 12 @@ -36,4 +48,7 @@ server: logging: level: INFO - logfile: ./tmp/antarest.log \ No newline at end of file + logfile: ./tmp/antarest.log + +# True to get sqlalchemy logs +debug: False diff --git a/resources/deploy/config.prod.yaml b/resources/deploy/config.prod.yaml index cf9087a2af..9d8d78d56e 100644 --- a/resources/deploy/config.prod.yaml +++ b/resources/deploy/config.prod.yaml @@ -66,6 +66,10 @@ launcher: debug: false +# See https://fastapi.tiangolo.com/advanced/behind-a-proxy/ +# root path is used when the API is served behind a proxy which +# adds a prefix for clients. +# It does NOT add any prefix to the URLs which fastapi serve. root_path: "api" #tasks: diff --git a/resources/deploy/config.yaml b/resources/deploy/config.yaml index 6f3fdf595f..08831198ce 100644 --- a/resources/deploy/config.yaml +++ b/resources/deploy/config.yaml @@ -46,7 +46,8 @@ launcher: debug: false -root_path: "api" +# Serve the API at /api +api_prefix: "/api" server: worker_threadpool_size: 12 diff --git a/resources/templates/.placeholder b/resources/templates/.placeholder deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/scripts/build-front.sh b/scripts/build-front.sh index 87c2db6a50..15c2e0e60f 100755 --- a/scripts/build-front.sh +++ b/scripts/build-front.sh @@ -11,4 +11,3 @@ npm run build -- --mode=desktop cd .. rm -fr resources/webapp cp -r ./webapp/dist/ resources/webapp -cp ./webapp/dist/index.html resources/templates/ diff --git a/tests/integration/filesystem_blueprint/test_filesystem_endpoints.py b/tests/integration/filesystem_blueprint/test_filesystem_endpoints.py index d1e344e73d..ac1e9a36ff 100644 --- a/tests/integration/filesystem_blueprint/test_filesystem_endpoints.py +++ b/tests/integration/filesystem_blueprint/test_filesystem_endpoints.py @@ -13,10 +13,10 @@ import datetime import operator import re -import shutil import typing as t from pathlib import Path +from pytest_mock import MockerFixture from starlette.testclient import TestClient from tests.integration.conftest import RESOURCES_DIR @@ -93,6 +93,7 @@ def test_lifecycle( client: TestClient, user_access_token: str, admin_access_token: str, + mocker: MockerFixture, ) -> None: """ Test the lifecycle of the filesystem endpoints. @@ -102,7 +103,7 @@ def test_lifecycle( caplog: pytest caplog fixture. client: test client (tests.integration.conftest.client_fixture). user_access_token: access token of a classic user (tests.integration.conftest.user_access_token_fixture). - admin_access_token: access token of an admin user (tests.integration.conftest.admin_access_token_fixture). + admin_access_token: access token of an admin user (tests.integration.conftestin_access_token_fixture). """ # NOTE: all the following paths are based on the configuration defined in the app_fixture. archive_dir = tmp_path / "archive_dir" @@ -165,26 +166,25 @@ def test_lifecycle( err_count += 1 # Known filesystem + mocker.patch("shutil.disk_usage", return_value=(100, 200, 300)) res = client.get("/v1/filesystem/ws", headers=user_headers) assert res.status_code == 200, res.json() actual = sorted(res.json(), key=operator.itemgetter("name")) - # Both mount point are in the same filesystem, which is the `tmp_path` filesystem - total_bytes, used_bytes, free_bytes = shutil.disk_usage(tmp_path) expected = [ { "name": "default", "path": str(default_workspace), - "total_bytes": total_bytes, - "used_bytes": used_bytes, - "free_bytes": free_bytes, + "total_bytes": 100, + "used_bytes": 200, + "free_bytes": 300, "message": AnyDiskUsagePercent(), }, { "name": "ext", "path": str(ext_workspace_path), - "total_bytes": total_bytes, - "used_bytes": used_bytes, - "free_bytes": free_bytes, + "total_bytes": 100, + "used_bytes": 200, + "free_bytes": 300, "message": AnyDiskUsagePercent(), }, ] @@ -206,9 +206,9 @@ def test_lifecycle( expected = { "name": "default", "path": str(default_workspace), - "total_bytes": total_bytes, - "used_bytes": used_bytes, - "free_bytes": free_bytes, + "total_bytes": 100, + "used_bytes": 200, + "free_bytes": 300, "message": AnyDiskUsagePercent(), } assert actual == expected diff --git a/tests/integration/filesystem_blueprint/test_model.py b/tests/integration/filesystem_blueprint/test_model.py index 4f184c4242..d26bdb02cb 100644 --- a/tests/integration/filesystem_blueprint/test_model.py +++ b/tests/integration/filesystem_blueprint/test_model.py @@ -16,6 +16,8 @@ import shutil from pathlib import Path +from pytest_mock import MockerFixture + from antarest.core.filesystem_blueprint import FileInfoDTO, FilesystemDTO, MountPointDTO @@ -63,15 +65,16 @@ def test_from_path__missing_file(self) -> None: assert dto.free_bytes == 0 assert dto.message.startswith("N/A:"), dto.message - def test_from_path__file(self, tmp_path: Path) -> None: + def test_from_path__file(self, tmp_path: Path, mocker: MockerFixture) -> None: + mocker.patch("shutil.disk_usage", return_value=(100, 200, 300)) + name = "foo" dto = asyncio.run(MountPointDTO.from_path(name, tmp_path)) - total_bytes, used_bytes, free_bytes = shutil.disk_usage(tmp_path) assert dto.name == name assert dto.path == tmp_path - assert dto.total_bytes == total_bytes - assert dto.used_bytes == used_bytes - assert dto.free_bytes == free_bytes + assert dto.total_bytes == 100 + assert dto.used_bytes == 200 + assert dto.free_bytes == 300 assert re.fullmatch(r"\d+(?:\.\d+)?% used", dto.message), dto.message diff --git a/tests/launcher/test_web.py b/tests/launcher/test_web.py index 8e39e9a910..e1104f87af 100644 --- a/tests/launcher/test_web.py +++ b/tests/launcher/test_web.py @@ -19,6 +19,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTGroup, JWTUser from antarest.core.requests import RequestParameters @@ -35,10 +36,9 @@ def create_app(service: Mock) -> FastAPI: - app = FastAPI(title=__name__) - + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) build_launcher( - app, + build_ctxt, study_service=Mock(), file_transfer_manager=Mock(), task_service=Mock(), @@ -46,7 +46,7 @@ def create_app(service: Mock) -> FastAPI: config=Config(security=SecurityConfig(disabled=True)), cache=Mock(), ) - return app + return build_ctxt.build() @pytest.mark.unit_test diff --git a/tests/login/test_web.py b/tests/login/test_web.py index c5ae051f69..389dc94cf7 100644 --- a/tests/login/test_web.py +++ b/tests/login/test_web.py @@ -21,6 +21,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import AppBuildContext, create_app_ctxt from antarest.core.config import Config, SecurityConfig from antarest.core.jwt import JWTGroup, JWTUser from antarest.core.requests import RequestParameters @@ -63,15 +64,16 @@ def get_config(): authjwt_token_location=("headers", "cookies"), ) + app_ctxt = create_app_ctxt(app) build_login( - app, + app_ctxt, service=service, config=Config( resources_path=Path(), security=SecurityConfig(disabled=auth_disabled), ), ) - return app + return app_ctxt.build() class TokenType: diff --git a/tests/matrixstore/test_web.py b/tests/matrixstore/test_web.py index 1e10e2602b..d47fb030bc 100644 --- a/tests/matrixstore/test_web.py +++ b/tests/matrixstore/test_web.py @@ -17,6 +17,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig from antarest.fastapi_jwt_auth import AuthJWT from antarest.main import JwtSettings @@ -26,7 +27,7 @@ def create_app(service: Mock, auth_disabled=False) -> FastAPI: - app = FastAPI(title=__name__) + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) @AuthJWT.load_config def get_config(): @@ -37,7 +38,7 @@ def get_config(): ) build_matrix_service( - app, + build_ctxt, user_service=Mock(), file_transfer_manager=Mock(), task_service=Mock(), @@ -47,7 +48,7 @@ def get_config(): security=SecurityConfig(disabled=auth_disabled), ), ) - return app + return build_ctxt.build() @pytest.mark.unit_test diff --git a/tests/storage/integration/conftest.py b/tests/storage/integration/conftest.py index 1136f3ca8a..dcf7e5e830 100644 --- a/tests/storage/integration/conftest.py +++ b/tests/storage/integration/conftest.py @@ -109,7 +109,7 @@ def storage_service(tmp_path: Path, project_path: Path, sta_mini_zip_path: Path) ) matrix_service = SimpleMatrixService(matrix_content_repository=matrix_content_repository) storage_service = build_study_service( - application=Mock(), + app_ctxt=Mock(), cache=LocalCache(config=config.cache), file_transfer_manager=Mock(), task_service=task_service_mock, diff --git a/tests/storage/integration/test_STA_mini.py b/tests/storage/integration/test_STA_mini.py index ca7904a075..35aaa4092d 100644 --- a/tests/storage/integration/test_STA_mini.py +++ b/tests/storage/integration/test_STA_mini.py @@ -22,6 +22,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.jwt import DEFAULT_ADMIN_USER, JWTGroup, JWTUser from antarest.core.model import JSON from antarest.core.requests import RequestParameters @@ -46,19 +47,23 @@ ) -def assert_url_content(storage_service: StudyService, url: str, expected_output: dict) -> None: - app = FastAPI(title=__name__) +def create_test_client(service: StudyService) -> TestClient: + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) build_study_service( - app, + build_ctxt, cache=Mock(), user_service=Mock(), task_service=Mock(), file_transfer_manager=Mock(), - study_service=storage_service, + study_service=service, matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, + config=service.storage_service.raw_study_service.config, ) - client = TestClient(app) + return TestClient(build_ctxt.build()) + + +def assert_url_content(storage_service: StudyService, url: str, expected_output: dict) -> None: + client = create_test_client(storage_service) res = client.get(url) assert_study(res.json(), expected_output) @@ -493,18 +498,7 @@ def test_sta_mini_copy(storage_service) -> None: source_study_name = UUID destination_study_name = "copy-STA-mini" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - user_service=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.post(f"/v1/studies/{source_study_name}/copy?dest={destination_study_name}&use_task=false") assert result.status_code == HTTPStatus.CREATED.value @@ -590,18 +584,7 @@ def test_sta_mini_import(tmp_path: Path, storage_service) -> None: sta_mini_zip_filepath = shutil.make_archive(tmp_path, "zip", path_study) sta_mini_zip_path = Path(sta_mini_zip_filepath) - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, - ) - client = TestClient(app) + client = create_test_client(storage_service) study_data = io.BytesIO(sta_mini_zip_path.read_bytes()) result = client.post("/v1/studies/_import", files={"study": study_data}) @@ -620,18 +603,7 @@ def test_sta_mini_import_output(tmp_path: Path, storage_service) -> None: sta_mini_output_zip_path = Path(sta_mini_output_zip_filepath) - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=storage_service.storage_service.raw_study_service.config, - ) - client = TestClient(app) + client = create_test_client(storage_service) study_output_data = io.BytesIO(sta_mini_output_zip_path.read_bytes()) result = client.post( diff --git a/tests/storage/integration/test_exporter.py b/tests/storage/integration/test_exporter.py index 2da0686302..46e077a9ca 100644 --- a/tests/storage/integration/test_exporter.py +++ b/tests/storage/integration/test_exporter.py @@ -21,6 +21,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig, StorageConfig, WorkspaceConfig from antarest.core.filetransfer.model import FileDownloadTaskDTO from antarest.core.jwt import DEFAULT_ADMIN_USER @@ -54,10 +55,10 @@ def assert_url_content(url: str, tmp_dir: Path, sta_mini_zip_path: Path) -> byte repo = Mock() repo.get.return_value = md - app = FastAPI(title=__name__) + build_ctxt = create_app_ctxt(FastAPI(title=__name__)) ftm = SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_dir))) build_study_service( - app, + build_ctxt, cache=Mock(), user_service=Mock(), task_service=SimpleSyncTaskService(), @@ -69,7 +70,7 @@ def assert_url_content(url: str, tmp_dir: Path, sta_mini_zip_path: Path) -> byte ) # Simulate the download of data using a streamed request - client = TestClient(app) + client = TestClient(build_ctxt.build()) if client.stream is False: # `TestClient` is based on `Requests` (old way before AntaREST-v2.15) # noinspection PyArgumentList diff --git a/tests/storage/web/test_studies_bp.py b/tests/storage/web/test_studies_bp.py index c9a3c18be2..33cedcc8df 100644 --- a/tests/storage/web/test_studies_bp.py +++ b/tests/storage/web/test_studies_bp.py @@ -24,9 +24,11 @@ from markupsafe import Markup from starlette.testclient import TestClient +from antarest.core.application import create_app_ctxt from antarest.core.config import Config, SecurityConfig, StorageConfig, WorkspaceConfig from antarest.core.exceptions import UrlNotMatchJsonDataError from antarest.core.filetransfer.model import FileDownloadDTO, FileDownloadTaskDTO +from antarest.core.filetransfer.service import FileTransferManager from antarest.core.jwt import JWTGroup, JWTUser from antarest.core.requests import RequestParameters from antarest.core.roles import RoleType @@ -48,6 +50,7 @@ TimeSerie, TimeSeriesData, ) +from antarest.study.service import StudyService from tests.storage.conftest import SimpleFileTransferManager from tests.storage.integration.conftest import UUID @@ -66,23 +69,29 @@ ) -@pytest.mark.unit_test -def test_server() -> None: - mock_service = Mock() - mock_service.get.return_value = {} - - app = FastAPI(title=__name__) +def create_test_client( + service: StudyService, file_transfer_manager: FileTransferManager = Mock(), raise_server_exceptions: bool = True +) -> TestClient: + app_ctxt = create_app_ctxt(FastAPI(title=__name__)) build_study_service( - app, + app_ctxt, cache=Mock(), task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, + file_transfer_manager=file_transfer_manager, + study_service=service, config=CONFIG, user_service=Mock(), matrix_service=Mock(spec=MatrixService), ) - client = TestClient(app) + return TestClient(app_ctxt.build(), raise_server_exceptions=raise_server_exceptions) + + +@pytest.mark.unit_test +def test_server() -> None: + mock_service = Mock() + mock_service.get.return_value = {} + + client = create_test_client(mock_service) client.get("/v1/studies/study1/raw?path=settings/general/params") mock_service.get.assert_called_once_with( @@ -95,18 +104,7 @@ def test_404() -> None: mock_storage_service = Mock() mock_storage_service.get.side_effect = UrlNotMatchJsonDataError("Test") - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(mock_storage_service, raise_server_exceptions=False) result = client.get("/v1/studies/study1/raw?path=settings/general/params") assert result.status_code == HTTPStatus.NOT_FOUND @@ -119,18 +117,7 @@ def test_server_with_parameters() -> None: mock_storage_service = Mock() mock_storage_service.get.return_value = {} - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) result = client.get("/v1/studies/study1/raw?depth=4") parameters = RequestParameters(user=ADMIN) @@ -158,18 +145,7 @@ def test_create_study(tmp_path: str, project_path) -> None: storage_service = Mock() storage_service.create_study.return_value = "my-uuid" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result_right = client.post("/v1/studies?name=study2") @@ -193,18 +169,7 @@ def test_import_study_zipped(tmp_path: Path, project_path) -> None: study_uuid = str(uuid.uuid4()) mock_storage_service.import_study.return_value = study_uuid - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) result = client.post("/v1/studies") @@ -223,18 +188,7 @@ def test_copy_study(tmp_path: Path) -> None: storage_service = Mock() storage_service.copy_study.return_value = "/studies/study-copied" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.post(f"/v1/studies/{UUID}/copy?dest=study-copied") @@ -285,18 +239,7 @@ def test_list_studies(tmp_path: str) -> None: storage_service = Mock() storage_service.get_studies_information.return_value = studies - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.get("/v1/studies") assert {k: StudyMetadataDTO.model_validate(v) for k, v in result.json().items()} == studies @@ -320,18 +263,7 @@ def test_study_metadata(tmp_path: str) -> None: storage_service = Mock() storage_service.get_study_information.return_value = study - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(storage_service) result = client.get("/v1/studies/1") assert StudyMetadataDTO.model_validate(result.json()) == study @@ -352,20 +284,8 @@ def test_export_files(tmp_path: Path) -> None: ) mock_storage_service.export_study.return_value = expected - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - # Simulate the download of data using a streamed request - client = TestClient(app) + client = create_test_client(mock_storage_service) if client.stream is False: # `TestClient` is based on `Requests` (old way before AntaREST-v2.15) # noinspection PyArgumentList @@ -402,18 +322,7 @@ def test_export_params(tmp_path: Path) -> None: ) mock_storage_service.export_study.return_value = expected - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) client.get(f"/v1/studies/{UUID}/export?no_output=true") client.get(f"/v1/studies/{UUID}/export?no_output=false") mock_storage_service.export_study.assert_has_calls( @@ -428,18 +337,7 @@ def test_export_params(tmp_path: Path) -> None: def test_delete_study() -> None: mock_storage_service = Mock() - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) study_uuid = "8319b5f8-2a35-4984-9ace-2ab072bd6eef" client.delete(f"/v1/studies/{study_uuid}") @@ -452,18 +350,7 @@ def test_edit_study() -> None: mock_storage_service = Mock() mock_storage_service.edit_study.return_value = {} - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_storage_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app) + client = create_test_client(mock_storage_service) client.post("/v1/studies/my-uuid/raw?path=url/to/change", json={"Hello": "World"}) mock_storage_service.edit_study.assert_called_once_with("my-uuid", "url/to/change", {"Hello": "World"}, PARAMS) @@ -497,18 +384,7 @@ def test_validate() -> None: mock_service = Mock() mock_service.check_errors.return_value = ["Hello"] - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(mock_service, raise_server_exceptions=False) res = client.get("/v1/studies/my-uuid/raw/validate") assert res.json() == ["Hello"] @@ -551,19 +427,8 @@ def test_output_download(tmp_path: Path) -> None: synthesis=False, includeClusters=True, ) - - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + ftm = SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))) + client = create_test_client(mock_service, ftm, raise_server_exceptions=False) res = client.post( f"/v1/studies/{UUID}/outputs/my-output-id/download", json=study_download.model_dump(), @@ -588,18 +453,8 @@ def test_output_whole_download(tmp_path: Path) -> None: ) mock_service.export_output.return_value = expected - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + ftm = SimpleFileTransferManager(Config(storage=StorageConfig(tmp_dir=tmp_path))) + client = create_test_client(mock_service, ftm, raise_server_exceptions=False) res = client.get( f"/v1/studies/{UUID}/outputs/{output_id}/export", ) @@ -612,18 +467,7 @@ def test_sim_reference() -> None: study_id = str(uuid.uuid4()) output_id = "my-output-id" - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(mock_service, raise_server_exceptions=False) res = client.put(f"/v1/studies/{study_id}/outputs/{output_id}/reference") mock_service.set_sim_reference.assert_called_once_with(study_id, output_id, True, PARAMS) assert res.status_code == HTTPStatus.OK @@ -656,18 +500,8 @@ def test_sim_result() -> None: ) ] mock_service.get_study_sim_result.return_value = result_data - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=mock_service, - config=CONFIG, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - ) - client = TestClient(app, raise_server_exceptions=False) + + client = create_test_client(mock_service, raise_server_exceptions=False) res = client.get(f"/v1/studies/{study_id}/outputs") actual_object = [StudySimResultDTO.parse_obj(res.json()[0])] assert actual_object == result_data @@ -676,19 +510,7 @@ def test_sim_result() -> None: @pytest.mark.unit_test def test_study_permission_management(tmp_path: Path) -> None: storage_service = Mock() - - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=storage_service, - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=CONFIG, - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(storage_service, raise_server_exceptions=False) result = client.put(f"/v1/studies/{UUID}/owner/2") storage_service.change_owner.assert_called_with( @@ -728,18 +550,7 @@ def test_study_permission_management(tmp_path: Path) -> None: @pytest.mark.unit_test def test_get_study_versions(tmp_path: Path) -> None: - app = FastAPI(title=__name__) - build_study_service( - app, - cache=Mock(), - task_service=Mock(), - file_transfer_manager=Mock(), - study_service=Mock(), - user_service=Mock(), - matrix_service=Mock(spec=MatrixService), - config=CONFIG, - ) - client = TestClient(app, raise_server_exceptions=False) + client = create_test_client(Mock(), raise_server_exceptions=False) result = client.get("/v1/studies/_versions") assert result.json() == list(STUDY_REFERENCE_TEMPLATES.keys()) diff --git a/tests/test_front.py b/tests/test_front.py new file mode 100644 index 0000000000..5046a868cf --- /dev/null +++ b/tests/test_front.py @@ -0,0 +1,111 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from pathlib import Path + +import pytest +from fastapi import FastAPI +from starlette.testclient import TestClient + +from antarest.front import RedirectMiddleware, add_front_app + + +@pytest.fixture +def base_back_app() -> FastAPI: + """ + A simple app which has only one backend endpoint + """ + app = FastAPI(title=__name__) + + @app.get(path="/api/a-backend-endpoint") + def get_from_api() -> str: + return "back" + + return app + + +@pytest.fixture +def resources_dir(tmp_path: Path) -> Path: + resource_dir = tmp_path / "resources" + resource_dir.mkdir() + webapp_dir = resource_dir / "webapp" + webapp_dir.mkdir() + with open(webapp_dir / "index.html", mode="w") as f: + f.write("index") + with open(webapp_dir / "front.css", mode="w") as f: + f.write("css") + return resource_dir + + +@pytest.fixture +def app_with_home(base_back_app) -> FastAPI: + """ + A simple app which has only a home endpoint and one backend endpoint + """ + + @base_back_app.get(path="/") + def home() -> str: + return "home" + + return base_back_app + + +@pytest.fixture +def redirect_app(app_with_home: FastAPI) -> FastAPI: + """ + Same as app with redirect middleware + """ + route_paths = [r.path for r in app_with_home.routes] # type: ignore + app_with_home.add_middleware(RedirectMiddleware, route_paths=route_paths) + return app_with_home + + +def test_redirect_middleware_does_not_modify_home(redirect_app: FastAPI) -> None: + client = TestClient(redirect_app) + response = client.get("/") + assert response.status_code == 200 + assert response.json() == "home" + + +def test_redirect_middleware_redirects_unknown_routes_to_home(redirect_app: FastAPI) -> None: + client = TestClient(redirect_app) + response = client.get("/a-front-route") + assert response.status_code == 200 + assert response.json() == "home" + + +def test_redirect_middleware_does_not_redirect_backend_routes(redirect_app: FastAPI) -> None: + client = TestClient(redirect_app) + response = client.get("/api/a-backend-endpoint") + assert response.status_code == 200 + assert response.json() == "back" + + +def test_frontend_paths(base_back_app, resources_dir: Path) -> None: + add_front_app(base_back_app, resources_dir, "/api") + client = TestClient(base_back_app) + + config_response = client.get("/config.json") + assert config_response.status_code == 200 + assert config_response.json() == {"restEndpoint": "/api", "wsEndpoint": "/api/ws"} + + index_response = client.get("/index.html") + assert index_response.status_code == 200 + assert index_response.text == "index" + + front_route_response = client.get("/any-route") + assert front_route_response.status_code == 200 + assert front_route_response.text == "index" + + front_static_file_response = client.get("/static/front.css") + assert front_static_file_response.status_code == 200 + assert front_static_file_response.text == "css" diff --git a/tests/variantstudy/test_command_factory.py b/tests/variantstudy/test_command_factory.py index f20985f808..b78ba393e5 100644 --- a/tests/variantstudy/test_command_factory.py +++ b/tests/variantstudy/test_command_factory.py @@ -13,6 +13,7 @@ import importlib import itertools import pkgutil +from typing import List, Set from unittest.mock import Mock import pytest @@ -25,18 +26,385 @@ from antarest.study.storage.variantstudy.model.command.icommand import ICommand from antarest.study.storage.variantstudy.model.model import CommandDTO +COMMANDS: List[CommandDTO] = [ + CommandDTO( + action=CommandName.CREATE_AREA.value, + args={"area_name": "area_name"}, + ), + CommandDTO( + action=CommandName.CREATE_AREA.value, + args=[ + {"area_name": "area_name"}, + {"area_name": "area2"}, + ], + ), + CommandDTO( + action=CommandName.REMOVE_AREA.value, + args={"id": "id"}, + ), + CommandDTO( + action=CommandName.REMOVE_AREA.value, + args=[{"id": "id"}], + ), + CommandDTO( + action=CommandName.CREATE_DISTRICT.value, + args={ + "name": "id", + "filter_items": ["a"], + "output": True, + "comments": "", + }, + ), + CommandDTO( + action=CommandName.CREATE_DISTRICT.value, + args=[ + { + "name": "id", + "base_filter": "add-all", + "output": True, + "comments": "", + } + ], + ), + CommandDTO( + action=CommandName.REMOVE_DISTRICT.value, + args={"id": "id"}, + ), + CommandDTO( + action=CommandName.REMOVE_DISTRICT.value, + args=[{"id": "id"}], + ), + CommandDTO( + action=CommandName.CREATE_LINK.value, + args={ + "area1": "area1", + "area2": "area2", + "parameters": {}, + "series": "series", + }, + ), + CommandDTO( + action=CommandName.CREATE_LINK.value, + args=[ + { + "area1": "area1", + "area2": "area2", + "parameters": {}, + "series": "series", + } + ], + ), + CommandDTO( + action=CommandName.REMOVE_LINK.value, + args={ + "area1": "area1", + "area2": "area2", + }, + ), + CommandDTO( + action=CommandName.REMOVE_LINK.value, + args=[ + { + "area1": "area1", + "area2": "area2", + } + ], + ), + CommandDTO( + action=CommandName.CREATE_BINDING_CONSTRAINT.value, + args={"name": "name"}, + ), + CommandDTO( + action=CommandName.CREATE_BINDING_CONSTRAINT.value, + args=[ + { + "name": "name", + "enabled": True, + "time_step": "hourly", + "operator": "equal", + "values": "values", + "group": "group_1", + }, + ], + ), + CommandDTO( + action=CommandName.UPDATE_BINDING_CONSTRAINT.value, + args={ + "id": "id", + "enabled": True, + "time_step": "hourly", + "operator": "equal", + "values": "values", + }, + ), + CommandDTO( + action=CommandName.UPDATE_BINDING_CONSTRAINT.value, + args=[ + { + "id": "id", + "enabled": True, + "time_step": "hourly", + "operator": "equal", + } + ], + ), + CommandDTO( + action=CommandName.REMOVE_BINDING_CONSTRAINT.value, + args={"id": "id"}, + ), + CommandDTO( + action=CommandName.REMOVE_BINDING_CONSTRAINT.value, + args=[{"id": "id"}], + ), + CommandDTO( + action=CommandName.CREATE_THERMAL_CLUSTER.value, + args={ + "area_id": "area_name", + "cluster_name": "cluster_name", + "parameters": { + "group": "group", + "unitcount": "unitcount", + "nominalcapacity": "nominalcapacity", + "marginal-cost": "marginal-cost", + "market-bid-cost": "market-bid-cost", + }, + "prepro": "prepro", + "modulation": "modulation", + }, + ), + CommandDTO( + action=CommandName.CREATE_THERMAL_CLUSTER.value, + args=[ + { + "area_id": "area_name", + "cluster_name": "cluster_name", + "parameters": { + "group": "group", + "unitcount": "unitcount", + "nominalcapacity": "nominalcapacity", + "marginal-cost": "marginal-cost", + "market-bid-cost": "market-bid-cost", + }, + "prepro": "prepro", + "modulation": "modulation", + } + ], + ), + CommandDTO( + action=CommandName.REMOVE_THERMAL_CLUSTER.value, + args={"area_id": "area_name", "cluster_id": "cluster_name"}, + ), + CommandDTO( + action=CommandName.REMOVE_THERMAL_CLUSTER.value, + args=[{"area_id": "area_name", "cluster_id": "cluster_name"}], + ), + CommandDTO( + action=CommandName.CREATE_RENEWABLES_CLUSTER.value, + args={ + "area_id": "area_name", + "cluster_name": "cluster_name", + "parameters": { + "name": "name", + "ts-interpretation": "power-generation", + }, + }, + ), + CommandDTO( + action=CommandName.CREATE_RENEWABLES_CLUSTER.value, + args=[ + { + "area_id": "area_name", + "cluster_name": "cluster_name", + "parameters": { + "name": "name", + "ts-interpretation": "power-generation", + }, + } + ], + ), + CommandDTO( + action=CommandName.REMOVE_RENEWABLES_CLUSTER.value, + args={"area_id": "area_name", "cluster_id": "cluster_name"}, + ), + CommandDTO( + action=CommandName.REMOVE_RENEWABLES_CLUSTER.value, + args=[{"area_id": "area_name", "cluster_id": "cluster_name"}], + ), + CommandDTO( + action=CommandName.REPLACE_MATRIX.value, + args={"target": "target_element", "matrix": "matrix"}, + ), + CommandDTO( + action=CommandName.REPLACE_MATRIX.value, + args=[{"target": "target_element", "matrix": "matrix"}], + ), + CommandDTO( + action=CommandName.UPDATE_CONFIG.value, + args={"target": "target", "data": {}}, + ), + CommandDTO( + action=CommandName.UPDATE_CONFIG.value, + args=[{"target": "target", "data": {}}], + ), + CommandDTO( + action=CommandName.UPDATE_COMMENTS.value, + args={"comments": "comments"}, + ), + CommandDTO( + action=CommandName.UPDATE_COMMENTS.value, + args=[{"comments": "comments"}], + ), + CommandDTO( + action=CommandName.UPDATE_FILE.value, + args={ + "target": "settings/resources/study", + "b64Data": "", + }, + ), + CommandDTO( + action=CommandName.UPDATE_DISTRICT.value, + args={"id": "id", "filter_items": ["a"]}, + ), + CommandDTO( + action=CommandName.UPDATE_DISTRICT.value, + args=[{"id": "id", "base_filter": "add-all"}], + ), + CommandDTO( + action=CommandName.UPDATE_PLAYLIST.value, + args=[{"active": True, "items": [1, 3], "reverse": False}], + ), + CommandDTO( + action=CommandName.UPDATE_PLAYLIST.value, + args={ + "active": True, + "items": [1, 3], + "weights": {1: 5.0}, + "reverse": False, + }, + ), + CommandDTO( + action=CommandName.UPDATE_SCENARIO_BUILDER.value, + args={ + "data": { + "ruleset test": { + "l": {"area1": {"0": 1}}, + "ntc": {"area1 / area2": {"1": 23}}, + "t": {"area1": {"thermal": {"1": 2}}}, + }, + } + }, + ), + CommandDTO( + action=CommandName.CREATE_ST_STORAGE.value, + args={ + "area_id": "area 1", + "parameters": { + "name": "Storage 1", + "group": "Battery", + "injectionnominalcapacity": 0, + "withdrawalnominalcapacity": 0, + "reservoircapacity": 0, + "efficiency": 1, + "initiallevel": 0, + "initialleveloptim": False, + }, + "pmax_injection": "matrix://59ea6c83-6348-466d-9530-c35c51ca4c37", + "pmax_withdrawal": "matrix://5f988548-dadc-4bbb-8ce8-87a544dbf756", + "lower_rule_curve": "matrix://8ce4fcea-cc97-4d2c-b641-a27a53454612", + "upper_rule_curve": "matrix://8ce614c8-c687-41af-8b24-df8a49cc52af", + "inflows": "matrix://df9b25e1-e3f7-4a57-8182-0ff9791439e5", + }, + ), + CommandDTO( + action=CommandName.CREATE_ST_STORAGE.value, + args=[ + { + "area_id": "area 1", + "parameters": { + "efficiency": 1, + "group": "Battery", + "initiallevel": 0, + "initialleveloptim": False, + "injectionnominalcapacity": 0, + "name": "Storage 1", + "reservoircapacity": 0, + "withdrawalnominalcapacity": 0, + }, + "pmax_injection": "matrix://59ea6c83-6348-466d-9530-c35c51ca4c37", + "pmax_withdrawal": "matrix://5f988548-dadc-4bbb-8ce8-87a544dbf756", + "lower_rule_curve": "matrix://8ce4fcea-cc97-4d2c-b641-a27a53454612", + "upper_rule_curve": "matrix://8ce614c8-c687-41af-8b24-df8a49cc52af", + "inflows": "matrix://df9b25e1-e3f7-4a57-8182-0ff9791439e5", + }, + { + "area_id": "area 1", + "parameters": { + "efficiency": 0.94, + "group": "Battery", + "initiallevel": 0, + "initialleveloptim": False, + "injectionnominalcapacity": 0, + "name": "Storage 2", + "reservoircapacity": 0, + "withdrawalnominalcapacity": 0, + }, + "pmax_injection": "matrix://3f5b3746-3995-49b7-a6da-622633472e05", + "pmax_withdrawal": "matrix://4b64a31f-927b-4887-b4cd-adcddd39bdcd", + "lower_rule_curve": "matrix://16c7c3ae-9824-4ef2-aa68-51145884b025", + "upper_rule_curve": "matrix://9a6104e9-990a-415f-a6e2-57507e13b58c", + "inflows": "matrix://e8923768-9bdd-40c2-a6ea-2da2523be727", + }, + ], + ), + CommandDTO( + action=CommandName.REMOVE_ST_STORAGE.value, + args={ + "area_id": "area 1", + "storage_id": "storage 1", + }, + ), + CommandDTO( + action=CommandName.REMOVE_ST_STORAGE.value, + args=[ + { + "area_id": "area 1", + "storage_id": "storage 1", + }, + { + "area_id": "area 1", + "storage_id": "storage 2", + }, + ], + ), + CommandDTO( + action=CommandName.GENERATE_THERMAL_CLUSTER_TIMESERIES.value, + args=[{}], + ), +] + + +@pytest.fixture +def command_factory() -> CommandFactory: + def get_matrix_id(matrix: str) -> str: + # str.removeprefix() is not available in Python 3.8 + prefix = "matrix://" + if matrix.startswith(prefix): + return matrix[len(prefix) :] + return matrix + + return CommandFactory( + generator_matrix_constants=Mock(spec=GeneratorMatrixConstants), + matrix_service=Mock(spec=MatrixService, get_matrix_id=get_matrix_id), + patch_service=Mock(spec=PatchService), + ) + class TestCommandFactory: - # noinspection SpellCheckingInspection - def setup_class(self): + def _get_command_classes(self) -> Set[str]: """ - Set up the test class. - Imports all modules from the `antarest.study.storage.variantstudy.model.command` package and creates a set of command class names derived from the `ICommand` abstract class. The objective is to ensure that the unit test covers all commands in this package. - - This method is executed once before any tests in the test class are run. """ for module_loader, name, ispkg in pkgutil.iter_modules(["antarest/study/storage/variantstudy/model/command"]): importlib.import_module( @@ -44,383 +412,21 @@ def setup_class(self): package="antarest.study.storage.variantstudy.model.command", ) abstract_commands = {"AbstractBindingConstraintCommand"} - self.command_class_set = { - cmd.__name__ for cmd in ICommand.__subclasses__() if cmd.__name__ not in abstract_commands - } + return {cmd.__name__ for cmd in ICommand.__subclasses__() if cmd.__name__ not in abstract_commands} + + def test_all_commands_are_tested(self, command_factory: CommandFactory): + commands = sum([command_factory.to_command(command_dto=cmd) for cmd in COMMANDS], []) + tested_classes = {c.__class__.__name__ for c in commands} + + assert self._get_command_classes().issubset(tested_classes) # noinspection SpellCheckingInspection @pytest.mark.parametrize( "command_dto", - [ - CommandDTO( - action=CommandName.CREATE_AREA.value, - args={"area_name": "area_name"}, - ), - CommandDTO( - action=CommandName.CREATE_AREA.value, - args=[ - {"area_name": "area_name"}, - {"area_name": "area2"}, - ], - ), - CommandDTO( - action=CommandName.REMOVE_AREA.value, - args={"id": "id"}, - ), - CommandDTO( - action=CommandName.REMOVE_AREA.value, - args=[{"id": "id"}], - ), - CommandDTO( - action=CommandName.CREATE_DISTRICT.value, - args={ - "name": "id", - "filter_items": ["a"], - "output": True, - "comments": "", - }, - ), - CommandDTO( - action=CommandName.CREATE_DISTRICT.value, - args=[ - { - "name": "id", - "base_filter": "add-all", - "output": True, - "comments": "", - } - ], - ), - CommandDTO( - action=CommandName.REMOVE_DISTRICT.value, - args={"id": "id"}, - ), - CommandDTO( - action=CommandName.REMOVE_DISTRICT.value, - args=[{"id": "id"}], - ), - CommandDTO( - action=CommandName.CREATE_LINK.value, - args={ - "area1": "area1", - "area2": "area2", - "parameters": {}, - "series": "series", - }, - ), - CommandDTO( - action=CommandName.CREATE_LINK.value, - args=[ - { - "area1": "area1", - "area2": "area2", - "parameters": {}, - "series": "series", - } - ], - ), - CommandDTO( - action=CommandName.REMOVE_LINK.value, - args={ - "area1": "area1", - "area2": "area2", - }, - ), - CommandDTO( - action=CommandName.REMOVE_LINK.value, - args=[ - { - "area1": "area1", - "area2": "area2", - } - ], - ), - CommandDTO( - action=CommandName.CREATE_BINDING_CONSTRAINT.value, - args={"name": "name"}, - ), - CommandDTO( - action=CommandName.CREATE_BINDING_CONSTRAINT.value, - args=[ - { - "name": "name", - "enabled": True, - "time_step": "hourly", - "operator": "equal", - "values": "values", - "group": "group_1", - }, - ], - ), - CommandDTO( - action=CommandName.UPDATE_BINDING_CONSTRAINT.value, - args={ - "id": "id", - "enabled": True, - "time_step": "hourly", - "operator": "equal", - "values": "values", - }, - ), - CommandDTO( - action=CommandName.UPDATE_BINDING_CONSTRAINT.value, - args=[ - { - "id": "id", - "enabled": True, - "time_step": "hourly", - "operator": "equal", - } - ], - ), - CommandDTO( - action=CommandName.REMOVE_BINDING_CONSTRAINT.value, - args={"id": "id"}, - ), - CommandDTO( - action=CommandName.REMOVE_BINDING_CONSTRAINT.value, - args=[{"id": "id"}], - ), - CommandDTO( - action=CommandName.CREATE_THERMAL_CLUSTER.value, - args={ - "area_id": "area_name", - "cluster_name": "cluster_name", - "parameters": { - "group": "group", - "unitcount": "unitcount", - "nominalcapacity": "nominalcapacity", - "marginal-cost": "marginal-cost", - "market-bid-cost": "market-bid-cost", - }, - "prepro": "prepro", - "modulation": "modulation", - }, - ), - CommandDTO( - action=CommandName.CREATE_THERMAL_CLUSTER.value, - args=[ - { - "area_id": "area_name", - "cluster_name": "cluster_name", - "parameters": { - "group": "group", - "unitcount": "unitcount", - "nominalcapacity": "nominalcapacity", - "marginal-cost": "marginal-cost", - "market-bid-cost": "market-bid-cost", - }, - "prepro": "prepro", - "modulation": "modulation", - } - ], - ), - CommandDTO( - action=CommandName.REMOVE_THERMAL_CLUSTER.value, - args={"area_id": "area_name", "cluster_id": "cluster_name"}, - ), - CommandDTO( - action=CommandName.REMOVE_THERMAL_CLUSTER.value, - args=[{"area_id": "area_name", "cluster_id": "cluster_name"}], - ), - CommandDTO( - action=CommandName.CREATE_RENEWABLES_CLUSTER.value, - args={ - "area_id": "area_name", - "cluster_name": "cluster_name", - "parameters": { - "name": "name", - "ts-interpretation": "power-generation", - }, - }, - ), - CommandDTO( - action=CommandName.CREATE_RENEWABLES_CLUSTER.value, - args=[ - { - "area_id": "area_name", - "cluster_name": "cluster_name", - "parameters": { - "name": "name", - "ts-interpretation": "power-generation", - }, - } - ], - ), - CommandDTO( - action=CommandName.REMOVE_RENEWABLES_CLUSTER.value, - args={"area_id": "area_name", "cluster_id": "cluster_name"}, - ), - CommandDTO( - action=CommandName.REMOVE_RENEWABLES_CLUSTER.value, - args=[{"area_id": "area_name", "cluster_id": "cluster_name"}], - ), - CommandDTO( - action=CommandName.REPLACE_MATRIX.value, - args={"target": "target_element", "matrix": "matrix"}, - ), - CommandDTO( - action=CommandName.REPLACE_MATRIX.value, - args=[{"target": "target_element", "matrix": "matrix"}], - ), - CommandDTO( - action=CommandName.UPDATE_CONFIG.value, - args={"target": "target", "data": {}}, - ), - CommandDTO( - action=CommandName.UPDATE_CONFIG.value, - args=[{"target": "target", "data": {}}], - ), - CommandDTO( - action=CommandName.UPDATE_COMMENTS.value, - args={"comments": "comments"}, - ), - CommandDTO( - action=CommandName.UPDATE_COMMENTS.value, - args=[{"comments": "comments"}], - ), - CommandDTO( - action=CommandName.UPDATE_FILE.value, - args={ - "target": "settings/resources/study", - "b64Data": "", - }, - ), - CommandDTO( - action=CommandName.UPDATE_DISTRICT.value, - args={"id": "id", "filter_items": ["a"]}, - ), - CommandDTO( - action=CommandName.UPDATE_DISTRICT.value, - args=[{"id": "id", "base_filter": "add-all"}], - ), - CommandDTO( - action=CommandName.UPDATE_PLAYLIST.value, - args=[{"active": True, "items": [1, 3], "reverse": False}], - ), - CommandDTO( - action=CommandName.UPDATE_PLAYLIST.value, - args={ - "active": True, - "items": [1, 3], - "weights": {1: 5.0}, - "reverse": False, - }, - ), - CommandDTO( - action=CommandName.UPDATE_SCENARIO_BUILDER.value, - args={ - "data": { - "ruleset test": { - "l": {"area1": {"0": 1}}, - "ntc": {"area1 / area2": {"1": 23}}, - "t": {"area1": {"thermal": {"1": 2}}}, - }, - } - }, - ), - CommandDTO( - action=CommandName.CREATE_ST_STORAGE.value, - args={ - "area_id": "area 1", - "parameters": { - "name": "Storage 1", - "group": "Battery", - "injectionnominalcapacity": 0, - "withdrawalnominalcapacity": 0, - "reservoircapacity": 0, - "efficiency": 1, - "initiallevel": 0, - "initialleveloptim": False, - }, - "pmax_injection": "matrix://59ea6c83-6348-466d-9530-c35c51ca4c37", - "pmax_withdrawal": "matrix://5f988548-dadc-4bbb-8ce8-87a544dbf756", - "lower_rule_curve": "matrix://8ce4fcea-cc97-4d2c-b641-a27a53454612", - "upper_rule_curve": "matrix://8ce614c8-c687-41af-8b24-df8a49cc52af", - "inflows": "matrix://df9b25e1-e3f7-4a57-8182-0ff9791439e5", - }, - ), - CommandDTO( - action=CommandName.CREATE_ST_STORAGE.value, - args=[ - { - "area_id": "area 1", - "parameters": { - "efficiency": 1, - "group": "Battery", - "initiallevel": 0, - "initialleveloptim": False, - "injectionnominalcapacity": 0, - "name": "Storage 1", - "reservoircapacity": 0, - "withdrawalnominalcapacity": 0, - }, - "pmax_injection": "matrix://59ea6c83-6348-466d-9530-c35c51ca4c37", - "pmax_withdrawal": "matrix://5f988548-dadc-4bbb-8ce8-87a544dbf756", - "lower_rule_curve": "matrix://8ce4fcea-cc97-4d2c-b641-a27a53454612", - "upper_rule_curve": "matrix://8ce614c8-c687-41af-8b24-df8a49cc52af", - "inflows": "matrix://df9b25e1-e3f7-4a57-8182-0ff9791439e5", - }, - { - "area_id": "area 1", - "parameters": { - "efficiency": 0.94, - "group": "Battery", - "initiallevel": 0, - "initialleveloptim": False, - "injectionnominalcapacity": 0, - "name": "Storage 2", - "reservoircapacity": 0, - "withdrawalnominalcapacity": 0, - }, - "pmax_injection": "matrix://3f5b3746-3995-49b7-a6da-622633472e05", - "pmax_withdrawal": "matrix://4b64a31f-927b-4887-b4cd-adcddd39bdcd", - "lower_rule_curve": "matrix://16c7c3ae-9824-4ef2-aa68-51145884b025", - "upper_rule_curve": "matrix://9a6104e9-990a-415f-a6e2-57507e13b58c", - "inflows": "matrix://e8923768-9bdd-40c2-a6ea-2da2523be727", - }, - ], - ), - CommandDTO( - action=CommandName.REMOVE_ST_STORAGE.value, - args={ - "area_id": "area 1", - "storage_id": "storage 1", - }, - ), - CommandDTO( - action=CommandName.REMOVE_ST_STORAGE.value, - args=[ - { - "area_id": "area 1", - "storage_id": "storage 1", - }, - { - "area_id": "area 1", - "storage_id": "storage 2", - }, - ], - ), - CommandDTO( - action=CommandName.GENERATE_THERMAL_CLUSTER_TIMESERIES.value, - args=[{}], - ), - ], + COMMANDS, ) @pytest.mark.unit_test - def test_command_factory(self, command_dto: CommandDTO): - def get_matrix_id(matrix: str) -> str: - # str.removeprefix() is not available in Python 3.8 - prefix = "matrix://" - if matrix.startswith(prefix): - return matrix[len(prefix) :] - return matrix - - command_factory = CommandFactory( - generator_matrix_constants=Mock(spec=GeneratorMatrixConstants), - matrix_service=Mock(spec=MatrixService, get_matrix_id=get_matrix_id), - patch_service=Mock(spec=PatchService), - ) + def test_command_factory(self, command_dto: CommandDTO, command_factory: CommandFactory): commands = command_factory.to_command(command_dto=command_dto) if isinstance(command_dto.args, dict): @@ -440,12 +446,6 @@ def get_matrix_id(matrix: str) -> str: assert actual_args == expected_args assert actual_version == expected_version - self.command_class_set.discard(type(commands[0]).__name__) - - def teardown_class(self): - # Check that all command classes have been tested - assert not self.command_class_set - @pytest.mark.unit_test def test_unknown_command(): diff --git a/webapp/src/components/common/MatrixInput/index.tsx b/webapp/src/components/common/MatrixInput/index.tsx index c48e67d508..d035a94cce 100644 --- a/webapp/src/components/common/MatrixInput/index.tsx +++ b/webapp/src/components/common/MatrixInput/index.tsx @@ -183,7 +183,7 @@ function MatrixInput({ isPercentDisplayEnabled={enablePercentDisplay} /> ) : ( - !isLoading && + !isLoading && )} {openImportDialog && ( diff --git a/webapp/src/components/common/page/SimpleContent.tsx b/webapp/src/components/common/page/SimpleContent.tsx index 6be0cd51b0..03c8525eec 100644 --- a/webapp/src/components/common/page/SimpleContent.tsx +++ b/webapp/src/components/common/page/SimpleContent.tsx @@ -16,6 +16,7 @@ function EmptyView(props: EmptyViewProps) {