diff --git a/antarest/front.py b/antarest/front.py index a0699812bf..16cd92544f 100644 --- a/antarest/front.py +++ b/antarest/front.py @@ -22,7 +22,7 @@ import re from pathlib import Path -from typing import Any, Optional, Sequence +from typing import Any, List, Optional, Sequence from fastapi import FastAPI from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint @@ -46,7 +46,8 @@ def __init__( self, app: ASGIApp, dispatch: Optional[DispatchFunction] = None, - route_paths: Sequence[str] = (), + protected_roots: Optional[List[str]] = None, + protected_paths: Optional[List[str]] = None, ) -> None: """ Initializes an instance of the URLRewriterMiddleware. @@ -54,15 +55,26 @@ def __init__( Args: app: The ASGI application to which the middleware is applied. dispatch: The dispatch function to use. - route_paths: The known route paths of the application. - Requests that do not match any of these paths will be rewritten to the root path. + protected_roots: URL starting at those roots will not be redirected + protected_paths: those URLs will not be redirected Note: The `route_paths` should contain all the known endpoints of the application. """ dispatch = self.dispatch if dispatch is None else dispatch super().__init__(app, dispatch) - self.known_prefixes = {re.findall(r"/(?:(?!/).)*", p)[0] for p in route_paths if p != "/"} + + self.protected_paths = protected_paths or [] + protected_roots = protected_roots or [] + self.protected_roots = [r.rstrip("/") for r in protected_roots] + + def _path_matches_protected_paths(self, path: str) -> bool: + if path in self.protected_paths: + return True + for root in self.protected_roots: + if path == root or path.startswith(f"{root}/"): + return True + return False async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Any: """ @@ -72,7 +84,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - url_path = request.scope["path"] if url_path in {"", "/"}: pass - elif not any(url_path.startswith(ep) for ep in self.known_prefixes): + elif not self._path_matches_protected_paths(url_path): request.scope["path"] = "/" return await call_next(request) @@ -112,8 +124,9 @@ def add_front_app(application: FastAPI, resources_dir: Path, api_prefix: str) -> front_app_dir = resources_dir / "webapp" # Serve front-end files + static_files_root = "/static" application.mount( - "/static", + static_files_root, StaticFiles(directory=front_app_dir), name="static", ) @@ -132,8 +145,6 @@ def get_api_paths_config(request: Request) -> BackEndConfig: # is served at the `/static` entry point. Any requests that are not API # requests should be redirected to the `index.html` file, which will handle # the route provided by the URL. - route_paths = [r.path for r in application.routes] # type: ignore application.add_middleware( - RedirectMiddleware, - route_paths=route_paths, + RedirectMiddleware, protected_roots=[static_files_root, api_prefix], protected_paths=["/config.json"] ) diff --git a/tests/test_front.py b/tests/test_front.py index 5046a868cf..afe58f5fd3 100644 --- a/tests/test_front.py +++ b/tests/test_front.py @@ -64,8 +64,9 @@ def redirect_app(app_with_home: FastAPI) -> FastAPI: """ Same as app with redirect middleware """ - route_paths = [r.path for r in app_with_home.routes] # type: ignore - app_with_home.add_middleware(RedirectMiddleware, route_paths=route_paths) + app_with_home.add_middleware( + RedirectMiddleware, protected_roots=["/api", "static"], protected_paths=["/config.json"] + ) return app_with_home @@ -106,6 +107,10 @@ def test_frontend_paths(base_back_app, resources_dir: Path) -> None: assert front_route_response.status_code == 200 assert front_route_response.text == "index" + front_route_response = client.get("/apidoc") + assert front_route_response.status_code == 200 + assert front_route_response.text == "index" + front_static_file_response = client.get("/static/front.css") assert front_static_file_response.status_code == 200 assert front_static_file_response.text == "css"