Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
feat(login): add token validation route (#115)
Browse files Browse the repository at this point in the history
* feat(login): add a basic token validation route

* perf(dependencies): speed up token & scope validation
  • Loading branch information
frgfm authored Mar 6, 2024
1 parent 5692948 commit d97c110
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 54 deletions.
9 changes: 5 additions & 4 deletions src/app/api/api_v1/endpoints/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from fastapi import APIRouter, Security, status
from fastapi.responses import StreamingResponse

from app.api.dependencies import get_current_user
from app.models import User, UserScope
from app.api.dependencies import get_token_payload
from app.models import UserScope
from app.schemas.code import ChatHistory
from app.schemas.login import TokenPayload
from app.services.ollama import ollama_client
from app.services.telemetry import telemetry_client

Expand All @@ -19,9 +20,9 @@
@router.post("/chat", status_code=status.HTTP_200_OK, summary="Chat with our code model")
async def chat(
payload: ChatHistory,
user: User = Security(get_current_user, scopes=[UserScope.ADMIN, UserScope.USER]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> StreamingResponse:
telemetry_client.capture(user.id, event="compute-chat")
telemetry_client.capture(token_payload.user_id, event="compute-chat")
# Run analysis
return StreamingResponse(
ollama_client.chat(payload.model_dump()["messages"]).iter_content(chunk_size=8192),
Expand Down
35 changes: 20 additions & 15 deletions src/app/api/api_v1/endpoints/guidelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

from fastapi import APIRouter, Depends, HTTPException, Path, Security, status

from app.api.dependencies import get_current_user, get_guideline_crud
from app.api.dependencies import get_guideline_crud, get_token_payload
from app.crud import GuidelineCRUD
from app.models import Guideline, User, UserScope
from app.models import Guideline, UserScope
from app.schemas.guidelines import (
ContentUpdate,
GuidelineContent,
)
from app.schemas.login import TokenPayload
from app.services.telemetry import telemetry_client

router = APIRouter()
Expand All @@ -23,28 +24,28 @@
async def create_guideline(
payload: GuidelineContent,
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> Guideline:
telemetry_client.capture(user.id, event="guideline-creation")
return await guidelines.create(Guideline(creator_id=user.id, **payload.model_dump()))
telemetry_client.capture(token_payload.user_id, event="guideline-creation")
return await guidelines.create(Guideline(creator_id=token_payload.user_id, **payload.model_dump()))


@router.get("/{guideline_id}", status_code=status.HTTP_200_OK, summary="Read a specific guideline")
async def get_guideline(
guideline_id: int = Path(..., gt=0),
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> Guideline:
telemetry_client.capture(user.id, event="guideline-get", properties={"guideline_id": guideline_id})
telemetry_client.capture(token_payload.user_id, event="guideline-get", properties={"guideline_id": guideline_id})
return cast(Guideline, await guidelines.get(guideline_id, strict=True))


@router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all the guidelines")
async def fetch_guidelines(
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> List[Guideline]:
telemetry_client.capture(user.id, event="guideline-fetch")
telemetry_client.capture(token_payload.user_id, event="guideline-fetch")
return [elt for elt in await guidelines.fetch_all()]


Expand All @@ -53,11 +54,13 @@ async def update_guideline_content(
payload: GuidelineContent,
guideline_id: int = Path(..., gt=0),
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> Guideline:
telemetry_client.capture(user.id, event="guideline-update-content", properties={"guideline_id": guideline_id})
telemetry_client.capture(
token_payload.user_id, event="guideline-update-content", properties={"guideline_id": guideline_id}
)
guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True))
if user.scope != UserScope.ADMIN and user.id != guideline.creator_id:
if UserScope.ADMIN not in token_payload.scopes and token_payload.user_id != guideline.creator_id:
raise HTTPException(status.HTTP_403_FORBIDDEN, "Insufficient permissions.")
return await guidelines.update(guideline_id, ContentUpdate(**payload.model_dump()))

Expand All @@ -66,11 +69,13 @@ async def update_guideline_content(
async def delete_guideline(
guideline_id: int = Path(..., gt=0),
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> None:
telemetry_client.capture(user.id, event="guideline-deletion", properties={"guideline_id": guideline_id})
telemetry_client.capture(
token_payload.user_id, event="guideline-deletion", properties={"guideline_id": guideline_id}
)
guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True))
if user.scope != UserScope.ADMIN and user.id != guideline.creator_id:
if UserScope.ADMIN not in token_payload.scopes and token_payload.user_id != guideline.creator_id:
raise HTTPException(status.HTTP_403_FORBIDDEN, "Insufficient permissions.")
await guidelines.delete(guideline_id)

Expand Down
13 changes: 10 additions & 3 deletions src/app/api/api_v1/endpoints/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.


from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Security, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import HttpUrl

from app.api.api_v1.endpoints.users import _create_user
from app.api.dependencies import get_user_crud
from app.api.dependencies import get_token_payload, get_user_crud
from app.core.config import settings
from app.core.security import create_access_token, verify_password
from app.crud import UserCRUD
from app.models import UserScope
from app.schemas.login import GHAccessToken, Token, TokenRequest
from app.schemas.login import GHAccessToken, Token, TokenPayload, TokenRequest
from app.schemas.services import GHToken
from app.schemas.users import UserCreate
from app.services.github import gh_client
Expand Down Expand Up @@ -97,3 +97,10 @@ async def login_with_github_token(
token = await create_access_token(token_data, settings.ACCESS_TOKEN_UNLIMITED_MINUTES)

return Token(access_token=token, token_type="bearer") # noqa S106


@router.get("/validate", status_code=status.HTTP_200_OK, summary="Check token validity")
async def check_token_validity(
payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.USER, UserScope.ADMIN]),
) -> TokenPayload:
return payload
18 changes: 10 additions & 8 deletions src/app/api/api_v1/endpoints/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from fastapi import APIRouter, Depends, HTTPException, Path, Security, status

from app.api.dependencies import get_current_user, get_repo_crud
from app.api.dependencies import get_current_user, get_repo_crud, get_token_payload
from app.crud import RepositoryCRUD
from app.models import Provider, Repository, User, UserScope
from app.schemas.login import TokenPayload
from app.schemas.repos import RepoRegistration
from app.services.github import gh_client
from app.services.slack import slack_client
Expand Down Expand Up @@ -39,7 +40,8 @@ async def register_repo(
detail="Expected `github_token` to check access.",
)
# Check provider link
if not isinstance(user.provider_user_id, str) or user.provider_user_id.partition("|")[0] != Provider.GITHUB:
# if not isinstance(user.provider_user_id, str) or user.provider_user_id.partition("|")[0] != Provider.GITHUB:
if not isinstance(user.provider_user_id, int):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No GitHub profile linked to your account.")
# Finally, check GH permission
gh_user = gh_client.get_my_user(payload.github_token)
Expand Down Expand Up @@ -78,28 +80,28 @@ async def register_repo(
async def get_repo(
repo_id: int = Path(..., gt=0),
repos: RepositoryCRUD = Depends(get_repo_crud),
user: User = Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> Repository:
telemetry_client.capture(user.id, event="repo-get", properties={"repo_id": repo_id})
telemetry_client.capture(token_payload.user_id, event="repo-get", properties={"repo_id": repo_id})
return cast(Repository, await repos.get(repo_id, strict=True))


@router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all repositories")
async def fetch_repos(
repos: RepositoryCRUD = Depends(get_repo_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> List[Repository]:
telemetry_client.capture(user.id, event="repo-fetch")
telemetry_client.capture(token_payload.user_id, event="repo-fetch")
return [elt for elt in await repos.fetch_all()]


@router.delete("/{repo_id}", status_code=status.HTTP_200_OK, summary="Delete a specific repository")
async def delete_repo(
repo_id: int = Path(..., gt=0),
repos: RepositoryCRUD = Depends(get_repo_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> None:
telemetry_client.capture(user.id, event="repo-delete", properties={"repo_id": repo_id})
telemetry_client.capture(token_payload.user_id, event="repo-delete", properties={"repo_id": repo_id})
await repos.delete(repo_id)


Expand Down
27 changes: 14 additions & 13 deletions src/app/api/api_v1/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from fastapi import APIRouter, Depends, HTTPException, Path, Security, status

from app.api.dependencies import get_current_user, get_user_crud
from app.api.dependencies import get_token_payload, get_user_crud
from app.core.security import hash_password
from app.crud import UserCRUD
from app.models import Provider, User, UserScope
from app.schemas.login import TokenPayload
from app.schemas.users import Cred, CredHash, UserCreate
from app.services.github import gh_client
from app.services.slack import slack_client
Expand All @@ -19,7 +20,7 @@
router = APIRouter()


async def _create_user(payload: UserCreate, users: UserCRUD, requester: Union[User, None] = None) -> User:
async def _create_user(payload: UserCreate, users: UserCRUD, requester_id: Union[int, None] = None) -> User:
valid_creds = False
user_props = {"login": payload.login, "provider_login": None, "name": None, "twitter_username": None}
notif_info = []
Expand Down Expand Up @@ -89,7 +90,7 @@ async def _create_user(payload: UserCreate, users: UserCRUD, requester: Union[Us

# Assume the requester is the new user if none was specified
telemetry_client.capture(
requester.id if isinstance(requester, User) else user.id,
requester_id if isinstance(requester_id, int) else user.id,
event="user-creation",
properties={"created_user_id": user.id},
)
Expand All @@ -102,27 +103,27 @@ async def _create_user(payload: UserCreate, users: UserCRUD, requester: Union[Us
async def create_user(
payload: UserCreate,
users: UserCRUD = Depends(get_user_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> User:
return await _create_user(payload, users, user)
return await _create_user(payload, users, token_payload.user_id)


@router.get("/{user_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific user")
async def get_user(
user_id: int = Path(..., gt=0),
users: UserCRUD = Depends(get_user_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> User:
telemetry_client.capture(user.id, event="user-get", properties={"user_id": user_id})
telemetry_client.capture(token_payload.user_id, event="user-get", properties={"user_id": user_id})
return cast(User, await users.get(user_id, strict=True))


@router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all the users")
async def fetch_users(
users: UserCRUD = Depends(get_user_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> List[User]:
telemetry_client.capture(user.id, event="user-fetch")
telemetry_client.capture(token_payload.user_id, event="user-fetch")
return [elt for elt in await users.fetch_all()]


Expand All @@ -131,9 +132,9 @@ async def update_user_password(
payload: Cred,
user_id: int = Path(..., gt=0),
users: UserCRUD = Depends(get_user_crud),
user: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> User:
telemetry_client.capture(user.id, event="user-pwd", properties={"user_id": user_id})
telemetry_client.capture(token_payload.user_id, event="user-pwd", properties={"user_id": user_id})
pwd = await hash_password(payload.password)
return await users.update(user_id, CredHash(hashed_password=pwd))

Expand All @@ -142,7 +143,7 @@ async def update_user_password(
async def delete_user(
user_id: int = Path(..., gt=0),
users: UserCRUD = Depends(get_user_crud),
_: User = Security(get_current_user, scopes=[UserScope.ADMIN]),
token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]),
) -> None:
telemetry_client.capture(user_id, event="user-deletion", properties={"user_id": user_id})
telemetry_client.capture(token_payload.user_id, event="user-deletion", properties={"user_id": user_id})
await users.delete(user_id)
26 changes: 15 additions & 11 deletions src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from app.models import User, UserScope
from app.schemas.login import TokenPayload

__all__ = ["get_current_user", "get_guideline_crud", "get_repo_crud", "get_user_crud"]
__all__ = ["get_current_user", "get_guideline_crud", "get_repo_crud", "get_token_payload", "get_user_crud"]

# Scope definition
oauth2_scheme = OAuth2PasswordBearer(
Expand All @@ -42,17 +42,11 @@ async def get_guideline_crud(session: AsyncSession = Depends(get_session)) -> Gu
return GuidelineCRUD(session=session)


async def get_current_user(
async def get_token_payload(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme),
users: UserCRUD = Depends(get_user_crud),
) -> User:
"""Dependency to use as fastapi.security.Security with scopes.
>>> @app.get("/users/me")
>>> async def read_users_me(current_user: User = Security(get_current_user, scopes=["me"])):
>>> return current_user
"""
) -> TokenPayload:
"""Dependency to use as fastapi.security.Security with scopes."""
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' if security_scopes.scopes else "Bearer"

try:
Expand Down Expand Up @@ -82,4 +76,14 @@ async def get_current_user(
headers={"WWW-Authenticate": authenticate_value},
)

return cast(User, await users.get(user_id, strict=True))
return token_data


async def get_current_user(
security_scopes: SecurityScopes,
token: str = Depends(oauth2_scheme),
users: UserCRUD = Depends(get_user_crud),
) -> User:
"""Dependency to use as fastapi.security.Security with scopes"""
token_payload = await get_token_payload(security_scopes, token)
return cast(User, await users.get(token_payload.user_id, strict=True))

0 comments on commit d97c110

Please sign in to comment.