From d97c11006e294686e03e6244714d4e3a2b27e12d Mon Sep 17 00:00:00 2001 From: F-G Fernandez <26927750+frgfm@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:38:51 +0100 Subject: [PATCH] feat(login): add token validation route (#115) * feat(login): add a basic token validation route * perf(dependencies): speed up token & scope validation --- src/app/api/api_v1/endpoints/code.py | 9 +++--- src/app/api/api_v1/endpoints/guidelines.py | 35 ++++++++++++---------- src/app/api/api_v1/endpoints/login.py | 13 ++++++-- src/app/api/api_v1/endpoints/repos.py | 18 ++++++----- src/app/api/api_v1/endpoints/users.py | 27 +++++++++-------- src/app/api/dependencies.py | 26 +++++++++------- 6 files changed, 74 insertions(+), 54 deletions(-) diff --git a/src/app/api/api_v1/endpoints/code.py b/src/app/api/api_v1/endpoints/code.py index ebccc18..c72c644 100644 --- a/src/app/api/api_v1/endpoints/code.py +++ b/src/app/api/api_v1/endpoints/code.py @@ -7,9 +7,10 @@ from fastapi import APIRouter, Security, status from fastapi.responses import StreamingResponse -from app.api.dependencies import get_current_user -from app.models import User, UserScope +from app.api.dependencies import get_token_payload +from app.models import UserScope from app.schemas.code import ChatHistory +from app.schemas.login import TokenPayload from app.services.ollama import ollama_client from app.services.telemetry import telemetry_client @@ -19,9 +20,9 @@ @router.post("/chat", status_code=status.HTTP_200_OK, summary="Chat with our code model") async def chat( payload: ChatHistory, - user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> StreamingResponse: - telemetry_client.capture(user.id, event="compute-chat") + telemetry_client.capture(token_payload.user_id, event="compute-chat") # Run analysis return StreamingResponse( ollama_client.chat(payload.model_dump()["messages"]).iter_content(chunk_size=8192), diff --git a/src/app/api/api_v1/endpoints/guidelines.py b/src/app/api/api_v1/endpoints/guidelines.py index f1040c5..7ff5b25 100644 --- a/src/app/api/api_v1/endpoints/guidelines.py +++ b/src/app/api/api_v1/endpoints/guidelines.py @@ -7,13 +7,14 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Security, status -from app.api.dependencies import get_current_user, get_guideline_crud +from app.api.dependencies import get_guideline_crud, get_token_payload from app.crud import GuidelineCRUD -from app.models import Guideline, User, UserScope +from app.models import Guideline, UserScope from app.schemas.guidelines import ( ContentUpdate, GuidelineContent, ) +from app.schemas.login import TokenPayload from app.services.telemetry import telemetry_client router = APIRouter() @@ -23,28 +24,28 @@ async def create_guideline( payload: GuidelineContent, guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Guideline: - telemetry_client.capture(user.id, event="guideline-creation") - return await guidelines.create(Guideline(creator_id=user.id, **payload.model_dump())) + telemetry_client.capture(token_payload.user_id, event="guideline-creation") + return await guidelines.create(Guideline(creator_id=token_payload.user_id, **payload.model_dump())) @router.get("/{guideline_id}", status_code=status.HTTP_200_OK, summary="Read a specific guideline") async def get_guideline( guideline_id: int = Path(..., gt=0), guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Guideline: - telemetry_client.capture(user.id, event="guideline-get", properties={"guideline_id": guideline_id}) + telemetry_client.capture(token_payload.user_id, event="guideline-get", properties={"guideline_id": guideline_id}) return cast(Guideline, await guidelines.get(guideline_id, strict=True)) @router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all the guidelines") async def fetch_guidelines( guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> List[Guideline]: - telemetry_client.capture(user.id, event="guideline-fetch") + telemetry_client.capture(token_payload.user_id, event="guideline-fetch") return [elt for elt in await guidelines.fetch_all()] @@ -53,11 +54,13 @@ async def update_guideline_content( payload: GuidelineContent, guideline_id: int = Path(..., gt=0), guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Guideline: - telemetry_client.capture(user.id, event="guideline-update-content", properties={"guideline_id": guideline_id}) + telemetry_client.capture( + token_payload.user_id, event="guideline-update-content", properties={"guideline_id": guideline_id} + ) guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True)) - if user.scope != UserScope.ADMIN and user.id != guideline.creator_id: + if UserScope.ADMIN not in token_payload.scopes and token_payload.user_id != guideline.creator_id: raise HTTPException(status.HTTP_403_FORBIDDEN, "Insufficient permissions.") return await guidelines.update(guideline_id, ContentUpdate(**payload.model_dump())) @@ -66,11 +69,13 @@ async def update_guideline_content( async def delete_guideline( guideline_id: int = Path(..., gt=0), guidelines: GuidelineCRUD = Depends(get_guideline_crud), - user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> None: - telemetry_client.capture(user.id, event="guideline-deletion", properties={"guideline_id": guideline_id}) + telemetry_client.capture( + token_payload.user_id, event="guideline-deletion", properties={"guideline_id": guideline_id} + ) guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True)) - if user.scope != UserScope.ADMIN and user.id != guideline.creator_id: + if UserScope.ADMIN not in token_payload.scopes and token_payload.user_id != guideline.creator_id: raise HTTPException(status.HTTP_403_FORBIDDEN, "Insufficient permissions.") await guidelines.delete(guideline_id) diff --git a/src/app/api/api_v1/endpoints/login.py b/src/app/api/api_v1/endpoints/login.py index 1d30e05..e856231 100644 --- a/src/app/api/api_v1/endpoints/login.py +++ b/src/app/api/api_v1/endpoints/login.py @@ -4,18 +4,18 @@ # See LICENSE or go to for full license details. -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Security, status from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from pydantic import HttpUrl from app.api.api_v1.endpoints.users import _create_user -from app.api.dependencies import get_user_crud +from app.api.dependencies import get_token_payload, get_user_crud from app.core.config import settings from app.core.security import create_access_token, verify_password from app.crud import UserCRUD from app.models import UserScope -from app.schemas.login import GHAccessToken, Token, TokenRequest +from app.schemas.login import GHAccessToken, Token, TokenPayload, TokenRequest from app.schemas.services import GHToken from app.schemas.users import UserCreate from app.services.github import gh_client @@ -97,3 +97,10 @@ async def login_with_github_token( token = await create_access_token(token_data, settings.ACCESS_TOKEN_UNLIMITED_MINUTES) return Token(access_token=token, token_type="bearer") # noqa S106 + + +@router.get("/validate", status_code=status.HTTP_200_OK, summary="Check token validity") +async def check_token_validity( + payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.USER, UserScope.ADMIN]), +) -> TokenPayload: + return payload diff --git a/src/app/api/api_v1/endpoints/repos.py b/src/app/api/api_v1/endpoints/repos.py index 6d314bc..c2c2742 100644 --- a/src/app/api/api_v1/endpoints/repos.py +++ b/src/app/api/api_v1/endpoints/repos.py @@ -8,9 +8,10 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Security, status -from app.api.dependencies import get_current_user, get_repo_crud +from app.api.dependencies import get_current_user, get_repo_crud, get_token_payload from app.crud import RepositoryCRUD from app.models import Provider, Repository, User, UserScope +from app.schemas.login import TokenPayload from app.schemas.repos import RepoRegistration from app.services.github import gh_client from app.services.slack import slack_client @@ -39,7 +40,8 @@ async def register_repo( detail="Expected `github_token` to check access.", ) # Check provider link - if not isinstance(user.provider_user_id, str) or user.provider_user_id.partition("|")[0] != Provider.GITHUB: + # if not isinstance(user.provider_user_id, str) or user.provider_user_id.partition("|")[0] != Provider.GITHUB: + if not isinstance(user.provider_user_id, int): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No GitHub profile linked to your account.") # Finally, check GH permission gh_user = gh_client.get_my_user(payload.github_token) @@ -78,18 +80,18 @@ async def register_repo( async def get_repo( repo_id: int = Path(..., gt=0), repos: RepositoryCRUD = Depends(get_repo_crud), - user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Repository: - telemetry_client.capture(user.id, event="repo-get", properties={"repo_id": repo_id}) + telemetry_client.capture(token_payload.user_id, event="repo-get", properties={"repo_id": repo_id}) return cast(Repository, await repos.get(repo_id, strict=True)) @router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all repositories") async def fetch_repos( repos: RepositoryCRUD = Depends(get_repo_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> List[Repository]: - telemetry_client.capture(user.id, event="repo-fetch") + telemetry_client.capture(token_payload.user_id, event="repo-fetch") return [elt for elt in await repos.fetch_all()] @@ -97,9 +99,9 @@ async def fetch_repos( async def delete_repo( repo_id: int = Path(..., gt=0), repos: RepositoryCRUD = Depends(get_repo_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> None: - telemetry_client.capture(user.id, event="repo-delete", properties={"repo_id": repo_id}) + telemetry_client.capture(token_payload.user_id, event="repo-delete", properties={"repo_id": repo_id}) await repos.delete(repo_id) diff --git a/src/app/api/api_v1/endpoints/users.py b/src/app/api/api_v1/endpoints/users.py index ff8ba3c..84953f7 100644 --- a/src/app/api/api_v1/endpoints/users.py +++ b/src/app/api/api_v1/endpoints/users.py @@ -7,10 +7,11 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Security, status -from app.api.dependencies import get_current_user, get_user_crud +from app.api.dependencies import get_token_payload, get_user_crud from app.core.security import hash_password from app.crud import UserCRUD from app.models import Provider, User, UserScope +from app.schemas.login import TokenPayload from app.schemas.users import Cred, CredHash, UserCreate from app.services.github import gh_client from app.services.slack import slack_client @@ -19,7 +20,7 @@ router = APIRouter() -async def _create_user(payload: UserCreate, users: UserCRUD, requester: Union[User, None] = None) -> User: +async def _create_user(payload: UserCreate, users: UserCRUD, requester_id: Union[int, None] = None) -> User: valid_creds = False user_props = {"login": payload.login, "provider_login": None, "name": None, "twitter_username": None} notif_info = [] @@ -89,7 +90,7 @@ async def _create_user(payload: UserCreate, users: UserCRUD, requester: Union[Us # Assume the requester is the new user if none was specified telemetry_client.capture( - requester.id if isinstance(requester, User) else user.id, + requester_id if isinstance(requester_id, int) else user.id, event="user-creation", properties={"created_user_id": user.id}, ) @@ -102,27 +103,27 @@ async def _create_user(payload: UserCreate, users: UserCRUD, requester: Union[Us async def create_user( payload: UserCreate, users: UserCRUD = Depends(get_user_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> User: - return await _create_user(payload, users, user) + return await _create_user(payload, users, token_payload.user_id) @router.get("/{user_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific user") async def get_user( user_id: int = Path(..., gt=0), users: UserCRUD = Depends(get_user_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> User: - telemetry_client.capture(user.id, event="user-get", properties={"user_id": user_id}) + telemetry_client.capture(token_payload.user_id, event="user-get", properties={"user_id": user_id}) return cast(User, await users.get(user_id, strict=True)) @router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all the users") async def fetch_users( users: UserCRUD = Depends(get_user_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> List[User]: - telemetry_client.capture(user.id, event="user-fetch") + telemetry_client.capture(token_payload.user_id, event="user-fetch") return [elt for elt in await users.fetch_all()] @@ -131,9 +132,9 @@ async def update_user_password( payload: Cred, user_id: int = Path(..., gt=0), users: UserCRUD = Depends(get_user_crud), - user: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> User: - telemetry_client.capture(user.id, event="user-pwd", properties={"user_id": user_id}) + telemetry_client.capture(token_payload.user_id, event="user-pwd", properties={"user_id": user_id}) pwd = await hash_password(payload.password) return await users.update(user_id, CredHash(hashed_password=pwd)) @@ -142,7 +143,7 @@ async def update_user_password( async def delete_user( user_id: int = Path(..., gt=0), users: UserCRUD = Depends(get_user_crud), - _: User = Security(get_current_user, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), ) -> None: - telemetry_client.capture(user_id, event="user-deletion", properties={"user_id": user_id}) + telemetry_client.capture(token_payload.user_id, event="user-deletion", properties={"user_id": user_id}) await users.delete(user_id) diff --git a/src/app/api/dependencies.py b/src/app/api/dependencies.py index ae99c11..93d9d79 100644 --- a/src/app/api/dependencies.py +++ b/src/app/api/dependencies.py @@ -18,7 +18,7 @@ from app.models import User, UserScope from app.schemas.login import TokenPayload -__all__ = ["get_current_user", "get_guideline_crud", "get_repo_crud", "get_user_crud"] +__all__ = ["get_current_user", "get_guideline_crud", "get_repo_crud", "get_token_payload", "get_user_crud"] # Scope definition oauth2_scheme = OAuth2PasswordBearer( @@ -42,17 +42,11 @@ async def get_guideline_crud(session: AsyncSession = Depends(get_session)) -> Gu return GuidelineCRUD(session=session) -async def get_current_user( +async def get_token_payload( security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme), - users: UserCRUD = Depends(get_user_crud), -) -> User: - """Dependency to use as fastapi.security.Security with scopes. - - >>> @app.get("/users/me") - >>> async def read_users_me(current_user: User = Security(get_current_user, scopes=["me"])): - >>> return current_user - """ +) -> TokenPayload: + """Dependency to use as fastapi.security.Security with scopes.""" authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' if security_scopes.scopes else "Bearer" try: @@ -82,4 +76,14 @@ async def get_current_user( headers={"WWW-Authenticate": authenticate_value}, ) - return cast(User, await users.get(user_id, strict=True)) + return token_data + + +async def get_current_user( + security_scopes: SecurityScopes, + token: str = Depends(oauth2_scheme), + users: UserCRUD = Depends(get_user_crud), +) -> User: + """Dependency to use as fastapi.security.Security with scopes""" + token_payload = await get_token_payload(security_scopes, token) + return cast(User, await users.get(token_payload.user_id, strict=True))