Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
feat(auth): add database method for retrieving by field value
Browse files Browse the repository at this point in the history
Signed-off-by: CyberFlame <[email protected]>
  • Loading branch information
CyberFlameGO committed Oct 18, 2023
1 parent 68d8499 commit f7d3767
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 55 deletions.
48 changes: 43 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum
from typing import Annotated, Tuple, Type

import sentry_sdk
from fastapi import Depends, FastAPI, Request, status
from fastapi import Depends, FastAPI, Request, status, HTTPException
from fastapi.responses import HTMLResponse, RedirectResponse, Response
from fastapi.staticfiles import StaticFiles
from fastapi.security import OAuth2PasswordBearer
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel

import models
import utils
from models import Base, Game
from models import Base, Game, PydanticUser, Token
from utils import ACCESS_TOKEN_EXPIRE_MINUTES, authenticate_user, create_access_token, get_current_active_user

sentry_sdk.init(
dsn="https://[email protected]/4504873492480000",
Expand Down Expand Up @@ -206,5 +207,42 @@ async def get_match(request: Request, match_id: int, session: Session):
:param match_id:
:return:
"""
match = await db.retrieve(session, models.Match, match_id)
try:
match = await db.retrieve(session, models.Match, match_id)
except NoResultFound:
# TODO: adjust with a proper page regarding no match found with id
return Response(status_code=status.HTTP_404_NOT_FOUND)
return templates.TemplateResponse("matches.html", {"request": request, "id": match_id, "match": match})


@app.post("/token", response_model=Token)
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
session: Session,
):
user = await authenticate_user(session, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}


@app.get("/users/me/", response_model=PydanticUser)
async def read_users_me(
current_user: Annotated[PydanticUser, Depends(get_current_active_user)]
):
return current_user


@app.get("/users/me/items/")
async def read_own_items(
current_user: Annotated[PydanticUser, Depends(get_current_active_user)]
):
return [{"item_id": "Foo", "owner": current_user.username}]
6 changes: 3 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
"""

from .db_utils import * # noqa F401
from .auth import * # noqa F401
57 changes: 15 additions & 42 deletions utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from jose import JWTError, jwt
from passlib.context import CryptContext

from models import User
from models.pydantic import PydanticUser, Token, TokenData, UserInDB
from utils import Database

# to get a string like this run:
# openssl rand -hex 32
Expand Down Expand Up @@ -41,14 +43,17 @@ def get_password_hash(password):
return pwd_context.hash(password)


def get_user(db, username: str):
if username in db:
user_dict = db[username]
return UserInDB(**user_dict)
async def get_user(session, username: str):
data = await Database.retrieve_by_field(session, User, User.username, username)
data_dict = data.__dict__
if username in data_dict:
return UserInDB(**data_dict)
# user_dict = db[username]
# return UserInDB(**user_dict)


def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
async def authenticate_user(session, username: str, password: str):
user = await get_user(session, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
Expand All @@ -59,15 +64,15 @@ def authenticate_user(fake_db, username: str, password: str):
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
expire = datetime.now(UTC) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
async def get_current_user(session, token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
Expand All @@ -81,7 +86,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=token_data.username)
user = get_user(session, username=token_data.username)
if user is None:
raise credentials_exception
return user
Expand All @@ -93,35 +98,3 @@ async def get_current_active_user(
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user


@app.post("/token", response_model=Token)
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
):
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}


@app.get("/users/me/", response_model=PydanticUser)
async def read_users_me(
current_user: Annotated[PydanticUser, Depends(get_current_active_user)]
):
return current_user


@app.get("/users/me/items/")
async def read_own_items(
current_user: Annotated[PydanticUser, Depends(get_current_active_user)]
):
return [{"item_id": "Foo", "owner": current_user.username}]
28 changes: 23 additions & 5 deletions utils/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from models import Base

type base_type = Type[Base]

class Database(object):
"""
Expand All @@ -35,6 +36,8 @@ def __init__(self, db_name: str):
async def connect(self) -> None:
"""
Initializes the database connection - creates sessionmaker and engine
'connect' is a misleading term here, but calling it 'init' would be even more misleading so I've opted for this
terminology.
:return:
"""
self.engine: AsyncEngine = create_async_engine(f"sqlite+aiosqlite:///{self._db_name}")
Expand Down Expand Up @@ -105,7 +108,7 @@ async def update(session: AsyncSession, model: Type[Base], identifier: int, data
raise DatabaseError(f"Exception encountered whilst executing: {e}")

@staticmethod
async def retrieve(session: AsyncSession, model: Type[Base], identifier: int) -> Optional[Base]:
async def retrieve(session: AsyncSession, model: base_type, identifier: int) -> base_type:
"""
Retrieves a record by primary key from a table in the database TODO: adjust return type
TODO: Perhaps use get_one instead of get -
Expand All @@ -115,10 +118,25 @@ async def retrieve(session: AsyncSession, model: Type[Base], identifier: int) ->
:param model:
:return:
"""
return await session.get(model, identifier)
return await session.get_one(model, identifier)

@staticmethod
async def dump_all(session: AsyncSession, model: Type[Base]) -> Sequence[Base]:
async def retrieve_by_field(session: AsyncSession, model: base_type, field, identifier) -> base_type:
"""
Retrieves a record by field value from a table in the database
Intended to retrieve only the first value (use-case: table with UNIQUE constraint)
:param session:
:param model:
:param field: The field in model.field format
:param identifier:
:return:
"""
statement = select(model).where(field == identifier)
executed = await session.execute(statement)
return executed.scalar_one()

@staticmethod
async def dump_all(session: AsyncSession, model: base_type) -> Sequence[Base]:
"""
Dumps all records for a model in the database TODO: finish writing function
:param session:
Expand Down Expand Up @@ -146,7 +164,7 @@ async def dump_by_field_descending(session: AsyncSession, field, label, limit: O
return executed.all()

@staticmethod
async def remove_record(session: AsyncSession, model: Type[Base], identifier: int):
async def remove_record(session: AsyncSession, model: base_type, identifier: int):
"""
Removes a record from the database
:param session:
Expand All @@ -166,7 +184,7 @@ async def remove_record(session: AsyncSession, model: Type[Base], identifier: in
raise DatabaseError(f"Exception encountered whilst executing: {e}")

@staticmethod
async def has_existed(session: AsyncSession, model: Type[Base], identifier: int) -> bool:
async def has_existed(session: AsyncSession, model: base_type, identifier: int) -> bool:
"""
Checks if a record has existed in the database, relying on the auto-incrementing primary key.
There are various flaws with this method/implementation, but they are tolerable considering the method is
Expand Down

0 comments on commit f7d3767

Please sign in to comment.