Skip to content

Commit

Permalink
Use cookie to store JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
rkun123 authored and PigeonsHouse committed Apr 15, 2023
1 parent 6d0da33 commit b74c157
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 151 deletions.
123 changes: 16 additions & 107 deletions cruds/users/auth.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,33 @@
from fastapi.params import Security
from fastapi.security import HTTPBearer
from fastapi.security.http import HTTPAuthorizationCredentials
from db.models import Token, User
from db.models import User
from db import get_db
from os import stat, environ
from schemas.user import TokenData, TokenResponse
from os import environ
from schemas.user import TokenData
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2PasswordBearer
from fastapi import Depends, status
from fastapi import Depends, status, Cookie
from cruds.users import get_user
from sqlalchemy.orm.session import Session
from jose import JWTError, jwt
from passlib.context import CryptContext
from typing import Optional
from typing import Optional, Union
from datetime import datetime, timedelta
from schemas.user import Token as TokenSchema
from utils.discord import (
DiscordAccessTokenResponse,
discord_fetch_user,
discord_refresh_token,
discord_verify_user_belongs_to_valid_guild,
)

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token")
AUTH_COOKIE_KEY = environ.get("AUTH_COOKIE_KEY")

SECRET_KEY = environ.get("TOKEN_SECRET_KEY")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30


def verify_password(plain_password: str, hashed_password: str):
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def authenticate_user(db: Session, email: str, password: str):
user = get_user(db, email)

if not user:
raise HTTPException(status_code=403, detail="Authorization failed")

if not verify_password(password, user.password_hash):
raise HTTPException(status_code=403, detail="Password incorrect")

return user


def authenticate_discord_user(
discord_token: DiscordAccessTokenResponse, db: Session = Depends(get_db)
) -> TokenResponse:
) -> str:
discord_user = discord_fetch_user(discord_token.access_token)

u = db.query(User).filter(User.discord_user_id == discord_user.id).first()
Expand Down Expand Up @@ -83,104 +57,39 @@ def authenticate_discord_user(
)
db.commit()

token_response = create_tokens(u, db)
return token_response


def create_tokens(user: User, db: Session = Depends(get_db)) -> TokenResponse:
new_refresh_token = create_refresh_token(user, db)
new_access_token = create_access_token(new_refresh_token.user)

t = TokenResponse(
refresh_token=new_refresh_token.refresh_token,
access_token=new_access_token,
expired_at=new_refresh_token.expired_at.isoformat(),
)
return t


def renew_token(refresh_token_str: str, db: Session = Depends(get_db)) -> TokenResponse:
token, old_t = verify_refresh_token(refresh_token_str, db)

# force expire old refresh token
old_t.expired_at = datetime.now().isoformat()
db.commit()

u = db.query(User).filter(User.id == token.user.id).first()

return create_tokens(u, db)


def create_refresh_token(
user: User,
db: Session = Depends(get_db),
expired_delta: timedelta = timedelta(days=15),
) -> TokenSchema:
t = Token(refresh_token=None, user_id=user.id)
if expired_delta:
t.expired_at = datetime.now() + expired_delta
db.add(t)
db.commit()
token = TokenSchema.from_orm(t)
token = create_access_token(u)
return token


def verify_refresh_token(
refresh_token: str, db: Session = Depends(get_db)
) -> tuple[TokenSchema, Token]:
t = db.query(Token).filter(Token.refresh_token == refresh_token).first()
if t == None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Specified refresh token not found",
)

token = TokenSchema.from_orm(t)
if token.has_expired():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Specified refresh token has expired",
)

return [token, t]


def create_access_token(user: User, expires_delta: Optional[timedelta] = None):
def create_access_token(
user: User, expires_delta: Optional[timedelta] = timedelta(minutes=15)
) -> str:
to_encode = {"sub": user.email, "token_type": "bearer"}
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
expire = datetime.utcnow() + expires_delta
to_encode["exp"] = expire
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


security = HTTPBearer(auto_error=False)


class GetCurrentUser:
def __init__(self, auto_error: bool = True) -> None:
self.auto_error = auto_error

def __call__(
self,
db: Session = Depends(get_db),
credentials: HTTPAuthorizationCredentials = Security(security),
token: Union[str, None] = Cookie(default=None, alias=AUTH_COOKIE_KEY),
):
try:
if credentials == None:
if token == None:
return self.handle_error(detail="Credential is missing")
if credentials.scheme != "Bearer":
return self.handle_error(detail="Invalid scheme")

payload = jwt.decode(
credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]
)
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("sub")
if email is None:
return self.handle_error(detail="Email is missing")
token_data = TokenData(email=email)
print("token", token_data)
except JWTError:
return self.handle_error(detail="JWT error")
user = get_user(db, token_data.email)
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ services:
volumes:
- .:/api
ports:
- "5000:8000"
- "8000:8000"
env_file:
- .env
depends_on:
Expand Down
59 changes: 16 additions & 43 deletions routers/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from os import access
import os
from starlette.responses import RedirectResponse
from cruds.users.auth import authenticate_discord_user, get_password_hash, renew_token
from fastapi.exceptions import HTTPException
from cruds.users import create_user, get_user
from schemas.user import RefreshTokenExchangeRequest, TokenResponse, UserCreateRequest
from starlette.responses import Response, RedirectResponse
from cruds.users.auth import authenticate_discord_user
from fastapi.params import Depends
from db.models import User as UserModel
from schemas.user import User
from db import get_db
from sqlalchemy.orm import Session
from fastapi import APIRouter
Expand All @@ -18,53 +12,32 @@
CLIENT_ID = os.environ.get("DISCORD_CLIENT_ID")
CLIENT_SECRET = os.environ.get("DISCORD_CLIENT_SECRET")
HOST_URL = os.environ.get("HOST_URL")
AUTH_COOKIE_KEY = os.environ.get("AUTH_COOKIE_KEY")

auth_router = APIRouter()

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


@auth_router.post("/sign_up", response_model=User)
def sign_up(user_request: UserCreateRequest, db: Session = Depends(get_db)):
existing_user = get_user(db, user_request.email)
if existing_user:
raise HTTPException(status_code=400, detail="User has already exists")

hashed_password = get_password_hash(user_request.password)
user = UserModel(
name=user_request.name,
email=user_request.email,
password_hash=hashed_password,
display_name=user_request.display_name,
avatar_url=user_request.avatar_url,
)

created_user = create_user(db, user)

if not created_user:
raise HTTPException(status_code=500, detail="Couldn't create user")

return User.from_orm(created_user)


@auth_router.post("/token", response_model=TokenResponse)
async def refresh_token_exchange(
token_request: RefreshTokenExchangeRequest, db: Session = Depends(get_db)
):
token = renew_token(token_request.refresh_token, db)
return token


@auth_router.get("/discord")
async def discord_login_redirect():
redirect_url = f"https://discord.com/api/oauth2/authorize?client_id={CLIENT_ID}&redirect_uri={HOST_URL}/api/v1/auth/discord/callback&response_type=code&scope=identify email guilds"
return RedirectResponse(url=redirect_url)


@auth_router.get("/discord/callback", response_model=TokenResponse)
@auth_router.get("/discord/callback")
async def discord_callback(code: str = "", db: Session = Depends(get_db)):
r = discord_exchange_code(code)
token_response = authenticate_discord_user(r, db)
return RedirectResponse(
f"{FRONTEND_HOST_URL}/discord?access_token={token_response.access_token}&refresh_token={token_response.refresh_token}&expired_at={token_response.expired_at}"
token = authenticate_discord_user(r, db)
res = RedirectResponse(
FRONTEND_HOST_URL,
)
res.set_cookie(AUTH_COOKIE_KEY, token, httponly=True, samesite="strict")
return res


@auth_router.post("/logout")
async def logout():
response = Response()
response.set_cookie(AUTH_COOKIE_KEY, "")
return response

0 comments on commit b74c157

Please sign in to comment.