Skip to content

Commit

Permalink
Merge pull request writer#507 from FabienArcellier/501-redirect-after…
Browse files Browse the repository at this point in the history
…-authentication

fix: Redirect after authentication on http://localhost without /
  • Loading branch information
ramedina86 authored Aug 12, 2024
2 parents bbe184b + 447f581 commit 514705a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 22 deletions.
67 changes: 45 additions & 22 deletions src/writer/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import dataclasses
import logging
import os.path
import time
from abc import ABCMeta, abstractmethod
Expand All @@ -16,6 +17,8 @@
from writer.serve import WriterFastAPI
from writer.ss_types import InitSessionRequestPayload

logger = logging.getLogger('writer')

# Dictionary for storing failed attempts {ip_address: timestamp}
failed_attempts: Dict[str, float] = {}

Expand Down Expand Up @@ -181,11 +184,23 @@ def register(self,
callback: Optional[Callable[[Request, str, dict], None]] = None,
unauthorized_action: Optional[Callable[[Request, Unauthorized], Response]] = None
):

redirect_url = urljoin(self.host_url, self.callback_authorize)
host_url_path = urlpath(self.host_url)
callback_authorize_path = urljoin(host_url_path, self.callback_authorize)
asset_assets_path = urljoin(host_url_path, "assets")

logger.debug(f"[auth] oidc - url redirect: {redirect_url}")
logger.debug(f"[auth] oidc - endpoint authorize: {self.url_authorize}")
logger.debug(f"[auth] oidc - endpoint token: {self.url_oauthtoken}")
logger.debug(f"[auth] oidc - path: {host_url_path}")
logger.debug(f"[auth] oidc - authorize path: {callback_authorize_path}")
logger.debug(f"[auth] oidc - asset path: {asset_assets_path}")
self.authlib = OAuth2Session(
client_id=self.client_id,
client_secret=self.client_secret,
scope=self.scope.split(" "),
redirect_uri=_urljoin(self.host_url, self.callback_authorize),
redirect_uri=redirect_url,
authorization_endpoint=self.url_authorize,
token_endpoint=self.url_oauthtoken,
)
Expand All @@ -195,22 +210,20 @@ def register(self,
@asgi_app.middleware("http")
async def oidc_middleware(request: Request, call_next):
session = request.cookies.get('session')
host_url_path = _urlpath(self.host_url)
full_callback_authorize = '/' + _urljoin(host_url_path, self.callback_authorize)
full_assets = '/' + _urljoin(host_url_path, '/assets')
if session is not None or request.url.path in [full_callback_authorize] or request.url.path.startswith(full_assets):

if session is not None or request.url.path in [callback_authorize_path] or request.url.path.startswith(asset_assets_path):
response: Response = await call_next(request)
return response
else:
url = self.authlib.create_authorization_url(self.url_authorize)
response = RedirectResponse(url=url[0])
return response

@asgi_app.get('/' + _urlstrip(self.callback_authorize))
@asgi_app.get('/' + urlstrip(self.callback_authorize))
async def route_callback(request: Request):
self.authlib.fetch_token(url=self.url_oauthtoken, authorization_response=str(request.url))
try:
host_url_path = _urlpath(self.host_url)
host_url_path = urlpath(self.host_url)
response = RedirectResponse(url=host_url_path)
session_id = session_manager.generate_session_id()

Expand Down Expand Up @@ -300,44 +313,54 @@ def Auth0(client_id: str, client_secret: str, domain: str, host_url: str) -> Oid
url_oauthtoken=f"https://{domain}/oauth/token",
url_userinfo=f"https://{domain}/userinfo")

def _urlpath(url: str):
def urlpath(url: str):
"""
>>> _urlpath("http://localhost/app1")
>>> urlpath("http://localhost/app1")
>>> "/app1"
>>> urlpath("http://localhost")
>>> "/"
"""
return urlparse(url).path
path = urlparse(url).path
if len(path) == 0:
return "/"
else:
return path

def _urljoin(*args):
def urljoin(*args):
"""
>>> _urljoin("http://localhost/app1", "edit")
>>> urljoin("http://localhost/app1", "edit")
>>> "http://localhost/app1/edit"
>>> _urljoin("app1/", "edit")
>>> urljoin("app1/", "edit")
>>> "app1/edit"
>>> _urljoin("app1", "edit")
>>> urljoin("app1", "edit")
>>> "app1/edit"
>>> _urljoin("/app1/", "/edit")
>>> "app1/edit"
>>> urljoin("/app1/", "/edit")
>>> "/app1/edit"
"""
root_part = args[0]
root_part_is_root_path = root_part.startswith('/') and len(root_part) > 1

url_strip_parts = []
for part in args:
if part:
url_strip_parts.append(_urlstrip(part))
url_strip_parts.append(urlstrip(part))

return '/'.join(url_strip_parts)
return '/'.join(url_strip_parts) if root_part_is_root_path is False else '/' + '/'.join(url_strip_parts)

def _urlstrip(url_path: str):
def urlstrip(url_path: str):
"""
>>> _urlstrip("/app1/")
>>> urlstrip("/app1/")
>>> "app1"
>>> _urlstrip("http://localhost/app1")
>>> urlstrip("http://localhost/app1")
>>> "http://localhost/app1"
>>> _urlstrip("http://localhost/app1/")
>>> urlstrip("http://localhost/app1/")
>>> "http://localhost/app1"
"""
return url_path.strip('/')
Expand Down
37 changes: 37 additions & 0 deletions tests/backend/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import fastapi
import fastapi.testclient
import pytest
import writer.serve
from writer import auth

from tests.backend import test_basicauth_dir

Expand Down Expand Up @@ -35,3 +37,38 @@ def test_basicauth_authentication_module_disabled_when_server_setup_hook_is_disa
with fastapi.testclient.TestClient(asgi_app) as client:
res = client.get("/api/init")
assert res.status_code == 405

@pytest.mark.parametrize("path,expected_path", [
("", "/"),
("http://localhost", "/"),
("http://localhost/", "/"),
("http://localhost/any", "/any"),
("http://localhost/any/", "/any/"),
("/any/yolo", "/any/yolo")
])
def test_url_path_scenarios(self, path: str, expected_path: str):
assert auth.urlpath(path) == expected_path

@pytest.mark.parametrize("path,expected_path", [
("/", ""),
("/yolo", "yolo"),
("/yolo/", "yolo"),
("http://localhost", "http://localhost"),
("http://localhost/", "http://localhost"),
("http://localhost/any", "http://localhost/any"),
("http://localhost/any/", "http://localhost/any")
])
def test_url_split_scenarios(self, path: str, expected_path: str):
assert auth.urlstrip(path) == expected_path

@pytest.mark.parametrize("path1,path2,expected_path", [
("/", "any", "/any"),
("", "any", "any"),
("/yolo", "any", "/yolo/any"),
("/yolo", "/any", "/yolo/any"),
("http://localhost", "any", "http://localhost/any"),
("http://localhost/", "/any", "http://localhost/any"),
("http://localhost/yolo", "/any", "http://localhost/yolo/any"),
])
def test_urljoin_scenarios(self, path1: str, path2, expected_path: str):
assert auth.urljoin(path1, path2) == expected_path

0 comments on commit 514705a

Please sign in to comment.