diff --git a/cruds/users/auth.py b/cruds/users/auth.py index f589e7c..9a12a53 100644 --- a/cruds/users/auth.py +++ b/cruds/users/auth.py @@ -1,59 +1,33 @@ -from fastapi.params import Security -from fastapi.security import HTTPBearer -from fastapi.security.http import HTTPAuthorizationCredentials -from db.models import Token, User +from db.models import User from db import get_db -from os import stat, environ -from schemas.user import TokenData, TokenResponse +from os import environ +from schemas.user import TokenData from fastapi.exceptions import HTTPException -from fastapi.security import OAuth2PasswordBearer -from fastapi import Depends, status +from fastapi import Depends, status, Cookie from cruds.users import get_user from sqlalchemy.orm.session import Session from jose import JWTError, jwt from passlib.context import CryptContext -from typing import Optional +from typing import Optional, Union from datetime import datetime, timedelta -from schemas.user import Token as TokenSchema from utils.discord import ( DiscordAccessTokenResponse, discord_fetch_user, - discord_refresh_token, discord_verify_user_belongs_to_valid_guild, ) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token") +AUTH_COOKIE_KEY = environ.get("AUTH_COOKIE_KEY") SECRET_KEY = environ.get("TOKEN_SECRET_KEY") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 -def verify_password(plain_password: str, hashed_password: str): - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) - - -def authenticate_user(db: Session, email: str, password: str): - user = get_user(db, email) - - if not user: - raise HTTPException(status_code=403, detail="Authorization failed") - - if not verify_password(password, user.password_hash): - raise HTTPException(status_code=403, detail="Password incorrect") - - return user - - def authenticate_discord_user( discord_token: DiscordAccessTokenResponse, db: Session = Depends(get_db) -) -> TokenResponse: +) -> str: discord_user = discord_fetch_user(discord_token.access_token) u = db.query(User).filter(User.discord_user_id == discord_user.id).first() @@ -83,82 +57,20 @@ def authenticate_discord_user( ) db.commit() - token_response = create_tokens(u, db) - return token_response - - -def create_tokens(user: User, db: Session = Depends(get_db)) -> TokenResponse: - new_refresh_token = create_refresh_token(user, db) - new_access_token = create_access_token(new_refresh_token.user) - - t = TokenResponse( - refresh_token=new_refresh_token.refresh_token, - access_token=new_access_token, - expired_at=new_refresh_token.expired_at.isoformat(), - ) - return t - - -def renew_token(refresh_token_str: str, db: Session = Depends(get_db)) -> TokenResponse: - token, old_t = verify_refresh_token(refresh_token_str, db) - - # force expire old refresh token - old_t.expired_at = datetime.now().isoformat() - db.commit() - - u = db.query(User).filter(User.id == token.user.id).first() - - return create_tokens(u, db) - - -def create_refresh_token( - user: User, - db: Session = Depends(get_db), - expired_delta: timedelta = timedelta(days=15), -) -> TokenSchema: - t = Token(refresh_token=None, user_id=user.id) - if expired_delta: - t.expired_at = datetime.now() + expired_delta - db.add(t) - db.commit() - token = TokenSchema.from_orm(t) + token = create_access_token(u) return token -def verify_refresh_token( - refresh_token: str, db: Session = Depends(get_db) -) -> tuple[TokenSchema, Token]: - t = db.query(Token).filter(Token.refresh_token == refresh_token).first() - if t == None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Specified refresh token not found", - ) - - token = TokenSchema.from_orm(t) - if token.has_expired(): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Specified refresh token has expired", - ) - - return [token, t] - - -def create_access_token(user: User, expires_delta: Optional[timedelta] = None): +def create_access_token( + user: User, expires_delta: Optional[timedelta] = timedelta(minutes=15) +) -> str: to_encode = {"sub": user.email, "token_type": "bearer"} - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) + expire = datetime.utcnow() + expires_delta to_encode["exp"] = expire encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -security = HTTPBearer(auto_error=False) - - class GetCurrentUser: def __init__(self, auto_error: bool = True) -> None: self.auto_error = auto_error @@ -166,21 +78,18 @@ def __init__(self, auto_error: bool = True) -> None: def __call__( self, db: Session = Depends(get_db), - credentials: HTTPAuthorizationCredentials = Security(security), + token: Union[str, None] = Cookie(default=None, alias=AUTH_COOKIE_KEY), ): try: - if credentials == None: + if token == None: return self.handle_error(detail="Credential is missing") - if credentials.scheme != "Bearer": - return self.handle_error(detail="Invalid scheme") - payload = jwt.decode( - credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM] - ) + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) email: str = payload.get("sub") if email is None: return self.handle_error(detail="Email is missing") token_data = TokenData(email=email) + print("token", token_data) except JWTError: return self.handle_error(detail="JWT error") user = get_user(db, token_data.email) diff --git a/docker-compose.yml b/docker-compose.yml index 6e43f5e..f6afc0e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,7 +14,7 @@ services: volumes: - .:/api ports: - - "5000:8000" + - "8000:8000" env_file: - .env depends_on: diff --git a/routers/auth/__init__.py b/routers/auth/__init__.py index ca1b30d..351bcf7 100644 --- a/routers/auth/__init__.py +++ b/routers/auth/__init__.py @@ -1,13 +1,7 @@ -from os import access import os -from starlette.responses import RedirectResponse -from cruds.users.auth import authenticate_discord_user, get_password_hash, renew_token -from fastapi.exceptions import HTTPException -from cruds.users import create_user, get_user -from schemas.user import RefreshTokenExchangeRequest, TokenResponse, UserCreateRequest +from starlette.responses import Response, RedirectResponse +from cruds.users.auth import authenticate_discord_user from fastapi.params import Depends -from db.models import User as UserModel -from schemas.user import User from db import get_db from sqlalchemy.orm import Session from fastapi import APIRouter @@ -18,53 +12,32 @@ CLIENT_ID = os.environ.get("DISCORD_CLIENT_ID") CLIENT_SECRET = os.environ.get("DISCORD_CLIENT_SECRET") HOST_URL = os.environ.get("HOST_URL") +AUTH_COOKIE_KEY = os.environ.get("AUTH_COOKIE_KEY") auth_router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -@auth_router.post("/sign_up", response_model=User) -def sign_up(user_request: UserCreateRequest, db: Session = Depends(get_db)): - existing_user = get_user(db, user_request.email) - if existing_user: - raise HTTPException(status_code=400, detail="User has already exists") - - hashed_password = get_password_hash(user_request.password) - user = UserModel( - name=user_request.name, - email=user_request.email, - password_hash=hashed_password, - display_name=user_request.display_name, - avatar_url=user_request.avatar_url, - ) - - created_user = create_user(db, user) - - if not created_user: - raise HTTPException(status_code=500, detail="Couldn't create user") - - return User.from_orm(created_user) - - -@auth_router.post("/token", response_model=TokenResponse) -async def refresh_token_exchange( - token_request: RefreshTokenExchangeRequest, db: Session = Depends(get_db) -): - token = renew_token(token_request.refresh_token, db) - return token - - @auth_router.get("/discord") async def discord_login_redirect(): redirect_url = f"https://discord.com/api/oauth2/authorize?client_id={CLIENT_ID}&redirect_uri={HOST_URL}/api/v1/auth/discord/callback&response_type=code&scope=identify email guilds" return RedirectResponse(url=redirect_url) -@auth_router.get("/discord/callback", response_model=TokenResponse) +@auth_router.get("/discord/callback") async def discord_callback(code: str = "", db: Session = Depends(get_db)): r = discord_exchange_code(code) - token_response = authenticate_discord_user(r, db) - return RedirectResponse( - f"{FRONTEND_HOST_URL}/discord?access_token={token_response.access_token}&refresh_token={token_response.refresh_token}&expired_at={token_response.expired_at}" + token = authenticate_discord_user(r, db) + res = RedirectResponse( + FRONTEND_HOST_URL, ) + res.set_cookie(AUTH_COOKIE_KEY, token, httponly=True, samesite="strict") + return res + + +@auth_router.post("/logout") +async def logout(): + response = Response() + response.set_cookie(AUTH_COOKIE_KEY, "") + return response