diff --git a/.env.example b/.env.example index 71c920b..0952ffa 100644 --- a/.env.example +++ b/.env.example @@ -22,7 +22,7 @@ OLLAMA_MODEL='dolphin-mistral:7b-v2.6-dpo-laser-q4_K_M' # OLLAMA_MODEL='tinydolphin:1.1b-v2.8-q4_K_M' OLLAMA_TIMEOUT=120 LLM_TEMPERATURE=0 -SECRET_KEY= +JWT_SECRET= SENTRY_DSN= SERVER_NAME= POSTHOG_HOST='https://eu.posthog.com' diff --git a/auth/docker-compose.yml b/auth/docker-compose.yml new file mode 100644 index 0000000..e140eee --- /dev/null +++ b/auth/docker-compose.yml @@ -0,0 +1,105 @@ +name: quack +version: '3.8' + +services: + auth: + image: supabase/gotrue:v2.143.0 + depends_on: + auth_db: + condition: service_healthy + healthcheck: + test: + [ + "CMD", + "wget", + "--no-verbose", + "--tries=1", + "--spider", + "http://localhost:9999/health" + ] + timeout: 5s + interval: 5s + retries: 3 + restart: unless-stopped + ports: + - 9999:9999 + environment: + GOTRUE_API_HOST: 0.0.0.0 + GOTRUE_API_PORT: 9999 + API_EXTERNAL_URL: ${API_EXTERNAL_URL} + + GOTRUE_DB_DRIVER: postgres + GOTRUE_DB_DATABASE_URL: postgres://supabase_auth_admin:${AUTH_PG_PW}@auth_db:${AUTH_PG_PORT}/${AUTH_PG_DB} + + GOTRUE_SITE_URL: ${GOTRUE_SITE_URL} + GOTRUE_URI_ALLOW_LIST: ${ADDITIONAL_REDIRECT_URLS} + GOTRUE_DISABLE_SIGNUP: false + + GOTRUE_JWT_ADMIN_ROLES: service_role + GOTRUE_JWT_AUD: authenticated + GOTRUE_JWT_DEFAULT_GROUP_NAME: authenticated + GOTRUE_JWT_EXP: ${JWT_EXPIRY} + GOTRUE_JWT_SECRET: ${JWT_SECRET} + + GOTRUE_EXTERNAL_EMAIL_ENABLED: true + GOTRUE_MAILER_AUTOCONFIRM: true + # GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED: true + # GOTRUE_SMTP_MAX_FREQUENCY: 1s + GOTRUE_SMTP_ADMIN_EMAIL: ${SMTP_ADMIN_EMAIL} + GOTRUE_SMTP_HOST: ${SMTP_HOST} + GOTRUE_SMTP_PORT: ${SMTP_PORT} + GOTRUE_SMTP_USER: ${SMTP_USER} + GOTRUE_SMTP_PASS: ${SMTP_PASS} + GOTRUE_SMTP_SENDER_NAME: ${SMTP_SENDER_NAME} + GOTRUE_MAILER_URLPATHS_INVITE: ${MAILER_URLPATHS_INVITE} + GOTRUE_MAILER_URLPATHS_CONFIRMATION: ${MAILER_URLPATHS_CONFIRMATION} + GOTRUE_MAILER_URLPATHS_RECOVERY: ${MAILER_URLPATHS_RECOVERY} + GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE: ${MAILER_URLPATHS_EMAIL_CHANGE} + + GOTRUE_EXTERNAL_PHONE_ENABLED: false + GOTRUE_SMS_AUTOCONFIRM: true + + auth_db: + image: supabase/postgres:15.1.0.147 + healthcheck: + test: pg_isready -U postgres -h localhost + interval: 5s + timeout: 5s + retries: 10 + command: + - postgres + - -c + - config_file=/etc/postgresql/postgresql.conf + - -c + - log_min_messages=fatal # prevents Realtime polling queries from appearing in logs + restart: unless-stopped + ports: + # Pass down internal port because it's set dynamically by other services + - ${AUTH_PG_PORT}:${AUTH_PG_PORT} + environment: + POSTGRES_HOST: /var/run/postgresql + PGPORT: ${AUTH_PG_PORT} + POSTGRES_PORT: ${AUTH_PG_PORT} + PGPASSWORD: ${AUTH_PG_PW} + POSTGRES_PASSWORD: ${AUTH_PG_PW} + PGDATABASE: ${AUTH_PG_DB} + POSTGRES_DB: ${AUTH_PG_DB} + JWT_SECRET: ${JWT_SECRET} + JWT_EXP: ${JWT_EXPIRY} + volumes: + - ./auth/volumes/db/realtime.sql:/docker-entrypoint-initdb.d/migrations/99-realtime.sql:Z + # Must be superuser to create event trigger + - ./auth/volumes/db/webhooks.sql:/docker-entrypoint-initdb.d/init-scripts/98-webhooks.sql:Z + # Must be superuser to alter reserved role + - ./auth/volumes/db/roles.sql:/docker-entrypoint-initdb.d/init-scripts/99-roles.sql:Z + # Initialize the database settings with JWT_SECRET and JWT_EXP + - ./auth/volumes/db/jwt.sql:/docker-entrypoint-initdb.d/init-scripts/99-jwt.sql:Z + # PGDATA directory is persisted between restarts + - ./auth/volumes/db/data:/var/lib/postgresql/data:Z + # Changes required for Analytics support + - ./auth/volumes/db/logs.sql:/docker-entrypoint-initdb.d/migrations/99-logs.sql:Z + # Use named volume to persist pgsodium decryption key between restarts + - db-config:/etc/postgresql-custom + +volumes: + db-config: diff --git a/demo/Dockerfile b/demo/Dockerfile index a11eec7..4a3b26b 100644 --- a/demo/Dockerfile +++ b/demo/Dockerfile @@ -12,6 +12,7 @@ COPY demo/requirements.txt /app/requirements.txt # install dependencies RUN set -eux \ + && apk add --no-cache curl \ && pip install --no-cache-dir uv \ && uv pip install --no-cache --system -r /app/requirements.txt \ && rm -rf /root/.cache diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 87fa290..e7349fd 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -1,3 +1,4 @@ +name: quack version: '3.8' services: @@ -55,10 +56,7 @@ services: - ./src/:/app/ command: "sh -c 'python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'" healthcheck: - test: ["CMD-SHELL", "nc -vz localhost 5050"] + test: ["CMD-SHELL", "curl http://localhost:5050/status"] interval: 10s timeout: 3s retries: 3 - -volumes: - ollama: diff --git a/docker-compose.override.yml b/docker-compose.override.yml index d0799aa..0aa67d6 100644 --- a/docker-compose.override.yml +++ b/docker-compose.override.yml @@ -1,3 +1,4 @@ +name: quack version: '3.8' services: diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index 0c7383d..b57e497 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -1,3 +1,4 @@ +name: quack version: '3.8' services: @@ -91,7 +92,7 @@ services: - "traefik.http.services.backend.loadbalancer.server.port=5050" command: "sh -c 'alembic upgrade head && python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'" healthcheck: - test: ["CMD-SHELL", "nc -vz localhost 5050"] + test: ["CMD-SHELL", "curl http://localhost:5050/status"] interval: 10s timeout: 3s retries: 3 @@ -117,7 +118,7 @@ services: - "traefik.http.services.grafana.loadbalancer.server.port=7860" command: python main.py --server-name 0.0.0.0 --auth healthcheck: - test: ["CMD-SHELL", "nc -vz localhost 7860"] + test: ["CMD-SHELL", "curl http://localhost:7860"] interval: 10s timeout: 3s retries: 3 @@ -159,6 +160,3 @@ services: interval: 10s timeout: 3s retries: 3 - -volumes: - ollama: diff --git a/docker-compose.yml b/docker-compose.yml index aaf8267..d014a3a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,3 +1,4 @@ +name: quack version: '3.8' services: @@ -49,7 +50,7 @@ services: - POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@db/${POSTGRES_DB} - SUPERADMIN_LOGIN=${SUPERADMIN_LOGIN} - SUPERADMIN_PWD=${SUPERADMIN_PWD} - - SECRET_KEY=${SECRET_KEY} + - JWT_SECRET=${JWT_SECRET} - OLLAMA_ENDPOINT=http://ollama:11434 - OLLAMA_MODEL=${OLLAMA_MODEL} - OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-60} @@ -60,7 +61,7 @@ services: - ./src/:/app/ command: "sh -c 'alembic upgrade head && python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'" healthcheck: - test: ["CMD-SHELL", "nc -vz localhost 5050"] + test: ["CMD-SHELL", "curl http://localhost:5050/status"] interval: 10s timeout: 3s retries: 3 @@ -83,7 +84,7 @@ services: - ./demo/:/app/ command: python main.py --server-name 0.0.0.0 healthcheck: - test: ["CMD-SHELL", "nc -vz localhost 7860"] + test: ["CMD-SHELL", "curl http://localhost:7860"] interval: 10s timeout: 3s retries: 3 @@ -126,4 +127,3 @@ services: volumes: postgres_data: - ollama: diff --git a/pyproject.toml b/pyproject.toml index d46f4f1..5fbb314 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,7 +131,6 @@ known-third-party = ["fastapi"] "**/__init__.py" = ["I001", "F401", "CPY001"] "scripts/**.py" = ["D", "T201", "S101", "ANN"] ".github/**.py" = ["D", "T201", "ANN"] -"client/docs/**.py" = ["E402"] "src/tests/**.py" = ["D103", "CPY001", "S101", "T201", "ANN001", "ANN201", "ARG001"] "src/migrations/versions/**.py" = ["CPY001"] "src/migrations/**.py" = ["ANN"] diff --git a/src/Dockerfile b/src/Dockerfile index 211051e..33ff062 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -12,6 +12,7 @@ COPY requirements.txt /app/requirements.txt # install dependencies RUN set -eux \ + && apk add --no-cache curl \ && pip install --no-cache-dir uv \ && uv pip install --no-cache --system -r /app/requirements.txt \ && rm -rf /root/.cache diff --git a/src/app/api/api_v1/endpoints/code.py b/src/app/api/api_v1/endpoints/code.py index e150ab4..cbcaf4e 100644 --- a/src/app/api/api_v1/endpoints/code.py +++ b/src/app/api/api_v1/endpoints/code.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, Security, status from fastapi.responses import StreamingResponse -from app.api.dependencies import get_guideline_crud, get_token_payload +from app.api.dependencies import get_guideline_crud, get_quack_jwt from app.core.config import settings from app.crud.crud_guideline import GuidelineCRUD from app.models import UserScope @@ -23,11 +23,11 @@ async def chat( payload: ChatHistory, guidelines: GuidelineCRUD = Depends(get_guideline_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> StreamingResponse: - telemetry_client.capture(token_payload.user_id, event="compute-chat") + telemetry_client.capture(token_payload.sub, event="compute-chat") # Retrieve the guidelines of this user - user_guidelines = [g.content for g in await guidelines.fetch_all(filter_pair=("creator_id", token_payload.user_id))] + user_guidelines = [g.content for g in await guidelines.fetch_all(filter_pair=("creator_id", token_payload.sub))] # Run analysis return StreamingResponse( ollama_client.chat( diff --git a/src/app/api/api_v1/endpoints/guidelines.py b/src/app/api/api_v1/endpoints/guidelines.py index 9a5d00d..21024f6 100644 --- a/src/app/api/api_v1/endpoints/guidelines.py +++ b/src/app/api/api_v1/endpoints/guidelines.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Security, status -from app.api.dependencies import get_guideline_crud, get_token_payload +from app.api.dependencies import get_guideline_crud, get_quack_jwt from app.crud import GuidelineCRUD from app.models import Guideline, UserScope from app.schemas.guidelines import ( @@ -24,29 +24,29 @@ async def create_guideline( payload: GuidelineContent, guidelines: GuidelineCRUD = Depends(get_guideline_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Guideline: - telemetry_client.capture(token_payload.user_id, event="guideline-creation") - return await guidelines.create(Guideline(creator_id=token_payload.user_id, **payload.model_dump())) + telemetry_client.capture(token_payload.sub, event="guideline-creation") + return await guidelines.create(Guideline(creator_id=token_payload.sub, **payload.model_dump())) @router.get("/{guideline_id}", status_code=status.HTTP_200_OK, summary="Read a specific guideline") async def get_guideline( guideline_id: int = Path(..., gt=0), guidelines: GuidelineCRUD = Depends(get_guideline_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Guideline: - telemetry_client.capture(token_payload.user_id, event="guideline-get", properties={"guideline_id": guideline_id}) + telemetry_client.capture(token_payload.sub, event="guideline-get", properties={"guideline_id": guideline_id}) return cast(Guideline, await guidelines.get(guideline_id, strict=True)) @router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all the guidelines") async def fetch_guidelines( guidelines: GuidelineCRUD = Depends(get_guideline_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.USER, UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.USER, UserScope.ADMIN]), ) -> List[Guideline]: - telemetry_client.capture(token_payload.user_id, event="guideline-fetch") - filter_pair = ("creator_id", token_payload.user_id) if UserScope.ADMIN not in token_payload.scopes else None + telemetry_client.capture(token_payload.sub, event="guideline-fetch") + filter_pair = ("creator_id", token_payload.sub) if UserScope.ADMIN not in token_payload.scopes else None return [elt for elt in await guidelines.fetch_all(filter_pair=filter_pair)] @@ -55,13 +55,13 @@ async def update_guideline_content( payload: GuidelineContent, guideline_id: int = Path(..., gt=0), guidelines: GuidelineCRUD = Depends(get_guideline_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Guideline: telemetry_client.capture( - token_payload.user_id, event="guideline-update-content", properties={"guideline_id": guideline_id} + token_payload.sub, event="guideline-update-content", properties={"guideline_id": guideline_id} ) guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True)) - if UserScope.ADMIN not in token_payload.scopes and token_payload.user_id != guideline.creator_id: + if UserScope.ADMIN not in token_payload.scopes and token_payload.sub != guideline.creator_id: raise HTTPException(status.HTTP_403_FORBIDDEN, "Insufficient permissions.") return await guidelines.update(guideline_id, ContentUpdate(**payload.model_dump())) @@ -70,13 +70,11 @@ async def update_guideline_content( async def delete_guideline( guideline_id: int = Path(..., gt=0), guidelines: GuidelineCRUD = Depends(get_guideline_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> None: - telemetry_client.capture( - token_payload.user_id, event="guideline-deletion", properties={"guideline_id": guideline_id} - ) + telemetry_client.capture(token_payload.sub, event="guideline-deletion", properties={"guideline_id": guideline_id}) guideline = cast(Guideline, await guidelines.get(guideline_id, strict=True)) - if UserScope.ADMIN not in token_payload.scopes and token_payload.user_id != guideline.creator_id: + if UserScope.ADMIN not in token_payload.scopes and token_payload.sub != guideline.creator_id: raise HTTPException(status.HTTP_403_FORBIDDEN, "Insufficient permissions.") await guidelines.delete(guideline_id) diff --git a/src/app/api/api_v1/endpoints/login.py b/src/app/api/api_v1/endpoints/login.py index 3127b91..5a61567 100644 --- a/src/app/api/api_v1/endpoints/login.py +++ b/src/app/api/api_v1/endpoints/login.py @@ -10,7 +10,7 @@ from pydantic import HttpUrl from app.api.api_v1.endpoints.users import _create_user -from app.api.dependencies import get_token_payload, get_user_crud +from app.api.dependencies import get_quack_jwt, get_user_crud from app.core.config import settings from app.core.security import create_access_token, verify_password from app.crud import UserCRUD @@ -96,6 +96,6 @@ async def login_with_github_token( @router.get("/validate", status_code=status.HTTP_200_OK, summary="Check token validity") def check_token_validity( - payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.USER, UserScope.ADMIN]), + payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.USER, UserScope.ADMIN]), ) -> TokenPayload: return payload diff --git a/src/app/api/api_v1/endpoints/repos.py b/src/app/api/api_v1/endpoints/repos.py index 5aaef8e..e1c4936 100644 --- a/src/app/api/api_v1/endpoints/repos.py +++ b/src/app/api/api_v1/endpoints/repos.py @@ -8,7 +8,7 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Security, status -from app.api.dependencies import get_current_user, get_repo_crud, get_token_payload +from app.api.dependencies import get_current_user, get_quack_jwt, get_repo_crud from app.crud import RepositoryCRUD from app.models import Provider, Repository, User, UserScope from app.schemas.login import TokenPayload @@ -80,18 +80,18 @@ async def register_repo( async def get_repo( repo_id: int = Path(..., gt=0), repos: RepositoryCRUD = Depends(get_repo_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN, UserScope.USER]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]), ) -> Repository: - telemetry_client.capture(token_payload.user_id, event="repo-get", properties={"repo_id": repo_id}) + telemetry_client.capture(token_payload.sub, event="repo-get", properties={"repo_id": repo_id}) return cast(Repository, await repos.get(repo_id, strict=True)) @router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all repositories") async def fetch_repos( repos: RepositoryCRUD = Depends(get_repo_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> List[Repository]: - telemetry_client.capture(token_payload.user_id, event="repo-fetch") + telemetry_client.capture(token_payload.sub, event="repo-fetch") return [elt for elt in await repos.fetch_all()] @@ -99,9 +99,9 @@ async def fetch_repos( async def delete_repo( repo_id: int = Path(..., gt=0), repos: RepositoryCRUD = Depends(get_repo_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> None: - telemetry_client.capture(token_payload.user_id, event="repo-delete", properties={"repo_id": repo_id}) + telemetry_client.capture(token_payload.sub, event="repo-delete", properties={"repo_id": repo_id}) await repos.delete(repo_id) diff --git a/src/app/api/api_v1/endpoints/users.py b/src/app/api/api_v1/endpoints/users.py index 4234bfd..078ea0e 100644 --- a/src/app/api/api_v1/endpoints/users.py +++ b/src/app/api/api_v1/endpoints/users.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Security, status -from app.api.dependencies import get_token_payload, get_user_crud +from app.api.dependencies import get_quack_jwt, get_user_crud from app.core.security import hash_password from app.crud import UserCRUD from app.models import Provider, User, UserScope @@ -103,27 +103,27 @@ async def _create_user(payload: UserCreate, users: UserCRUD, requester_id: Union async def create_user( payload: UserCreate, users: UserCRUD = Depends(get_user_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> User: - return await _create_user(payload, users, token_payload.user_id) + return await _create_user(payload, users, token_payload.sub) @router.get("/{user_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific user") async def get_user( user_id: int = Path(..., gt=0), users: UserCRUD = Depends(get_user_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> User: - telemetry_client.capture(token_payload.user_id, event="user-get", properties={"user_id": user_id}) + telemetry_client.capture(token_payload.sub, event="user-get", properties={"user_id": user_id}) return cast(User, await users.get(user_id, strict=True)) @router.get("/", status_code=status.HTTP_200_OK, summary="Fetch all the users") async def fetch_users( users: UserCRUD = Depends(get_user_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> List[User]: - telemetry_client.capture(token_payload.user_id, event="user-fetch") + telemetry_client.capture(token_payload.sub, event="user-fetch") return [elt for elt in await users.fetch_all()] @@ -132,9 +132,9 @@ async def update_user_password( payload: Cred, user_id: int = Path(..., gt=0), users: UserCRUD = Depends(get_user_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> User: - telemetry_client.capture(token_payload.user_id, event="user-pwd", properties={"user_id": user_id}) + telemetry_client.capture(token_payload.sub, event="user-pwd", properties={"user_id": user_id}) pwd = hash_password(payload.password) return await users.update(user_id, CredHash(hashed_password=pwd)) @@ -143,7 +143,7 @@ async def update_user_password( async def delete_user( user_id: int = Path(..., gt=0), users: UserCRUD = Depends(get_user_crud), - token_payload: TokenPayload = Security(get_token_payload, scopes=[UserScope.ADMIN]), + token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN]), ) -> None: - telemetry_client.capture(token_payload.user_id, event="user-deletion", properties={"user_id": user_id}) + telemetry_client.capture(token_payload.sub, event="user-deletion", properties={"user_id": user_id}) await users.delete(user_id) diff --git a/src/app/api/dependencies.py b/src/app/api/dependencies.py index d668028..778f584 100644 --- a/src/app/api/dependencies.py +++ b/src/app/api/dependencies.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import cast +from typing import Dict, Type, TypeVar, Union, cast from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, SecurityScopes @@ -17,8 +17,11 @@ from app.db import get_session from app.models import User, UserScope from app.schemas.login import TokenPayload +from app.services.auth.supabase import SupaJWT -__all__ = ["get_current_user", "get_guideline_crud", "get_repo_crud", "get_token_payload", "get_user_crud"] +JWTTemplate = TypeVar("JWTTemplate") + +__all__ = ["get_guideline_crud", "get_repo_crud", "get_user_crud"] # Scope definition oauth2_scheme = OAuth2PasswordBearer( @@ -42,47 +45,68 @@ def get_guideline_crud(session: AsyncSession = Depends(get_session)) -> Guidelin return GuidelineCRUD(session=session) -def get_token_payload( - security_scopes: SecurityScopes, - token: str = Depends(oauth2_scheme), -) -> TokenPayload: - """Dependency to use as fastapi.security.Security with scopes.""" - authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' if security_scopes.scopes else "Bearer" - +def decode_token(token: str, authenticate_value: Union[str, None] = None) -> Dict[str, str]: try: - payload = jwt_decode(token, settings.SECRET_KEY, algorithms=[settings.JWT_ENCODING_ALGORITHM]) + payload = jwt_decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ENCODING_ALGORITHM]) except (ExpiredSignatureError, InvalidSignatureError): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired.", - headers={"WWW-Authenticate": authenticate_value}, + headers={"WWW-Authenticate": authenticate_value} if authenticate_value else None, ) except DecodeError: raise HTTPException( status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="Invalid token.", - headers={"WWW-Authenticate": authenticate_value}, + headers={"WWW-Authenticate": authenticate_value} if authenticate_value else None, ) + return payload + +def process_token( + token: str, jwt_template: Type[JWTTemplate], authenticate_value: Union[str, None] = None +) -> JWTTemplate: + payload = decode_token(token) + # Verify the JWT template try: - user_id = int(payload["sub"]) - token_scopes = payload.get("scopes", []) - token_data = TokenPayload(user_id=user_id, scopes=token_scopes) - except (KeyError, ValidationError): + return jwt_template(**payload) + except ValidationError: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid token payload.", - headers={"WWW-Authenticate": authenticate_value}, + headers={"WWW-Authenticate": authenticate_value} if authenticate_value else None, ) - if set(token_data.scopes).isdisjoint(security_scopes.scopes): + +def get_supa_jwt( + security_scopes: SecurityScopes, + token: str = Depends(oauth2_scheme), +) -> SupaJWT: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' if security_scopes.scopes else "Bearer" + jwt_payload = process_token(token, SupaJWT, authenticate_value=authenticate_value) + # Retrieve the actual role + if set(jwt_payload.user_metadata.quack_role).isdisjoint(security_scopes.scopes): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Incompatible token scope.", headers={"WWW-Authenticate": authenticate_value}, ) + return jwt_payload + - return token_data +def get_quack_jwt( + security_scopes: SecurityScopes, + token: str = Depends(oauth2_scheme), +) -> TokenPayload: + authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' if security_scopes.scopes else "Bearer" + jwt_payload = process_token(token, TokenPayload) + if set(jwt_payload.scopes).isdisjoint(security_scopes.scopes): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Incompatible token scope.", + headers={"WWW-Authenticate": authenticate_value}, + ) + return jwt_payload async def get_current_user( @@ -91,5 +115,5 @@ async def get_current_user( users: UserCRUD = Depends(get_user_crud), ) -> User: """Dependency to use as fastapi.security.Security with scopes""" - token_payload = get_token_payload(security_scopes, token) - return cast(User, await users.get(token_payload.user_id, strict=True)) + token_payload = get_quack_jwt(security_scopes, token) + return cast(User, await users.get(token_payload.sub, strict=True)) diff --git a/src/app/core/config.py b/src/app/core/config.py index 54b233a..c65e34e 100644 --- a/src/app/core/config.py +++ b/src/app/core/config.py @@ -41,7 +41,7 @@ def sqlachmey_uri(cls, v: str) -> str: return v # Security - SECRET_KEY: str = os.environ.get("SECRET_KEY", secrets.token_urlsafe(32)) + JWT_SECRET: str = os.environ.get("JWT_SECRET", secrets.token_urlsafe(32)) ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 ACCESS_TOKEN_UNLIMITED_MINUTES: int = 60 * 24 * 365 JWT_ENCODING_ALGORITHM: str = "HS256" diff --git a/src/app/core/security.py b/src/app/core/security.py index 021a9d5..6c4a002 100644 --- a/src/app/core/security.py +++ b/src/app/core/security.py @@ -20,7 +20,7 @@ def create_access_token(content: Dict[str, Any], expires_minutes: Optional[int] """Encode content dict using security algorithm, setting expiration.""" expire_delta = timedelta(minutes=expires_minutes or settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.utcnow() + expire_delta - return jwt.encode({**content, "exp": expire}, settings.SECRET_KEY, algorithm=settings.JWT_ENCODING_ALGORITHM) + return jwt.encode({**content, "exp": expire}, settings.JWT_SECRET, algorithm=settings.JWT_ENCODING_ALGORITHM) def verify_password(plain_password: str, hashed_password: str) -> bool: diff --git a/src/app/schemas/login.py b/src/app/schemas/login.py index 5ebfa53..0a73ff8 100644 --- a/src/app/schemas/login.py +++ b/src/app/schemas/login.py @@ -24,7 +24,7 @@ class Token(BaseModel): class TokenPayload(BaseModel): - user_id: int = Field(..., gt=0) + sub: int = Field(..., gt=0) scopes: List[UserScope] = [] diff --git a/src/app/services/auth/supabase.py b/src/app/services/auth/supabase.py new file mode 100755 index 0000000..ae607b5 --- /dev/null +++ b/src/app/services/auth/supabase.py @@ -0,0 +1,187 @@ +# Copyright (C) 2024, Quack AI. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import logging +from enum import Enum +from typing import Dict, List, Union +from urllib.parse import urljoin + +import jwt +import requests +from fastapi import HTTPException +from pydantic import BaseModel + +logger = logging.getLogger("uvicorn.error") + +__all__ = ["SupaJWT"] + + +class Login(BaseModel): + email: str + password: str + + +class OauthProvider(str, Enum): + GITHUB: str = "github" + GOOGLE: str = "google" + TWITTER: str = "twitter" + + +class Provider(str, Enum): + GITHUB: str = "github" + GOOGLE: str = "google" + TWITTER: str = "twitter" + EMAIL: str = "email" + + +class OIDCProvider(str, Enum): + GOOGLE: str = "google" + APPLE: str = "apple" + AZURE: str = "azure" + FACEBOOK: str = "facebook" + + +class IDToken(BaseModel): + provider: OIDCProvider + id_token: str + + +class AppMetaData(BaseModel): + provider: str + providers: List[str] + + +class UserRole(str, Enum): + ADMIN: str = "admin" + USER: str = "user" + + +class UserMetaData(BaseModel): + quack_role: UserRole = UserRole.USER + sub: Union[str, None] = None + user_name: Union[str, None] = None + full_name: Union[str, None] = None + iss: Union[str, None] = None + + +class SupaJWT(BaseModel): + sub: str + email: str + iat: int + exp: int + app_metadata: AppMetaData + user_metadata: UserMetaData + + +class SupaUser(BaseModel): + id: str + email: str + created_at: str + updated_at: str + # Check providers to see connected Oauth + app_metadata: AppMetaData + # Check username and sub for Github username & ID (+ quack_role) + user_metadata: UserMetaData + + +class LoginResponse(BaseModel): + access_token: str + token_type: str + expires_in: int + expires_at: int + refresh_token: str + user: SupaUser + + +def issue_admin_token(secret_key: str, role: str = "service_role", algorithm: str = "HS256") -> str: + return jwt.encode({"role": role}, secret_key, algorithm=algorithm) + + +class SupaClient: + ENDPOINT: str = "https://api.clerk.com/v1" + + def __init__(self, endpoint: str, api_key: str, service_token: str) -> None: + self.endpoint = endpoint + self.api_key = api_key + self.token = service_token + self.headers = self._get_headers(api_key) + # Validate token + self._request("get", "/health", headers=self.headers) + logger.info("Using Supabase authentication service") + + @staticmethod + def _get_headers(api_key: str) -> Dict[str, str]: + return {"apiKey": api_key, "Content-Type": "application/json"} + + def _request( + self, operation: str, route: str, expected_status_code: int = 200, timeout: int = 2, **kwargs + ) -> Dict[str, str]: + response = getattr(requests, operation)(urljoin(self.endpoint, route), timeout=timeout, **kwargs) + json_response = response.json() + if response.status_code != expected_status_code: + raise HTTPException(status_code=response.status_code, detail=json_response["errors"][0]["message"]) + + return json_response + + def sign_up(self, payload: Login, metadata: Union[Dict[str, str], None] = None) -> LoginResponse: + json_payload = ( + {**payload.model_dump_json(), "data": metadata} if isinstance(metadata, dict) else payload.model_dump_json() # type: ignore[dict-item] + ) + return self._request("post", "/signup", json=json_payload, headers=self.headers) # type: ignore[return-value] + + def login_with_password(self, payload: Login) -> LoginResponse: + return self._request( # type: ignore[return-value] + "post", "/token", params={"grant_type": "password"}, json=payload.model_dump_json(), headers=self.headers + ) + + def login_with_idtoken(self, payload: IDToken) -> LoginResponse: + return self._request( # type: ignore[return-value] + "post", "/token", params={"grant_type": "id_token"}, json=payload.model_dump_json(), headers=self.headers + ) + + def magic_link(self, email: str) -> Dict[str, str]: + return self._request("post", "/magiclink", json={"email": email}, headers=self.headers) + + def authorize(self, provider: Provider = Provider.GITHUB) -> Dict[str, str]: + return self._request( + "get", + "/authorize", + params={"provider": provider}, + headers={**self.headers, "Authorization": f"Bearer {self.token}"}, + ) + + def get_authenticated_user(self, token: str) -> SupaUser: + return self._request("get", "/user", headers={**self.headers, "Authorization": f"Bearer {token}"}) # type: ignore[return-value] + + def get_user(self, uid: str) -> SupaUser: + return self._request( # type: ignore[return-value] + "get", f"/admin/users/{uid}", headers={**self.headers, "Authorization": f"Bearer {self.token}"} + ) + + def update_user(self, uid: str, payload: Dict[str, str]) -> SupaUser: + return self._request( # type: ignore[return-value] + "put", + f"/admin/users/{uid}", + json=payload, + headers={**self.headers, "Authorization": f"Bearer {self.token}"}, + ) + + def delete_user(self, uid: str) -> Dict[str, str]: + return self._request( + "delete", f"/admin/users/{uid}", headers={**self.headers, "Authorization": f"Bearer {self.token}"} + ) + + def recover(self, email: str) -> Dict[str, str]: + return self._request("post", "/recover", json={"email": email}, headers=self.headers) + + def invite_user(self, email: str) -> Dict[str, str]: + return self._request( + "post", "/invite", json={"email": email}, headers={**self.headers, "Authorization": f"Bearer {self.token}"} + ) + + +# supabase_client = SupaClient( +# settings.SUPABASE_AUTH_ENDPOINT, settings.SUPABASE_API_KEY, issue_admin_token(settings.JWT_SECRET) +# ) diff --git a/src/tests/test_dependencies.py b/src/tests/test_dependencies.py index c8e8c3d..b3402ef 100644 --- a/src/tests/test_dependencies.py +++ b/src/tests/test_dependencies.py @@ -2,7 +2,7 @@ from fastapi import HTTPException from fastapi.security import SecurityScopes -from app.api.dependencies import get_token_payload +from app.api.dependencies import get_quack_jwt from app.core.security import create_access_token @@ -12,16 +12,16 @@ (["admin"], "", None, 406, None), (["admin"], {"user_id": "123", "scopes": ["admin"]}, None, 422, None), (["admin"], {"sub": "123", "scopes": ["admin"]}, -1, 401, None), - (["admin"], {"sub": "123", "scopes": ["admin"]}, None, None, {"user_id": 123, "scopes": ["admin"]}), + (["admin"], {"sub": "123", "scopes": ["admin"]}, None, None, {"sub": 123, "scopes": ["admin"]}), (["admin"], {"sub": "123", "scopes": ["user"]}, None, 403, None), ], ) -def test_get_token_payload(scopes, token, expires_minutes, error_code, expected_payload): +def test_get_quack_jwt(scopes, token, expires_minutes, error_code, expected_payload): _token = create_access_token(token, expires_minutes) if isinstance(token, dict) else token if isinstance(error_code, int): with pytest.raises(HTTPException): - get_token_payload(SecurityScopes(scopes), _token) + get_quack_jwt(SecurityScopes(scopes), _token) else: - payload = get_token_payload(SecurityScopes(scopes), _token) + payload = get_quack_jwt(SecurityScopes(scopes), _token) if expected_payload is not None: assert payload.model_dump() == expected_payload diff --git a/src/tests/test_security.py b/src/tests/test_security.py index 2634a07..1aa1828 100644 --- a/src/tests/test_security.py +++ b/src/tests/test_security.py @@ -36,7 +36,7 @@ def test_create_access_token(content, expires_minutes, expected_delta): payload = create_access_token(content, expires_minutes) after = datetime.utcnow() assert isinstance(payload, str) - decoded_data = jwt.decode(payload, settings.SECRET_KEY, algorithms=[settings.JWT_ENCODING_ALGORITHM]) + decoded_data = jwt.decode(payload, settings.JWT_SECRET, algorithms=[settings.JWT_ENCODING_ALGORITHM]) # Verify data integrity assert all(v == decoded_data[k] for k, v in content.items()) # Check expiration