Skip to content

Commit

Permalink
Merge branch 'dev' into feature/745-load-balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvlecl authored Oct 14, 2024
2 parents 3997248 + 5421a83 commit 96c82f3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
31 changes: 21 additions & 10 deletions antarest/front.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import re
from pathlib import Path
from typing import Any, Optional, Sequence
from typing import Any, List, Optional, Sequence

from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware, DispatchFunction, RequestResponseEndpoint
Expand All @@ -46,23 +46,35 @@ def __init__(
self,
app: ASGIApp,
dispatch: Optional[DispatchFunction] = None,
route_paths: Sequence[str] = (),
protected_roots: Optional[List[str]] = None,
protected_paths: Optional[List[str]] = None,
) -> None:
"""
Initializes an instance of the URLRewriterMiddleware.
Args:
app: The ASGI application to which the middleware is applied.
dispatch: The dispatch function to use.
route_paths: The known route paths of the application.
Requests that do not match any of these paths will be rewritten to the root path.
protected_roots: URL starting at those roots will not be redirected
protected_paths: those URLs will not be redirected
Note:
The `route_paths` should contain all the known endpoints of the application.
"""
dispatch = self.dispatch if dispatch is None else dispatch
super().__init__(app, dispatch)
self.known_prefixes = {re.findall(r"/(?:(?!/).)*", p)[0] for p in route_paths if p != "/"}

self.protected_paths = protected_paths or []
protected_roots = protected_roots or []
self.protected_roots = [r.rstrip("/") for r in protected_roots]

def _path_matches_protected_paths(self, path: str) -> bool:
if path in self.protected_paths:
return True
for root in self.protected_roots:
if path == root or path.startswith(f"{root}/"):
return True
return False

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Any:
"""
Expand All @@ -72,7 +84,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
url_path = request.scope["path"]
if url_path in {"", "/"}:
pass
elif not any(url_path.startswith(ep) for ep in self.known_prefixes):
elif not self._path_matches_protected_paths(url_path):
request.scope["path"] = "/"
return await call_next(request)

Expand Down Expand Up @@ -112,8 +124,9 @@ def add_front_app(application: FastAPI, resources_dir: Path, api_prefix: str) ->
front_app_dir = resources_dir / "webapp"

# Serve front-end files
static_files_root = "/static"
application.mount(
"/static",
static_files_root,
StaticFiles(directory=front_app_dir),
name="static",
)
Expand All @@ -132,8 +145,6 @@ def get_api_paths_config(request: Request) -> BackEndConfig:
# is served at the `/static` entry point. Any requests that are not API
# requests should be redirected to the `index.html` file, which will handle
# the route provided by the URL.
route_paths = [r.path for r in application.routes] # type: ignore
application.add_middleware(
RedirectMiddleware,
route_paths=route_paths,
RedirectMiddleware, protected_roots=[static_files_root, api_prefix], protected_paths=["/config.json"]
)
9 changes: 7 additions & 2 deletions tests/test_front.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def redirect_app(app_with_home: FastAPI) -> FastAPI:
"""
Same as app with redirect middleware
"""
route_paths = [r.path for r in app_with_home.routes] # type: ignore
app_with_home.add_middleware(RedirectMiddleware, route_paths=route_paths)
app_with_home.add_middleware(
RedirectMiddleware, protected_roots=["/api", "static"], protected_paths=["/config.json"]
)
return app_with_home


Expand Down Expand Up @@ -106,6 +107,10 @@ def test_frontend_paths(base_back_app, resources_dir: Path) -> None:
assert front_route_response.status_code == 200
assert front_route_response.text == "index"

front_route_response = client.get("/apidoc")
assert front_route_response.status_code == 200
assert front_route_response.text == "index"

front_static_file_response = client.get("/static/front.css")
assert front_static_file_response.status_code == 200
assert front_static_file_response.text == "css"

0 comments on commit 96c82f3

Please sign in to comment.