Skip to content

Commit

Permalink
* Update JWT implementation
Browse files Browse the repository at this point in the history
* Made nonce optional
* Added more docs
* Refactors
  • Loading branch information
FastestMolasses committed Nov 4, 2023
1 parent 0d7227d commit 9a0f9f8
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 113 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DOMAIN=localhost
SECRET_KEY=
REFRESH_KEY=
PROFILING=0
JWT_USE_NONCE=0

# Backend
BACKEND_CORS_ORIGINS=["http://localhost:8000","http://localhost:5000"]
Expand Down
64 changes: 63 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* 📝 [Loguru](https://github.com/Delgan/loguru) + [picologging](https://github.com/microsoft/picologging) for simplified and performant logging
* 🐳 Dockerized and includes AWS deployment flow
* 🗃️ Several database implementations with sample ORM models (MySQL, Postgres, Timescale) & migrations
* 🔐 JWT authentication and authorization
* 🔐 Optional JWT authentication and authorization
* 🌐 AWS Lambda functions support
* 🧩 Modularized features
* 📊 Prometheus metrics
Expand All @@ -44,6 +44,9 @@
* [Shell](#shell)
* [Migrations](#migrations)
* [Downgrade Migration](#downgrade-migration)
* [JWT Auth](#jwt-auth)
* [JWT Overview](#jwt-overview)
* [Modifying JWT Payload Fields](#modifying-jwt-payload-fields)
* [Project Structure](#project-structure)
* [Makefile Commands](#makefile-commands)
* [Contributing](#contributing)
Expand Down Expand Up @@ -220,6 +223,65 @@ Run this command to revert every migration back to the beginning.
alembic downgrade base
```
## JWT Implementation
In this FastAPI template, JSON Web Tokens (JWT) can be optionally utilized for authentication. This documentation section elucidates the JWT implementation and related functionalities.
### JWT Overview
The JWT implementation can be found in the module: app/auth/jwt.py. The primary functions include:
- Creating access and refresh JWT tokens.
- Verifying and decoding a given JWT token.
- Handling JWT-based authentication for FastAPI routes.
#### User Management
If a user associated with a JWT token is not found in the database, a new user will be created. This is managed by the get_or_create_user function. When a token is decoded and the corresponding user ID (sub field in the token) is not found, the system will attempt to create a new user with that ID.
#### Nonce Usage
A nonce is an arbitrary number that can be used just once. It's an optional field in the JWT token to ensure additional security. If a nonce is used:

- It is stored in Redis for the duration of the refresh token's validity.
- It must match between access and refresh tokens to ensure their pairing.
- Its presence in Redis is verified before the token is considered valid.
Enabling nonce usage provides an additional layer of security against token reuse, but requires Redis to function.
### Modifying JWT Payload Fields
The JWT token payload structure is defined in `app/types/jwt.py`` under the JWTPayload class. If you wish to add more fields to the JWT token payload:
1. Update the TokenData and JWTPayload class in `app/types/jwt.py`` by adding the desired fields.
```python
class JWTPayload(BaseModel):
# ... existing fields ...
new_field: Type
class TokenData(BaseModel):
# ... existing fields ...
new_field: Type
```
TokenData is separated from JWTPayload to make it clear what is automatically filled in and what is manually added. Both classes must be updated to include the new fields.
2. Wherever the token is created, update the payload to include the new fields.
```python
from app.auth.jwt import create_jwt
from app.types.jwt import TokenData
payload = TokenData(
sub='user_id_1',
field1='value1',
# ... all fields ...
)
access_token, refresh_token = create_jwt(payload)
```
Remember, the JWT token has a size limit. The more data you include, the bigger your token becomes, so ensure that you only include essential data in the token payload.
## Project Structure
```
Expand Down
26 changes: 14 additions & 12 deletions app/api/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from jose import JWTError
from loguru import logger
from fastapi import APIRouter, Depends
from fastapi.responses import ORJSONResponse

Expand All @@ -14,16 +15,16 @@


@router.get('/login')
async def login(address: str, response: ORJSONResponse) -> ServerResponse[str]:
session = MySqlSession()
token = TokenData(sub=address)
async def login(response: ORJSONResponse) -> ServerResponse[str]:
# TODO: Look up the user here, or create one if they don't exist
# session = MySqlSession()
token = TokenData(sub='example_user_id')

try:
accessToken, refreshToken = create_jwt(token, session)
accessToken, refreshToken = create_jwt(token)
except JWTError as e:
return ServerResponse(status='error', message=f'JWT Error: {e}')
finally:
session.close()
logger.error(f'JWT Error during login: {e}')
return ServerResponse(status='error', message='JWT Error, try again')

# Save the refresh token in an HTTPOnly cookie
response.set_cookie(
Expand All @@ -37,15 +38,16 @@ async def login(address: str, response: ORJSONResponse) -> ServerResponse[str]:


@router.get('/refresh')
async def refresh(response: ORJSONResponse,
payload: JWTPayload = Depends(RequireRefreshToken)) -> ServerResponse[str]:
async def refresh(
response: ORJSONResponse, payload: JWTPayload = Depends(RequireRefreshToken)
) -> ServerResponse[str]:
token = TokenData(sub=payload.sub)

try:
accessToken, refreshToken = create_jwt(
token, userID=payload.id)
accessToken, refreshToken = create_jwt(token)
except JWTError as e:
return ServerResponse(status='error', message=f'JWT Error: {e}')
logger.error(f'JWT Error during login: {e}')
return ServerResponse(status='error', message='JWT Error, try again.')

# Save the refresh token in an HTTPOnly cookie
response.set_cookie(
Expand Down
117 changes: 117 additions & 0 deletions app/api/endpoints/discord_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import jwt
import json
import base64

from fastapi import FastAPI, Depends, HTTPException, Request, Response
from fastapi.security import OAuth2PasswordBearer
from fastapi.templating import Jinja2Templates
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from typing import Optional

app = FastAPI()
templates = Jinja2Templates(directory="templates") # Assuming your templates are in a 'templates' directory

# JWT secret key
SECRET_KEY = "your_secret_key_here"

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

def createOAuthSession():
# Your logic for creating an OAuth session
pass

def generateKey():
# Your logic for generating a key
pass

def removeStripeCookies():
# Your logic to remove Stripe cookies
pass

def userHasDiscordAuthToken(token: str = Depends(oauth2_scheme)) -> bool:
# Decode JWT token and verify user has Discord auth token
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
return bool(payload.get("discord_auth"))
except:
return False

@app.get("/signin")
async def signin(request: Request, token: str = Depends(oauth2_scheme)):
if userHasDiscordAuthToken(token):
# Redirect to the license page
return RedirectResponse(url_for('auth.license'))

oauth = createOAuthSession()
state = {
'nonce': generateKey(),
}
nextUrl = request.query_params.get('next')
if nextUrl:
state.update({'redirect': nextUrl})

state = saveAsState(state)
loginUrl, state = oauth.authorization_url(
Config.DISCORD_AUTHORIZE_URL, state=state)
response = templates.TemplateResponse('auth/authWithDiscord.html', {
"request": request,
"title": 'Waffler Sign In',
"loginUrl": loginUrl
})
response.set_cookie(key=Cookies.DISCORD_STATE, value=state)

removeStripeCookies()
return response

@app.get("/logout")
def logout(request: Request):
response = RedirectResponse(url=request.query_params.get('next', '/'))
# Your logic to remove all cookies, for example:
response.delete_cookie(key="your_cookie_name")
return response

@app.get("/oauth_callback")
def discordOAuthCallback(request: Request):
state = request.query_params.get('state', '')
if not state or state != request.cookies.get(Cookies.DISCORD_STATE):
return RedirectResponse('/')

oauth = createOAuthSession()
try:
token = oauth.fetch_token(
Config.DISCORD_TOKEN_URL,
client_secret=Config.DISCORD_CLIENT_SECRET,
authorization_response=request.url,
)
except Exception:
return RedirectResponse('/')

jwt_token = jwt.encode({"discord_auth": token}, SECRET_KEY, algorithm="HS256")

state_dict = getState(request)
if state_dict.get('redirect'):
params = state_dict.get('params', {})
redir = state_dict.get('redirect')
if params:
redir += '?' + urlencode(params)

response = RedirectResponse(redir)
else:
# Go to the user profile on default
response = RedirectResponse(url_for('userDashboard.userProfile'))

response.set_cookie(key=Cookies.DISCORD_TOKEN, value=jwt_token)
return response

def saveAsState(state: dict) -> str:
state = json.dumps(state)
return base64.b64encode(state.encode()).decode()

def getState(request: Request) -> dict:
state = request.cookies.get(Cookies.DISCORD_STATE, '')
if not state:
return {}

state = base64.b64decode(state).decode()
return json.loads(state)
Loading

0 comments on commit 9a0f9f8

Please sign in to comment.