diff --git a/.env.example b/.env.example index 5a8ba4b..27b2c6e 100644 --- a/.env.example +++ b/.env.example @@ -4,6 +4,7 @@ DOMAIN=localhost SECRET_KEY= REFRESH_KEY= PROFILING=0 +JWT_USE_NONCE=0 # Backend BACKEND_CORS_ORIGINS=["http://localhost:8000","http://localhost:5000"] diff --git a/README.md b/README.md index 8181c8c..819d2ef 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ * 📝 [Loguru](https://github.com/Delgan/loguru) + [picologging](https://github.com/microsoft/picologging) for simplified and performant logging * 🐳 Dockerized and includes AWS deployment flow * 🗃️ Several database implementations with sample ORM models (MySQL, Postgres, Timescale) & migrations -* 🔐 JWT authentication and authorization +* 🔐 Optional JWT authentication and authorization * 🌐 AWS Lambda functions support * 🧩 Modularized features * 📊 Prometheus metrics @@ -44,6 +44,9 @@ * [Shell](#shell) * [Migrations](#migrations) * [Downgrade Migration](#downgrade-migration) +* [JWT Auth](#jwt-auth) + * [JWT Overview](#jwt-overview) + * [Modifying JWT Payload Fields](#modifying-jwt-payload-fields) * [Project Structure](#project-structure) * [Makefile Commands](#makefile-commands) * [Contributing](#contributing) @@ -220,6 +223,65 @@ Run this command to revert every migration back to the beginning. alembic downgrade base ``` +## JWT Implementation + +In this FastAPI template, JSON Web Tokens (JWT) can be optionally utilized for authentication. This documentation section elucidates the JWT implementation and related functionalities. + +### JWT Overview + +The JWT implementation can be found in the module: app/auth/jwt.py. The primary functions include: + +- Creating access and refresh JWT tokens. +- Verifying and decoding a given JWT token. +- Handling JWT-based authentication for FastAPI routes. + +#### User Management + +If a user associated with a JWT token is not found in the database, a new user will be created. This is managed by the get_or_create_user function. When a token is decoded and the corresponding user ID (sub field in the token) is not found, the system will attempt to create a new user with that ID. + +#### Nonce Usage + +A nonce is an arbitrary number that can be used just once. It's an optional field in the JWT token to ensure additional security. If a nonce is used: + +- It is stored in Redis for the duration of the refresh token's validity. +- It must match between access and refresh tokens to ensure their pairing. +- Its presence in Redis is verified before the token is considered valid. + +Enabling nonce usage provides an additional layer of security against token reuse, but requires Redis to function. + +### Modifying JWT Payload Fields + +The JWT token payload structure is defined in `app/types/jwt.py`` under the JWTPayload class. If you wish to add more fields to the JWT token payload: + +1. Update the TokenData and JWTPayload class in `app/types/jwt.py`` by adding the desired fields. + ```python + class JWTPayload(BaseModel): + # ... existing fields ... + new_field: Type + + class TokenData(BaseModel): + # ... existing fields ... + new_field: Type + ``` + + TokenData is separated from JWTPayload to make it clear what is automatically filled in and what is manually added. Both classes must be updated to include the new fields. + +2. Wherever the token is created, update the payload to include the new fields. + ```python + from app.auth.jwt import create_jwt + from app.types.jwt import TokenData + + payload = TokenData( + sub='user_id_1', + field1='value1', + # ... all fields ... + ) + access_token, refresh_token = create_jwt(payload) + ``` + +Remember, the JWT token has a size limit. The more data you include, the bigger your token becomes, so ensure that you only include essential data in the token payload. + + ## Project Structure ``` diff --git a/app/api/endpoints/auth.py b/app/api/endpoints/auth.py index 139d776..8da8608 100644 --- a/app/api/endpoints/auth.py +++ b/app/api/endpoints/auth.py @@ -1,4 +1,5 @@ from jose import JWTError +from loguru import logger from fastapi import APIRouter, Depends from fastapi.responses import ORJSONResponse @@ -14,16 +15,16 @@ @router.get('/login') -async def login(address: str, response: ORJSONResponse) -> ServerResponse[str]: - session = MySqlSession() - token = TokenData(sub=address) +async def login(response: ORJSONResponse) -> ServerResponse[str]: + # TODO: Look up the user here, or create one if they don't exist + # session = MySqlSession() + token = TokenData(sub='example_user_id') try: - accessToken, refreshToken = create_jwt(token, session) + accessToken, refreshToken = create_jwt(token) except JWTError as e: - return ServerResponse(status='error', message=f'JWT Error: {e}') - finally: - session.close() + logger.error(f'JWT Error during login: {e}') + return ServerResponse(status='error', message='JWT Error, try again') # Save the refresh token in an HTTPOnly cookie response.set_cookie( @@ -37,15 +38,16 @@ async def login(address: str, response: ORJSONResponse) -> ServerResponse[str]: @router.get('/refresh') -async def refresh(response: ORJSONResponse, - payload: JWTPayload = Depends(RequireRefreshToken)) -> ServerResponse[str]: +async def refresh( + response: ORJSONResponse, payload: JWTPayload = Depends(RequireRefreshToken) +) -> ServerResponse[str]: token = TokenData(sub=payload.sub) try: - accessToken, refreshToken = create_jwt( - token, userID=payload.id) + accessToken, refreshToken = create_jwt(token) except JWTError as e: - return ServerResponse(status='error', message=f'JWT Error: {e}') + logger.error(f'JWT Error during login: {e}') + return ServerResponse(status='error', message='JWT Error, try again.') # Save the refresh token in an HTTPOnly cookie response.set_cookie( diff --git a/app/api/endpoints/discord_auth.py b/app/api/endpoints/discord_auth.py new file mode 100644 index 0000000..678968e --- /dev/null +++ b/app/api/endpoints/discord_auth.py @@ -0,0 +1,117 @@ +import jwt +import json +import base64 + +from fastapi import FastAPI, Depends, HTTPException, Request, Response +from fastapi.security import OAuth2PasswordBearer +from fastapi.templating import Jinja2Templates +from fastapi.responses import RedirectResponse +from pydantic import BaseModel +from typing import Optional + +app = FastAPI() +templates = Jinja2Templates(directory="templates") # Assuming your templates are in a 'templates' directory + +# JWT secret key +SECRET_KEY = "your_secret_key_here" + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +def createOAuthSession(): + # Your logic for creating an OAuth session + pass + +def generateKey(): + # Your logic for generating a key + pass + +def removeStripeCookies(): + # Your logic to remove Stripe cookies + pass + +def userHasDiscordAuthToken(token: str = Depends(oauth2_scheme)) -> bool: + # Decode JWT token and verify user has Discord auth token + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) + return bool(payload.get("discord_auth")) + except: + return False + +@app.get("/signin") +async def signin(request: Request, token: str = Depends(oauth2_scheme)): + if userHasDiscordAuthToken(token): + # Redirect to the license page + return RedirectResponse(url_for('auth.license')) + + oauth = createOAuthSession() + state = { + 'nonce': generateKey(), + } + nextUrl = request.query_params.get('next') + if nextUrl: + state.update({'redirect': nextUrl}) + + state = saveAsState(state) + loginUrl, state = oauth.authorization_url( + Config.DISCORD_AUTHORIZE_URL, state=state) + response = templates.TemplateResponse('auth/authWithDiscord.html', { + "request": request, + "title": 'Waffler Sign In', + "loginUrl": loginUrl + }) + response.set_cookie(key=Cookies.DISCORD_STATE, value=state) + + removeStripeCookies() + return response + +@app.get("/logout") +def logout(request: Request): + response = RedirectResponse(url=request.query_params.get('next', '/')) + # Your logic to remove all cookies, for example: + response.delete_cookie(key="your_cookie_name") + return response + +@app.get("/oauth_callback") +def discordOAuthCallback(request: Request): + state = request.query_params.get('state', '') + if not state or state != request.cookies.get(Cookies.DISCORD_STATE): + return RedirectResponse('/') + + oauth = createOAuthSession() + try: + token = oauth.fetch_token( + Config.DISCORD_TOKEN_URL, + client_secret=Config.DISCORD_CLIENT_SECRET, + authorization_response=request.url, + ) + except Exception: + return RedirectResponse('/') + + jwt_token = jwt.encode({"discord_auth": token}, SECRET_KEY, algorithm="HS256") + + state_dict = getState(request) + if state_dict.get('redirect'): + params = state_dict.get('params', {}) + redir = state_dict.get('redirect') + if params: + redir += '?' + urlencode(params) + + response = RedirectResponse(redir) + else: + # Go to the user profile on default + response = RedirectResponse(url_for('userDashboard.userProfile')) + + response.set_cookie(key=Cookies.DISCORD_TOKEN, value=jwt_token) + return response + +def saveAsState(state: dict) -> str: + state = json.dumps(state) + return base64.b64encode(state.encode()).decode() + +def getState(request: Request) -> dict: + state = request.cookies.get(Cookies.DISCORD_STATE, '') + if not state: + return {} + + state = base64.b64decode(state).decode() + return json.loads(state) diff --git a/app/auth/jwt.py b/app/auth/jwt.py index 1058c90..fd27889 100644 --- a/app/auth/jwt.py +++ b/app/auth/jwt.py @@ -1,139 +1,149 @@ from jose import jwt -from sqlalchemy.orm import Session +from loguru import logger +from jose.constants import Algorithms from fastapi.security import HTTPBearer from datetime import datetime, timedelta from fastapi import Request, HTTPException from jose.exceptions import JWTClaimsError, JWTError, ExpiredSignatureError -from app.models.mysql import User from app.core.config import settings from app.cache.redis import SessionStore from app.util.common import generateNonce from app.types.server import Cookie -from app.types.cache import UserKey from app.types.jwt import TokenData, JWTPayload +from app.types.cache import RedisTokenPrefix, UserKey -ALGORITHM = 'HS256' +ALGORITHM = Algorithms.HS256 -def create_jwt(data: TokenData, - session: Session | None = None, - userID: int | None = None) -> tuple[str, str]: +def create_jwt(data: TokenData) -> tuple[str, str]: """ - Create a JWT token that expires in 30 min. If the user does not exist - in the database, a new user will be created. - - Raises: - JWTError: If there is an error encoding the claims. + Create access and refresh JWT tokens. + If the user ID is provided, the database won't be queried. + If the user does not exist in the database and no user ID is provided, a new user will be created. + The nonce is stored in the cache for refresh token invalidation. """ - # If no user info was provided, we need to query it - if session and userID is None: - # Query the database to get the user object. Create a new user if needed. - user: User | None = session.query( - User).filter_by(address=data.sub).first() - if not user: - user = User(address=data.sub) - session.add(user) - session.commit() - userID = user.id - - if userID is None: - raise JWTError('Could not get user info.') - - nonce = generateNonce() - - # Create the access token - payload = JWTPayload( - exp=datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), - id=userID, - nonce=nonce, - iat=datetime.utcnow(), - **data.model_dump(), - ).model_dump() - accessToken = jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM) + nonce = None + if settings.JWT_USE_NONCE: + nonce = generateNonce() - # Create the refresh token - payload['exp'] = datetime.utcnow( - ) + timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES) - refreshToken = jwt.encode( - payload, settings.SECRET_KEY, algorithm=ALGORITHM) + access_token = create_token(data, nonce, settings.ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_token = create_token(data, nonce, settings.REFRESH_TOKEN_EXPIRE_MINUTES) - # Save the nonce in redis - cache = SessionStore(data.sub, settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60) - cache.set(UserKey.NONCE, nonce) + # Save the nonce in the cache for refresh token invalidation, only if using nonce + if nonce: + set_nonce_in_cache(data.sub, nonce, settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60) - return accessToken, refreshToken + return access_token, refresh_token def verify_token(token: str) -> JWTPayload | None: """ - Decode a JWT token. + Decode a JWT token. """ try: - return JWTPayload(**jwt.decode( - token, - settings.SECRET_KEY, - algorithms=[ALGORITHM], - options={ - 'require_iat': True, - 'require_exp': True, - 'require_sub': True, - } - )) - except (JWTError, ExpiredSignatureError, JWTClaimsError): + payload = JWTPayload( + **jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[ALGORITHM], + options={ + 'require_iat': True, + 'require_exp': True, + 'require_sub': True, + }, + ) + ) + if settings.JWT_USE_NONCE and not payload.nonce: + logger.error('Nonce not found in JWT payload.') + return None + return payload + except (JWTError, ExpiredSignatureError, JWTClaimsError) as e: + logger.error(f'Error while verifying JWT: {e}') return None class RequireJWT(HTTPBearer): """ - Custom FastAPI dependency for JWT authentication. - Returns the decoded JWT payload if the token is valid. + Custom FastAPI dependency for JWT authentication. + Returns the decoded JWT payload if the token is valid. """ + async def __call__(self, request: Request): credentials = await super(RequireJWT, self).__call__(request) refreshToken = request.cookies.get(Cookie.REFRESH_TOKEN, '') - if credentials: + if credentials and credentials.credentials: payload = verify_token(credentials.credentials) refreshTokenPayload = verify_token(refreshToken) - if not payload or not refreshTokenPayload or \ - not payload.nonce == refreshTokenPayload.nonce: - raise HTTPException( - status_code=403, detail='Invalid token or expired token.') - # Make sure the nonce is in the cache - # This is the invalidation mechanism for refresh tokens - if not _nonceInCache(refreshTokenPayload.sub, refreshTokenPayload.nonce): - raise HTTPException( - status_code=403, detail='Invalid token or expired token.') + print(payload) + print(refreshTokenPayload) - return payload + # Check if payloads are valid + if not (payload and refreshTokenPayload): + raise HTTPException(status_code=403, detail='Invalid token or expired token.') + # If using nonce, check nonce + if not validate_nonce(payload, refreshTokenPayload): + raise HTTPException(status_code=403, detail='Nonce validation failed.') + + return payload else: - raise HTTPException( - status_code=403, detail='Invalid authorization code.') + raise HTTPException(status_code=403, detail='Invalid authorization code.') def RequireRefreshToken(request: Request) -> JWTPayload: refreshToken = request.cookies.get(Cookie.REFRESH_TOKEN, '') refreshTokenPayload = verify_token(refreshToken) if not refreshTokenPayload: - raise HTTPException( - status_code=403, detail='Invalid token or expired token.') + raise HTTPException(status_code=403, detail='Invalid token or expired token.') - # Make sure the nonce is in the cache - # This is the invalidation mechanism for refresh tokens - if not _nonceInCache(refreshTokenPayload.sub, refreshTokenPayload.nonce): - raise HTTPException( - status_code=403, detail='Invalid token or expired token.') + # If using nonce, ensure it's in the cache + if settings.JWT_USE_NONCE and not is_nonce_in_cache( + refreshTokenPayload.sub, refreshTokenPayload.nonce + ): + raise HTTPException(status_code=403, detail='Invalid token or expired token.') return refreshTokenPayload -def _nonceInCache(address: str, nonce: str) -> bool: +def create_token(data: TokenData, nonce: str | None, expire_minutes: int) -> str: + payload = { + **data.model_dump(), + 'exp': datetime.utcnow() + timedelta(minutes=expire_minutes), + 'iat': datetime.utcnow(), + } + if settings.JWT_USE_NONCE and nonce: + payload['nonce'] = nonce + return jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM) + + +def validate_nonce(payload: JWTPayload, refreshTokenPayload: JWTPayload) -> bool: + if not settings.JWT_USE_NONCE: + return True + if payload.nonce != refreshTokenPayload.nonce: + raise HTTPException(status_code=403, detail='Nonce mismatch.') + if not is_nonce_in_cache(refreshTokenPayload.sub, refreshTokenPayload.nonce): + raise HTTPException(status_code=403, detail='Invalid token or expired token.') + return True + + +def set_nonce_in_cache(user_id: str, nonce: str, expiration_time: int): + """ + Store nonce in cache with a specified expiration time. + """ + if settings.JWT_USE_NONCE: + cache = SessionStore(RedisTokenPrefix.USER, user_id, ttl=expiration_time) + cache.set(UserKey.NONCE, nonce) + + +def is_nonce_in_cache(user_id: str, nonce: str | None) -> bool: """ - Check if the nonce is in the cache. + Check if the nonce is in the cache. """ - cache = SessionStore(address) + if not nonce: + return False + cache = SessionStore(RedisTokenPrefix.USER, user_id) return cache.get(UserKey.NONCE) == nonce diff --git a/app/cache/redis.py b/app/cache/redis.py index 2ec23ee..f6cd7f9 100644 --- a/app/cache/redis.py +++ b/app/cache/redis.py @@ -9,18 +9,22 @@ class SessionStore: @classmethod def get_pool(cls): if cls._pool is None: - cls._pool = redis.ConnectionPool(host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - password=settings.REDIS_PASSWORD, - connection_class=redis.SSLConnection) + cls._pool = redis.ConnectionPool( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + password=settings.REDIS_PASSWORD, + connection_class=redis.SSLConnection, + ) return cls._pool - def __init__(self, token: str, ttl: int = (60 * 60 * 4)): + def __init__(self, *tokens: str, ttl: int = (60 * 60 * 4)): """ - :param token: Used to create a session in redis. Key/value pairs are unique to this token. - :param ttl: Time to live in seconds. Defaults to 4 hours. + Params:\n + tokens - Used to create a session in redis. Key/value pairs are unique to this token. + Pass in multiple tokens and they will be joined into one.\n + ttl - Time to live in seconds. Defaults to 4 hours. """ - self.token = token + self.token = ':'.join(tokens) self.redis = redis.StrictRedis(connection_pool=self.get_pool()) self.ttl = ttl diff --git a/app/core/config.py b/app/core/config.py index 249943a..e8bf5b8 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -12,6 +12,7 @@ class EnvConfigSettings(BaseSettings): SECRET_KEY: str REFRESH_KEY: str PROFILING: bool = False + JWT_USE_NONCE: bool ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 # 30 minutes REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 3 # 3 days BACKEND_CORS_ORIGINS: list[AnyHttpUrl] = [] diff --git a/app/types/cache.py b/app/types/cache.py index 5f3c56f..170f3ba 100644 --- a/app/types/cache.py +++ b/app/types/cache.py @@ -2,11 +2,16 @@ # string values, not Enum values. -class AuthSession: - ADDRESS = 'address' - SESSION_ID = 'session_id' - NONCE = 'nonce' +class RedisTokenPrefix: + """ + When creating a SessionStore, it is useful to have a prefix to avoid + collisions with other keys in Redis. + """ + USER = 'user' class UserKey: + """ + Keys used to store user data in Redis. + """ NONCE = 'nonce' diff --git a/app/types/jwt.py b/app/types/jwt.py index d091882..4b1af78 100644 --- a/app/types/jwt.py +++ b/app/types/jwt.py @@ -13,9 +13,7 @@ class JWTPayload(BaseModel): # Issued at iat: datetime # Unique hex number - nonce: str - # User ID - id: int + nonce: str | None = None class TokenData(BaseModel): diff --git a/app/types/server.py b/app/types/server.py index 2e0aca1..9ea89ed 100644 --- a/app/types/server.py +++ b/app/types/server.py @@ -33,4 +33,4 @@ def dict(self, *args, **kwargs) -> dict[str, Any]: Override the default dict method to exclude None values in the response """ kwargs.pop('exclude_none', None) - return super().dict(*args, exclude_none=True, **kwargs) + return super().model_dump(*args, exclude_none=True, **kwargs) diff --git a/loguru/_better_exceptions.py b/loguru/_better_exceptions.py index e5ea4ad..9338734 100644 --- a/loguru/_better_exceptions.py +++ b/loguru/_better_exceptions.py @@ -441,7 +441,6 @@ def _format_exception( # on the indentation; the preliminary context for "SyntaxError" is always indented, while # the Exception itself is not. This allows us to identify the correct index for the # exception message. - error_message_index = 0 for error_message_index, part in enumerate(exception_only): # noqa: B007 if not part.startswith(" "): break