Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor auth code to output auth scheme in OpenAPI spec #418

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions auth/token_authentication.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,27 @@
import threading

from fastapi import Header
from fastapi import Request
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel, SecuritySchemeType
from fastapi.security.base import SecurityBase
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

from app_users.models import AppUser
from auth.auth_backend import authlocal
from daras_ai_v2 import db
from daras_ai_v2.crypto import PBKDF2PasswordHasher

auth_keyword = "Bearer"

class AuthenticationError(HTTPException):
status_code = HTTP_401_UNAUTHORIZED

def __init__(self, msg: str):
super().__init__(status_code=self.status_code, detail={"error": msg})

def api_auth_header(
authorization: str = Header(
alias="Authorization",
description=f"{auth_keyword} $GOOEY_API_KEY",
),
) -> AppUser:
if authlocal:
return authlocal[0]
return authenticate(authorization)

class AuthorizationError(HTTPException):
status_code = HTTP_403_FORBIDDEN

def authenticate(auth_token: str) -> AppUser:
auth = auth_token.split()
if not auth or auth[0].lower() != auth_keyword.lower():
msg = "Invalid Authorization header."
raise HTTPException(status_code=401, detail={"error": msg})
if len(auth) == 1:
msg = "Invalid Authorization header. No credentials provided."
raise HTTPException(status_code=401, detail={"error": msg})
elif len(auth) > 2:
msg = "Invalid Authorization header. Token string should not contain spaces."
raise HTTPException(status_code=401, detail={"error": msg})
return authenticate_credentials(auth[1])
def __init__(self, msg: str):
super().__init__(status_code=self.status_code, detail={"error": msg})


def authenticate_credentials(token: str) -> AppUser:
Expand All @@ -48,12 +36,7 @@ def authenticate_credentials(token: str) -> AppUser:
.get()[0]
)
except IndexError:
raise HTTPException(
status_code=403,
detail={
"error": "Invalid API Key.",
},
)
raise AuthorizationError("Invalid API Key.")

uid = doc.get("uid")
user = AppUser.objects.get_or_create_from_uid(uid)[0]
Expand All @@ -62,6 +45,50 @@ def authenticate_credentials(token: str) -> AppUser:
"Your Gooey.AI account has been disabled for violating our Terms of Service. "
"Contact us at [email protected] if you think this is a mistake."
)
raise HTTPException(status_code=401, detail={"error": msg})
raise AuthenticationError(msg)

return user


class APIAuth(SecurityBase):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a new class instead of using FastAPI's default HTTPBearer because it has terrible error messages and uses 403 for all auth problems.

https://github.com/fastapi/fastapi/blob/0.85.0/fastapi/security/http.py#L113-L133

My changes are compatible with what we already do - nothing has changed behaviourally for the API.

"""
### Usage:

```python
api_auth = APIAuth(scheme_name="Bearer", description="Bearer $GOOEY_API_KEY")

@app.get("/api/users")
def get_users(authenticated_user: AppUser = Depends(api_auth)):
...
```
"""

def __init__(self, scheme_name: str, description: str):
self.model = HTTPBaseModel(
type=SecuritySchemeType.http, scheme=scheme_name, description=description
)
self.scheme_name = scheme_name
self.description = description

def __call__(self, request: Request) -> AppUser:
if authlocal: # testing only!
return authlocal[0]

auth = request.headers.get("Authorization", "").split()
if not auth or auth[0].lower() != self.scheme_name.lower():
raise AuthenticationError("Invalid Authorization header.")
if len(auth) == 1:
raise AuthenticationError(
"Invalid Authorization header. No credentials provided."
)
elif len(auth) > 2:
raise AuthenticationError(
"Invalid Authorization header. Token string should not contain spaces."
)
return authenticate_credentials(auth[1])


auth_scheme = "Bearer"
api_auth_header = APIAuth(
scheme_name=auth_scheme, description=f"{auth_scheme} $GOOEY_API_KEY"
)
38 changes: 19 additions & 19 deletions daras_ai_v2/api_examples_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from furl import furl

import gooey_ui as st
from auth.token_authentication import auth_keyword
from auth.token_authentication import auth_scheme
from daras_ai_v2 import settings
from daras_ai_v2.doc_search_settings_widgets import is_user_uploaded_url

Expand Down Expand Up @@ -48,12 +48,12 @@ def api_example_generator(
if as_form_data:
curl_code = r"""
curl %(api_url)s \
-H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \
-H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \
%(files)s \
-F json=%(json)s
""" % dict(
api_url=shlex.quote(api_url),
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
files=" \\\n ".join(
f"-F {key}=@{shlex.quote(filename)}" for key, filename in filenames
),
Expand All @@ -62,12 +62,12 @@ def api_example_generator(
else:
curl_code = r"""
curl %(api_url)s \
-H "Authorization: %(auth_keyword)s $GOOEY_API_KEY" \
-H "Authorization: %(auth_scheme)s $GOOEY_API_KEY" \
-H 'Content-Type: application/json' \
-d %(json)s
""" % dict(
api_url=shlex.quote(api_url),
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
json=shlex.quote(json.dumps(request_body, indent=2)),
)
if as_async:
Expand All @@ -77,7 +77,7 @@ def api_example_generator(
)

while true; do
result=$(curl $status_url -H "Authorization: %(auth_keyword)s $GOOEY_API_KEY")
result=$(curl $status_url -H "Authorization: %(auth_scheme)s $GOOEY_API_KEY")
status=$(echo $result | jq -r '.status')
if [ "$status" = "completed" ]; then
echo $result
Expand All @@ -91,7 +91,7 @@ def api_example_generator(
""" % dict(
curl_code=indent(curl_code.strip(), " " * 2),
api_url=shlex.quote(api_url),
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
json=shlex.quote(json.dumps(request_body, indent=2)),
)

Expand Down Expand Up @@ -128,7 +128,7 @@ def api_example_generator(
response = requests.post(
"%(api_url)s",
headers={
"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"],
"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"],
},
files=files,
data={"json": json.dumps(payload)},
Expand All @@ -140,7 +140,7 @@ def api_example_generator(
),
json=repr(request_body),
api_url=api_url,
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
)
else:
py_code = r"""
Expand All @@ -152,14 +152,14 @@ def api_example_generator(
response = requests.post(
"%(api_url)s",
headers={
"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"],
"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"],
},
json=payload,
)
assert response.ok, response.content
""" % dict(
api_url=api_url,
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
json=repr(request_body),
)
if as_async:
Expand All @@ -168,7 +168,7 @@ def api_example_generator(

status_url = response.headers["Location"]
while True:
response = requests.get(status_url, headers={"Authorization": "%(auth_keyword)s " + os.environ["GOOEY_API_KEY"]})
response = requests.get(status_url, headers={"Authorization": "%(auth_scheme)s " + os.environ["GOOEY_API_KEY"]})
assert response.ok, response.content
result = response.json()
if result["status"] == "completed":
Expand All @@ -181,7 +181,7 @@ def api_example_generator(
sleep(3)
""" % dict(
api_url=api_url,
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
)
else:
py_code += r"""
Expand Down Expand Up @@ -229,7 +229,7 @@ def api_example_generator(
const response = await fetch("%(api_url)s", {
method: "POST",
headers: {
"Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"],
"Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"],
},
body: formData,
});
Expand All @@ -243,7 +243,7 @@ def api_example_generator(
" " * 2,
),
api_url=api_url,
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
)

else:
Expand All @@ -256,14 +256,14 @@ def api_example_generator(
const response = await fetch("%(api_url)s", {
method: "POST",
headers: {
"Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"],
"Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"],
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
});
""" % dict(
api_url=api_url,
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
json=json.dumps(request_body, indent=2),
)

Expand All @@ -280,7 +280,7 @@ def api_example_generator(
const response = await fetch(status_url, {
method: "GET",
headers: {
"Authorization": "%(auth_keyword)s " + process.env["GOOEY_API_KEY"],
"Authorization": "%(auth_scheme)s " + process.env["GOOEY_API_KEY"],
},
});
if (!response.ok) {
Expand All @@ -299,7 +299,7 @@ def api_example_generator(
}
}""" % dict(
api_url=api_url,
auth_keyword=auth_keyword,
auth_scheme=auth_scheme,
)
else:
js_code += """
Expand Down