Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
fix(dependencies): handle cases where the token isn't JWT compliant (#…
Browse files Browse the repository at this point in the history
…123)

* fix(dependencies): catch the cases where the token isn't JWT

* test(dependencies): add test cases for dependencies
  • Loading branch information
frgfm authored Mar 14, 2024
1 parent c699719 commit 37a4f7f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
from jwt import ExpiredSignatureError, InvalidSignatureError
from jwt import DecodeError, ExpiredSignatureError, InvalidSignatureError
from jwt import decode as jwt_decode
from pydantic import ValidationError
from sqlmodel.ext.asyncio.session import AsyncSession
Expand Down Expand Up @@ -57,6 +57,12 @@ async def get_token_payload(
detail="Token has expired.",
headers={"WWW-Authenticate": authenticate_value},
)
except DecodeError:
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail="Invalid token.",
headers={"WWW-Authenticate": authenticate_value},
)

try:
user_id = int(payload["sub"])
Expand Down
28 changes: 28 additions & 0 deletions src/tests/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from fastapi import HTTPException
from fastapi.security import SecurityScopes

from app.api.dependencies import get_token_payload
from app.core.security import create_access_token


@pytest.mark.parametrize(
("scopes", "token", "expires_minutes", "error_code", "expected_payload"),
[
(["admin"], "", None, 406, None),
(["admin"], {"user_id": "123", "scopes": ["admin"]}, None, 422, None),
(["admin"], {"sub": "123", "scopes": ["admin"]}, -1, 401, None),
(["admin"], {"sub": "123", "scopes": ["admin"]}, None, None, {"user_id": 123, "scopes": ["admin"]}),
(["admin"], {"sub": "123", "scopes": ["user"]}, None, 403, None),
],
)
@pytest.mark.asyncio()
async def test_get_token_payload(scopes, token, expires_minutes, error_code, expected_payload):
_token = await create_access_token(token, expires_minutes) if isinstance(token, dict) else token
if isinstance(error_code, int):
with pytest.raises(HTTPException):
await get_token_payload(SecurityScopes(scopes), _token)
else:
payload = await get_token_payload(SecurityScopes(scopes), _token)
if expected_payload is not None:
assert payload.model_dump() == expected_payload

0 comments on commit 37a4f7f

Please sign in to comment.