From 04ca5b1c540b25bb2844bd58b4fdd6fc1ee5e2c5 Mon Sep 17 00:00:00 2001 From: Chris Lovering Date: Tue, 27 Aug 2024 22:21:52 +0100 Subject: [PATCH] Add debug endpoint for creating users --- thallium-backend/src/routes/debug.py | 43 ++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/thallium-backend/src/routes/debug.py b/thallium-backend/src/routes/debug.py index 4d81741..691ac3f 100644 --- a/thallium-backend/src/routes/debug.py +++ b/thallium-backend/src/routes/debug.py @@ -1,15 +1,19 @@ import logging +from datetime import UTC, datetime -from fastapi import APIRouter +import argon2 +from fastapi import APIRouter, HTTPException from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from src.auth import build_jwt -from src.dto import Voucher -from src.orm import Voucher as DBVoucher +from src.dto import UserPermission, Voucher +from src.orm import User as DBUser, Voucher as DBVoucher from src.settings import DBSession, PrintfulClient router = APIRouter(tags=["debug"], prefix="/debug") log = logging.getLogger(__name__) +ph = argon2.PasswordHasher() @router.get("/templates") @@ -54,3 +58,36 @@ async def get_vouchers(db: DBSession, *, only_active: bool = True) -> list[Vouch async def get_user_jwt(user_id: str) -> str: """Return the user_id's JWT.""" return build_jwt(user_id, "user") + + +@router.post("/user") +async def create_user( # noqa: PLR0913 + db: DBSession, + username: str, + password: str, + *, + require_password_change: bool = True, + password_reset_code: str | None = None, + active: bool = True, + permissions: int = ~UserPermission(0), +) -> dict: + """Create a user with the given username & pass.""" + db_user = DBUser( + username=username, + password_hash=ph.hash(password), + permissions=permissions, + require_password_change=require_password_change, + password_reset_code=password_reset_code, + active=active, + password_set_at=datetime.now(UTC), + ) + db.add(db_user) + + try: + await db.flush() + except IntegrityError as e: + raise HTTPException(400, detail=str(e)) from e + + stmt = select(DBUser).where(DBUser.username == username) + db_user = await db.scalar(stmt) + return {key: val for key, val in db_user.__dict__.items() if not key.startswith("_")}