From 34d5555d361daa716a433b50c7ecfab9fa1f201c Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Tue, 13 Feb 2024 16:27:50 +0000 Subject: [PATCH] runner+worker: Add bearer auth --- runner/app/routes/image_to_image.py | 13 +++++++- runner/app/routes/image_to_video.py | 24 ++++++++++----- runner/app/routes/text_to_image.py | 16 ++++++++-- runner/openapi.json | 27 ++++++++++++++-- worker/runner.gen.go | 48 +++++++++++++++++------------ 5 files changed, 96 insertions(+), 32 deletions(-) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index fb8a33ba..33c82616 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -1,5 +1,6 @@ from fastapi import Depends, APIRouter, UploadFile, File, Form from fastapi.responses import JSONResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.pipelines import ImageToImagePipeline from app.dependencies import get_pipeline from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error @@ -7,7 +8,7 @@ from typing import Annotated import logging import random -from typing import List +import os router = APIRouter() @@ -35,7 +36,17 @@ async def image_to_image( seed: Annotated[int, Form()] = None, num_images_per_prompt: Annotated[int, Form()] = 1, pipeline: ImageToImagePipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=401, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + if model_id != "" and model_id != pipeline.model_id: return JSONResponse( status_code=400, diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 2bf5576c..d90f1306 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -1,12 +1,14 @@ from fastapi import Depends, APIRouter, UploadFile, File, Form from fastapi.responses import JSONResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.pipelines import ImageToVideoPipeline from app.dependencies import get_pipeline -from app.routes.util import image_to_data_url, VideoResponse, HTTPError +from app.routes.util import image_to_data_url, VideoResponse, HTTPError, http_error import PIL from typing import Annotated import logging import random +import os router = APIRouter() @@ -34,15 +36,23 @@ async def image_to_video( noise_aug_strength: Annotated[float, Form()] = 0.02, seed: Annotated[int, Form()] = None, pipeline: ImageToVideoPipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=401, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + if model_id != "" and model_id != pipeline.model_id: return JSONResponse( status_code=400, - content={ - "detail": { - "msg": f"pipeline configured with {pipeline.model_id} but called with {model_id}" - } - }, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with {model_id}" + ), ) if seed is None: @@ -62,7 +72,7 @@ async def image_to_video( logger.error(f"ImageToVideoPipeline error: {e}") logger.exception(e) return JSONResponse( - status_code=500, content={"detail": {"msg": "ImageToVideoPipeline error"}} + status_code=500, content=http_error("ImageToVideoPipeline error") ) output_frames = [] diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 3a4028ef..4d00c8be 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -1,12 +1,13 @@ from pydantic import BaseModel from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from app.pipelines import TextToImagePipeline from app.dependencies import get_pipeline from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error import logging -from typing import List import random +import os router = APIRouter() @@ -32,8 +33,19 @@ class TextToImageParams(BaseModel): @router.post("/text-to-image", response_model=ImageResponse, responses=responses) @router.post("/text-to-image/", response_model=ImageResponse, include_in_schema=False) async def text_to_image( - params: TextToImageParams, pipeline: TextToImagePipeline = Depends(get_pipeline) + params: TextToImageParams, + pipeline: TextToImagePipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=401, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + if params.model_id != "" and params.model_id != pipeline.model_id: return JSONResponse( status_code=400, diff --git a/runner/openapi.json b/runner/openapi.json index a494c110..57dcc4a0 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -79,7 +79,12 @@ } } } - } + }, + "security": [ + { + "HTTPBearer": [] + } + ] } }, "/image-to-image": { @@ -137,7 +142,12 @@ } } } - } + }, + "security": [ + { + "HTTPBearer": [] + } + ] } }, "/image-to-video": { @@ -195,7 +205,12 @@ } } } - } + }, + "security": [ + { + "HTTPBearer": [] + } + ] } } }, @@ -477,6 +492,12 @@ ], "title": "VideoResponse" } + }, + "securitySchemes": { + "HTTPBearer": { + "type": "http", + "scheme": "bearer" + } } } } \ No newline at end of file diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 87d27693..9e2e74b3 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -22,6 +22,10 @@ import ( openapi_types "github.com/oapi-codegen/runtime/types" ) +const ( + HTTPBearerScopes = "HTTPBearer.Scopes" +) + // APIError defines model for APIError. type APIError struct { Msg string `json:"msg"` @@ -894,6 +898,8 @@ func (siw *ServerInterfaceWrapper) Health(w http.ResponseWriter, r *http.Request func (siw *ServerInterfaceWrapper) ImageToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { siw.Handler.ImageToImage(w, r) })) @@ -909,6 +915,8 @@ func (siw *ServerInterfaceWrapper) ImageToImage(w http.ResponseWriter, r *http.R func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { siw.Handler.ImageToVideo(w, r) })) @@ -924,6 +932,8 @@ func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.R func (siw *ServerInterfaceWrapper) TextToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { siw.Handler.TextToImage(w, r) })) @@ -1067,25 +1077,25 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xX32/bNhD+Vwhuj07seM0y+C3Zrxpb1iD2uoeiMBjpLLOVSI4/vBqG//eBpC1RojTJ", - "Q5thQJ5iSce77+6++3jZ44QXgjNgWuHZHqtkAwVxP28f5j9KyaX9LSQXIDUF96VQmf2jqc4Bz/C9yvAI", - "652wD0pLyjJ8OIywhD8NlZDi2Tt35P2oPFL6Ls/xpw+QaHwY4Tue7la0IBmsND/+aDwKrnQMKzM0JSyB", - "lUqIjbLHKayJyTWe3VxeV8F/PtqhhbMrITBTPIG0EFwU62DNZUE0nuEnyojc4crJ3JlEaY9wwVPIVzSt", - "xcfByXtrgOZp22EGGdF0CysheSF0p4/fjnbowdu1uTKFr5ZaCZBtDq8Cf6ZALiOFHkBGXinTkPnSVH5O", - "Z7shKIA0tFzY5zanSktgmd7U4E0uv6sALk4WUbcaRBMnNL6HAeeG8qqXkluaAm8+tlNyLVSdhxWcn4Rq", - "rcUGaLapN+r65tvq3Gv/ve3of0bbgmvK2erJJB9BN51cTW9CL9YS3TnLmrcgD8apghUx2aqDGJNpQF1r", - "jG5Nhro5cgYV/6JpI9zVZPqqCveH+x6fbNCwh33dFGph3+vl8qFDiVPQhOb219cS1niGvxpXej4+ivm4", - "VNsmyuPxAGYVqwPIW5LTlNgm9kKiGgrVh63p71Bh+cF7KoEQKcnO5RCibTpoww0k15vvN5B8jPEqTbSp", - "Tyl+8wsOpccZtN1w1UxWAVriu6F7BCU4UxAj8Co9uGL3kFIS1skLd1udIkaqsNd1WC24faS4YkNnycg8", - "tPtd5r17gnE2LkKA1ANpQbiET3rJXSIPRBJfvC+1FVTKPECLX9aA89eAUnvPFNsjmIAwMS9ayNMrZTlP", - "alNJ2O7NGs/e7aMc9xHE98GA/soTFyYa0VG0SoNSHRe0f1GZOsxoad/2DZXNw4c6WgaVGiCfb+3t1C1f", - "a0mKhnydqWONmpQbknfco2vH8GFKNbxRQtYBZWvuh0AlkgrXnBm+ZYgIkVPfLaQ5koah2zkSVEBOmQdz", - "airdggCQ9vujYQxs7bYglfc1uby6nNhsuABGBMUz/I17NcKC6I2rznjj7g2nUuCGydbVBZ+n5bWCbb4+", - "GXdqOpnYPwlnGpg7FYAef1A2/Okfub4ehBeXK0y9IAuTJKDU2uSorKe1UqYo7F5ZQrQvx05mLjS/KPfQ", - "01JcT8uN5XE6sW8mKG0XpEZehck1FUTqsV1oL1KiyfDUhq77hzqhtDRw+IIVr1+6Q2s+wq8+Z9fLJa8l", - "/h1J0aNviYs7nX7WuNG+FyOoTFC5E14/V/pzpkEykqMFyC1IVC3OFeldD9GS+7uyQX63y/eS32nUc5G/", - "+7+NZyZ/XZlfyP+/Jr+nsCO/hk96gPAHW9k/Uv/fZxfvfS/y/sLwMxluSRSq++HwdwAAAP//Jp1yTCYX", - "AAA=", + "H4sIAAAAAAAC/+xXS2/jNhD+KwTbozd23E1T+Jb0tUabbhC720MQGIw0lrkrkSw5TNcI/N8LkrZEvSqn", + "yKZAkZNew5lvZr556JEmslBSgEBDZ4/UJBsomL+9uJ7/qLXU7l5pqUAjB/+lMJm7IMcc6IxemYyOKG6V", + "ezCoucjobjeiGv60XENKZ7f+yN2oPFLqLs/J+4+QIN2N6KVMtytesAxWKPc3jUclDbZhZZanTCSwMglz", + "Vh5pCmtmc6Sz85OzyvjPezmy8HIlBGGLe9AOgrfiFKylLhjSGb3nguktrZTMvUjL7REtZAr5iqc1+zQ6", + "eeUEyDztOiwgY8gfYKW0LBT26vhtL0eug1yXKluEaJmVAt2l8DTSZwviPTLkGnRLKxcIWQhNpedwth+C", + "AUhjyYV77lJqUIPIcFODNzn5rgK4OEi0stUgmjqgCTmMOHcsrwYp+cBTkM3HbkqulanzsILzkzKdsdgA", + "zzb1RJ2df1udexe+dx39z2hbSORSrO5t8gmwqeR0eh5rcZLk0kvWtEV+CMkNrJjNVj3EmEwj6jphcmEz", + "0s+RJ1DxL542zJ1Opm8rc3/47+2TDRoOsK+fQh3se7dcXvd04hSQ8dzdfa1hTWf0q3HVz8f7Zj4uu20T", + "5f54BLOy1QPkA8t5ylwSByFxhMIMYWvq21VYfgiaSiBMa7b1PsRomwq6cAPLcfP9BpJPbbwGGdp6ldL3", + "v9C49XiBrglX1WRloMO+L7obMEoKA20EoUsfHbErSDmL4xQad1ecWow0ca7rsDpwB0vtiB1bS1bnsdzv", + "Oh/cE6yX8RYipAFIB8IlfMal9I5cM81C8L7UVlB15iN68esa8PQ1oOy9T2y2ezARYdq86CDPYCvLZVKr", + "Sia279d0dvvY8vGxBfEuKtBfZeLNtEp01FqlwZieAR1eVKIeM1m6t0NF5fwIpvaSUaSOaJ8f3HTqb19r", + "zYpG+3piH2vEpNyQguKBvrY3H7tUw9tyyDMysZrjduGgBOxulFwC06DL3yB36D68KpVsEBXdOR1crGWo", + "I5Nornx+Z/RCEKZUzkPCCUqirSAXc6K4gpyL4M+BF/wBFIB232+sEN7QA2gTdE1OTk8mLiBSgWCK0xn9", + "xr8aUcVw42GPN370+EYHvh5darzxeVpOJupCFuLhT00nE3dJpEAQ/lQEevzROPOHf8GhNMazzwemHpCF", + "TRIwZm1zUqbEp8AWhVtNS4ju5dh3qjco35Sr7GGvrrvlK3tf4DTwAQy6HavhV2Fz5IppHLud+E3KkB3v", + "2rF/DLs6J1Fb2H3BiNfn9rExH9G3z5n1ck/ssH/JUnITUuLtTqfPare1MrYRVCKkXCvPXsr9uUDQguVk", + "AfoBNKl270Pf8TMk7ji3d7u7uCZ8islShmncqA3/tzBYG74LvlRt9P/PvHBt1Hv/a238n2sjMNzXBsJn", + "PGJsRGvhP1bGv3e+vXi+DofXAnjeAnAci2fDbvd3AAAA//92bsyPxhcAAA==", } // GetSwagger returns the content of the embedded swagger specification file