Skip to content

Commit

Permalink
Refactor: Use BaseSettings of pydantic for configuration
Browse files Browse the repository at this point in the history
This is the recommended way as per the documentation of `fastapi` as it
can be easily made available to routes.
  • Loading branch information
sphuber committed May 8, 2023
1 parent 3f14ea2 commit 520bd52
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
33 changes: 27 additions & 6 deletions aiida_restapi/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
# -*- coding: utf-8 -*-
"""Configuration of API"""
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
"""Configuration settings for the application."""
from functools import lru_cache

from pydantic import BaseSettings


class Settings(BaseSettings):
"""Configuration settings for the application."""

secret_key: str = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
"""The secret key used to create access tokens."""

secret_key_algoritm: str = "HS256"
"""The algorithm used to create access tokens."""

access_token_expire_minutes: int = 30
"""The number of minutes an access token remains valid."""


@lru_cache()
def get_settings():
"""Return the configuration settings for the application.
This function is cached and should be used preferentially over constructing ``Settings`` directly.
"""
return Settings()


fake_users_db = {
"[email protected]": {
Expand Down
30 changes: 19 additions & 11 deletions aiida_restapi/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from passlib.context import CryptContext
from pydantic import BaseModel

from aiida_restapi import config
from aiida_restapi.models import User

from ..config import Settings, fake_users_db, get_settings


class Token(BaseModel):
access_token: str
Expand Down Expand Up @@ -67,32 +68,40 @@ def authenticate_user(fake_db: dict, email: str, password: str) -> Optional[User
return user


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
def create_access_token(
settings: Settings, data: dict, expires_delta: Optional[timedelta] = None
) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.secret_key, algorithm=settings.secret_key_algoritm
)
return encoded_jwt


async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
async def get_current_user(
token: str = Depends(oauth2_scheme), settings: Settings = Depends(get_settings)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM])
payload = jwt.decode(
token, settings.secret_key, algorithms=[settings.secret_key_algoritm]
)
email: str = payload.get("sub")
if email is None:
raise credentials_exception
token_data = TokenData(email=email)
except JWTError:
raise credentials_exception # pylint: disable=raise-missing-from
user = get_user(config.fake_users_db, email=token_data.email)
user = get_user(fake_users_db, email=token_data.email)
if user is None:
raise credentials_exception
return user
Expand All @@ -109,19 +118,18 @@ async def get_current_active_user(
@router.post("/token", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
settings: Settings = Depends(get_settings),
) -> Dict[str, Any]:
user = authenticate_user(
config.fake_users_db, form_data.username, form_data.password
)
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": user.email}, expires_delta=access_token_expires
settings, data={"sub": user.email}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}

Expand Down

0 comments on commit 520bd52

Please sign in to comment.