From 6442d2deb868a59edd74959b14523aa9aad2571b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 29 Jul 2024 20:19:00 +0530 Subject: [PATCH 1/4] Refactor auth code to output auth scheme in OpenAPI spec Helpful for generating SDKs --- auth/token_authentication.py | 91 +++++++++++++++++++----------- daras_ai_v2/api_examples_widget.py | 38 ++++++------- 2 files changed, 78 insertions(+), 51 deletions(-) diff --git a/auth/token_authentication.py b/auth/token_authentication.py index b33bbbbd0..483e291b6 100644 --- a/auth/token_authentication.py +++ b/auth/token_authentication.py @@ -1,39 +1,27 @@ -import threading - -from fastapi import Header +from fastapi import Request from fastapi.exceptions import HTTPException +from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType +from fastapi.security.base import SecurityBase +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from app_users.models import AppUser from auth.auth_backend import authlocal from daras_ai_v2 import db from daras_ai_v2.crypto import PBKDF2PasswordHasher -auth_keyword = "Bearer" +class AuthenticationError(HTTPException): + status_code = HTTP_401_UNAUTHORIZED + + def __init__(self, msg: str): + super().__init__(status_code=self.status_code, detail={"error": msg}) -def api_auth_header( - authorization: str = Header( - alias="Authorization", - description=f"{auth_keyword} $GOOEY_API_KEY", - ), -) -> AppUser: - if authlocal: - return authlocal[0] - return authenticate(authorization) +class AuthorizationError(HTTPException): + status_code = HTTP_403_FORBIDDEN -def authenticate(auth_token: str) -> AppUser: - auth = auth_token.split() - if not auth or auth[0].lower() != auth_keyword.lower(): - msg = "Invalid Authorization header." - raise HTTPException(status_code=401, detail={"error": msg}) - if len(auth) == 1: - msg = "Invalid Authorization header. No credentials provided." - raise HTTPException(status_code=401, detail={"error": msg}) - elif len(auth) > 2: - msg = "Invalid Authorization header. Token string should not contain spaces." - raise HTTPException(status_code=401, detail={"error": msg}) - return authenticate_credentials(auth[1]) + def __init__(self, msg: str): + super().__init__(status_code=self.status_code, detail={"error": msg}) def authenticate_credentials(token: str) -> AppUser: @@ -48,12 +36,7 @@ def authenticate_credentials(token: str) -> AppUser: .get()[0] ) except IndexError: - raise HTTPException( - status_code=403, - detail={ - "error": "Invalid API Key.", - }, - ) + raise AuthorizationError("Invalid API Key.") uid = doc.get("uid") user = AppUser.objects.get_or_create_from_uid(uid)[0] @@ -62,6 +45,50 @@ def authenticate_credentials(token: str) -> AppUser: "Your Gooey.AI account has been disabled for violating our Terms of Service. " "Contact us at support@gooey.ai if you think this is a mistake." ) - raise HTTPException(status_code=401, detail={"error": msg}) + raise AuthenticationError(msg) return user + + +class APIAuth(SecurityBase): + """ + ### Usage: + + ```python + api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY") + + @app.get("/api/users") + def get_users(authenticated_user: AppUser = Depends(api_auth)): + ... + ``` + """ + + def __init__(self, scheme_name: str, description: str): + self.model = HTTPBaseModel( + type=SecuritySchemeType.http, scheme=scheme_name, description=description + ) + self.scheme_name = scheme_name + self.description = description + + def __call__(self, request: Request) -> AppUser: + if authlocal: # testing only! + return authlocal[0] + + auth = request.headers.get("Authorization", "").split() + if not auth or auth[0].lower() != self.scheme_name.lower(): + raise AuthenticationError("Invalid Authorization header.") + if len(auth) == 1: + raise AuthenticationError( + "Invalid Authorization header. No credentials provided." + ) + elif len(auth) > 2: + raise AuthenticationError( + "Invalid Authorization header. Token string should not contain spaces." + ) + return authenticate_credentials(auth[1]) + + +auth_scheme = "Bearer" +api_auth_header = APIAuth( + scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY" +) diff --git a/daras_ai_v2/api_examples_widget.py b/daras_ai_v2/api_examples_widget.py index 44086ba02..53d7ac8c4 100644 --- a/daras_ai_v2/api_examples_widget.py +++ b/daras_ai_v2/api_examples_widget.py @@ -6,7 +6,7 @@ from furl import furl import gooey_ui as st -from auth.token_authentication import auth_keyword +from auth.token_authentication import auth_scheme from daras_ai_v2 import settings from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url @@ -48,12 +48,12 @@ def api_example_generator( if as_form_data: curl_code = r""" curl %(api_url)s \ - -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \ + -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \ %(files)s \ -F json=%(json)s """ % dict( api_url=shlex.quote(api_url), - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, files=" \\\n ".join( f"-F {key}=@{shlex.quote(filename)}" for key, filename in filenames ), @@ -62,12 +62,12 @@ def api_example_generator( else: curl_code = r""" curl %(api_url)s \ - -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \ + -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \ -H 'Content-Type: application/json' \ -d %(json)s """ % dict( api_url=shlex.quote(api_url), - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, json=shlex.quote(json.dumps(request_body, indent=2)), ) if as_async: @@ -77,7 +77,7 @@ def api_example_generator( ) while true; do - result=$(curl $status_url -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY") + result=$(curl $status_url -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY") status=$(echo $result | jq -r '.status') if [ "$status" = "completed" ]; then echo $result @@ -91,7 +91,7 @@ def api_example_generator( """ % dict( curl_code=indent(curl_code.strip(), " " * 2), api_url=shlex.quote(api_url), - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, json=shlex.quote(json.dumps(request_body, indent=2)), ) @@ -128,7 +128,7 @@ def api_example_generator( response = requests.post( "%(api_url)s", headers={ - "Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"], + "Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"], }, files=files, data={"json": json.dumps(payload)}, @@ -140,7 +140,7 @@ def api_example_generator( ), json=repr(request_body), api_url=api_url, - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, ) else: py_code = r""" @@ -152,14 +152,14 @@ def api_example_generator( response = requests.post( "%(api_url)s", headers={ - "Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"], + "Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"], }, json=payload, ) assert response.ok, response.content """ % dict( api_url=api_url, - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, json=repr(request_body), ) if as_async: @@ -168,7 +168,7 @@ def api_example_generator( status_url = response.headers["Location"] while True: - response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]}) + response = requests.get(status_url, headers={"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"]}) assert response.ok, response.content result = response.json() if result["status"] == "completed": @@ -181,7 +181,7 @@ def api_example_generator( sleep(3) """ % dict( api_url=api_url, - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, ) else: py_code += r""" @@ -229,7 +229,7 @@ def api_example_generator( const response = await fetch("%(api_url)s", { method: "POST", headers: { - "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], }, body: formData, }); @@ -243,7 +243,7 @@ def api_example_generator( " " * 2, ), api_url=api_url, - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, ) else: @@ -256,14 +256,14 @@ def api_example_generator( const response = await fetch("%(api_url)s", { method: "POST", headers: { - "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], "Content-Type": "application/json", }, body: JSON.stringify(payload), }); """ % dict( api_url=api_url, - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, json=json.dumps(request_body, indent=2), ) @@ -280,7 +280,7 @@ def api_example_generator( const response = await fetch(status_url, { method: "GET", headers: { - "Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"], + "Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"], }, }); if (!response.ok) { @@ -299,7 +299,7 @@ def api_example_generator( } }""" % dict( api_url=api_url, - auth_keyword=auth_keyword, + auth_scheme=auth_scheme, ) else: js_code += """ From 1b5f46f95f5138cc23669e3f08953917ad0b9b27 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:28:42 +0530 Subject: [PATCH 2/4] Add openapi params for fern bearer auth, hide healthcheck from fern --- auth/token_authentication.py | 19 ++++++++++++++----- server.py | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/auth/token_authentication.py b/auth/token_authentication.py index 483e291b6..a1281faa7 100644 --- a/auth/token_authentication.py +++ b/auth/token_authentication.py @@ -1,3 +1,5 @@ +from typing import Any + from fastapi import Request from fastapi.exceptions import HTTPException from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType @@ -55,7 +57,7 @@ class APIAuth(SecurityBase): ### Usage: ```python - api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY") + api_auth = APIAuth(scheme_name="bearer", description="Bearer $GOOEY_API_KEY") @app.get("/api/users") def get_users(authenticated_user: AppUser = Depends(api_auth)): @@ -63,9 +65,14 @@ def get_users(authenticated_user: AppUser = Depends(api_auth)): ``` """ - def __init__(self, scheme_name: str, description: str): + def __init__( + self, scheme_name: str, description: str, openapi_extra: dict[str, Any] = None + ): self.model = HTTPBaseModel( - type=SecuritySchemeType.http, scheme=scheme_name, description=description + type=SecuritySchemeType.http, + scheme=scheme_name, + description=description, + **(openapi_extra or {}), ) self.scheme_name = scheme_name self.description = description @@ -88,7 +95,9 @@ def __call__(self, request: Request) -> AppUser: return authenticate_credentials(auth[1]) -auth_scheme = "Bearer" +auth_scheme = "bearer" api_auth_header = APIAuth( - scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY" + scheme_name=auth_scheme, + description=f"{auth_scheme} $GOOEY_API_KEY", + openapi_extra={"x-fern-bearer": {"name": "apiKey", "env": "GOOEY_API_KEY"}}, ) diff --git a/server.py b/server.py index 2a8744c2a..7879adab2 100644 --- a/server.py +++ b/server.py @@ -92,7 +92,7 @@ async def startup(): limiter.total_tokens = config("MAX_THREADS", default=limiter.total_tokens, cast=int) -@app.get("/", tags=["Misc"]) +@app.get("/", tags=["Misc"], openapi_extra={"x-fern-ignore": True}) async def health(): return "OK" From 63deccfc10cb890a0a9b66d7eca6644235f26e14 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 03:36:15 +0530 Subject: [PATCH 3/4] Add x-fern-sdk-return-value for all status routes --- routers/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/routers/api.py b/routers/api.py index aab15806c..58a0870c3 100644 --- a/routers/api.py +++ b/routers/api.py @@ -283,6 +283,7 @@ def run_api_form_async( operation_id="status__" + page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v3 status)", + openapi_extra={"x-fern-sdk-return-value": "output"}, ) @app.get( os.path.join(endpoint, "status"), From 53a8036e409b7f1bd155f0e854b3f5e771af38f8 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 2 Aug 2024 03:37:35 +0530 Subject: [PATCH 4/4] fern: ignore v2 sync APIs --- routers/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/routers/api.py b/routers/api.py index 58a0870c3..04ce9f84c 100644 --- a/routers/api.py +++ b/routers/api.py @@ -148,6 +148,7 @@ def script_to_api(page_cls: typing.Type[BasePage]): operation_id=page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v2 sync)", + openapi_extra={"x-fern-ignore": True}, ) @app.post( endpoint,