generated from owl-corp/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add debug endpoints and implement token auth
Co-authored-by: Joe Banks <[email protected]>
- Loading branch information
1 parent
47f248c
commit fe6abd8
Showing
25 changed files
with
419 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 0 additions & 51 deletions
51
thallium-backend/migrations/versions/1723831312-ac28edf8dd84_users_and_products.py
This file was deleted.
Oops, something went wrong.
70 changes: 70 additions & 0 deletions
70
thallium-backend/migrations/versions/1724025217-bd897d0f21e1_add_products_users_vouchers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
Add products, users and vouchers. | ||
Revision ID: bd897d0f21e1 | ||
Revises: | ||
Create Date: 2024-08-18 23:53:37.211777+00:00 | ||
""" | ||
|
||
import sqlalchemy as sa | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision = "bd897d0f21e1" | ||
down_revision = None | ||
branch_labels = None | ||
depends_on = None | ||
|
||
|
||
def upgrade() -> None: | ||
"""Apply this migration.""" | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table( | ||
"products", | ||
sa.Column("product_id", sa.Integer(), nullable=False), | ||
sa.Column("name", sa.String(), nullable=False), | ||
sa.Column("description", sa.String(), nullable=False), | ||
sa.Column("price", sa.Numeric(), nullable=False), | ||
sa.Column("image", sa.LargeBinary(), nullable=False), | ||
sa.Column("id", sa.Uuid(), server_default=sa.text("gen_random_uuid()"), nullable=False), | ||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), | ||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), | ||
sa.PrimaryKeyConstraint("product_id", "id", name=op.f("products_pk")), | ||
) | ||
op.create_table( | ||
"users", | ||
sa.Column("permissions", sa.Integer(), nullable=False), | ||
sa.Column("id", sa.Uuid(), server_default=sa.text("gen_random_uuid()"), nullable=False), | ||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), | ||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), | ||
sa.PrimaryKeyConstraint("id", name=op.f("users_pk")), | ||
) | ||
op.create_table( | ||
"vouchers", | ||
sa.Column("voucher_code", sa.String(), nullable=False), | ||
sa.Column("active", sa.Boolean(), nullable=False), | ||
sa.Column("balance", sa.Numeric(), nullable=False), | ||
sa.Column("id", sa.Uuid(), server_default=sa.text("gen_random_uuid()"), nullable=False), | ||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), | ||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), | ||
sa.PrimaryKeyConstraint("id", name=op.f("vouchers_pk")), | ||
) | ||
op.create_index( | ||
"ix_unique_active_voucher_code", | ||
"vouchers", | ||
["voucher_code"], | ||
unique=True, | ||
postgresql_where=sa.text("active"), | ||
) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
"""Revert this migration.""" | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_index(op.f("vouchers_voucher_code_ix"), table_name="vouchers") | ||
op.drop_index("ix_unique_active_voucher_code", table_name="vouchers", postgresql_where=sa.text("active")) | ||
op.drop_table("vouchers") | ||
op.drop_table("users") | ||
op.drop_table("products") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import asyncio | ||
|
||
from src.orm import Voucher | ||
from src.settings import Connections | ||
|
||
|
||
async def main() -> None: | ||
"""Seed the database with some test data.""" | ||
async with Connections.DB_SESSION_MAKER() as session, session.begin(): | ||
session.add_all( | ||
[ | ||
Voucher(voucher_code="k1p", balance="13.37", active=False), | ||
Voucher(voucher_code="k1p", balance="13.37", active=False), | ||
Voucher(voucher_code="k1p", balance="13.37"), | ||
] | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import logging | ||
import typing as t | ||
from datetime import UTC, datetime, timedelta | ||
from enum import IntFlag | ||
from uuid import uuid4 | ||
|
||
import jwt | ||
from fastapi import HTTPException, Request | ||
from fastapi.security import HTTPAuthorizationCredentials | ||
from fastapi.security.http import HTTPBase | ||
|
||
from src.dto.users import User, UserPermission | ||
from src.settings import CONFIG | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class UserTypes(IntFlag): | ||
"""All types of users.""" | ||
|
||
VOUCHER_USER = 2**0 | ||
REGULAR_USER = 2**1 | ||
|
||
|
||
class TokenAuth(HTTPBase): | ||
"""Ensure all requests with this auth enabled include an auth header with the expected token.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
auto_error: bool = True, | ||
allow_vouchers: bool = False, | ||
allow_regular_users: bool = False, | ||
) -> None: | ||
super().__init__(scheme="token", auto_error=auto_error) | ||
self.allow_vouchers = allow_vouchers | ||
self.allow_regular_users = allow_regular_users | ||
|
||
async def __call__(self, request: Request) -> HTTPAuthorizationCredentials: | ||
"""Parse the token in the auth header, and check it matches with the expected token.""" | ||
creds: HTTPAuthorizationCredentials = await super().__call__(request) | ||
if creds.scheme.lower() != "token": | ||
raise HTTPException( | ||
status_code=401, | ||
detail="Incorrect scheme passed", | ||
) | ||
if self.allow_regular_users and creds.credentials == CONFIG.super_admin_token.get_secret_value(): | ||
request.state.user = User(user_id=uuid4(), permissions=~UserPermission(0)) | ||
return | ||
|
||
jwt_data = verify_jwt( | ||
creds.credentials, | ||
allow_vouchers=self.allow_vouchers, | ||
allow_regular_users=self.allow_regular_users, | ||
) | ||
if not jwt_data: | ||
raise HTTPException( | ||
status_code=403, | ||
detail="Invalid authentication credentials", | ||
) | ||
if jwt_data["iss"] == "thallium:user": | ||
request.state.user_id = jwt_data["sub"] | ||
else: | ||
request.state.voucher_id = jwt_data["sub"] | ||
|
||
|
||
def build_jwt( | ||
identifier: str, | ||
user_type: t.Literal["voucher", "user"], | ||
) -> str: | ||
"""Build & sign a jwt.""" | ||
return jwt.encode( | ||
payload={ | ||
"sub": identifier, | ||
"iss": f"thallium:{user_type}", | ||
"exp": datetime.now(tz=UTC) + timedelta(minutes=30), | ||
"nbf": datetime.now(tz=UTC) - timedelta(minutes=1), | ||
}, | ||
key=CONFIG.signing_key.get_secret_value(), | ||
) | ||
|
||
|
||
def verify_jwt( | ||
jwt_data: str, | ||
*, | ||
allow_vouchers: bool, | ||
allow_regular_users: bool, | ||
) -> dict | None: | ||
"""Return and verify the given JWT.""" | ||
issuers = [] | ||
if allow_vouchers: | ||
issuers.append("thallium:voucher") | ||
if allow_regular_users: | ||
issuers.append("thallium:user") | ||
try: | ||
return jwt.decode( | ||
jwt_data, | ||
key=CONFIG.signing_key.get_secret_value(), | ||
issuer=issuers, | ||
algorithms=("HS256",), | ||
options={"require": ["exp", "iss", "sub", "nbf"]}, | ||
) | ||
except jwt.InvalidIssuerError as e: | ||
raise HTTPException(403, "Your user type does not have access to this resource") from e | ||
except jwt.InvalidSignatureError as e: | ||
raise HTTPException(401, "Invalid JWT signature") from e | ||
except (jwt.DecodeError, jwt.MissingRequiredClaimError, jwt.InvalidAlgorithmError) as e: | ||
raise HTTPException(401, "Invalid JWT passed") from e | ||
except (jwt.ImmatureSignatureError, jwt.ExpiredSignatureError) as e: | ||
raise HTTPException(401, "JWT not valid for current time") from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .login import VoucherClaim, VoucherLogin | ||
from .users import User | ||
from .vouchers import Voucher | ||
|
||
__all__ = ("LoginData", "User", "Voucher", "VoucherClaim", "VoucherLogin") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class VoucherLogin(BaseModel): | ||
"""The data needed to login with a voucher.""" | ||
|
||
voucher_code: str | ||
|
||
|
||
class VoucherClaim(VoucherLogin): | ||
"""A JWT for a verified voucher.""" | ||
|
||
jwt: str |
Oops, something went wrong.