-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #72 from bento-platform/quart-support
Quart support
- Loading branch information
Showing
10 changed files
with
264 additions
and
26 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
|
||
from functools import wraps | ||
from quart import request | ||
from typing import Callable, Union | ||
|
||
from bento_lib.auth.headers import BENTO_USER_HEADER, BENTO_USER_ROLE_HEADER | ||
from bento_lib.auth.roles import ROLE_OWNER, ROLE_USER | ||
from bento_lib.responses.quart_errors import quart_forbidden_error | ||
|
||
|
||
__all__ = [ | ||
"quart_permissions", | ||
"quart_permissions_any_user", | ||
"quart_permissions_owner", | ||
] | ||
|
||
|
||
# TODO: Centralize this | ||
BENTO_DEBUG = os.environ.get("CHORD_DEBUG", "true").lower() == "true" | ||
BENTO_PERMISSIONS = os.environ.get("CHORD_PERMISSIONS", str(not BENTO_DEBUG)).lower() == "true" | ||
|
||
|
||
def _check_roles(headers, roles: Union[set, dict]) -> bool: | ||
method_roles = roles if not isinstance(roles, dict) else roles.get(request.method, set()) | ||
return ( | ||
not BENTO_PERMISSIONS or | ||
len(method_roles) == 0 or | ||
(BENTO_USER_HEADER in headers and headers.get(BENTO_USER_ROLE_HEADER, "") in method_roles) | ||
) | ||
|
||
|
||
def quart_permissions(method_roles: Union[set, dict]) -> Callable: | ||
def decorator(func): | ||
@wraps(func) | ||
async def wrapper(*args, **kwargs): | ||
if not _check_roles(request.headers, method_roles): | ||
return quart_forbidden_error() | ||
return await func(*args, **kwargs) | ||
return wrapper | ||
return decorator | ||
|
||
|
||
quart_permissions_any_user = quart_permissions({ROLE_USER, ROLE_OWNER}) | ||
quart_permissions_owner = quart_permissions({ROLE_OWNER}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[package] | ||
name = bento_lib | ||
version = 4.0.0 | ||
version = 4.1.0 | ||
authors = David Lougheed, Paul Pillot | ||
author_emails = [email protected], [email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import sys | ||
import traceback | ||
|
||
from quart import jsonify | ||
from functools import partial | ||
from typing import Callable | ||
|
||
from bento_lib.responses import errors | ||
|
||
|
||
__all__ = [ | ||
"quart_error_wrap_with_traceback", | ||
"quart_error_wrap", | ||
|
||
"quart_error", | ||
|
||
"quart_bad_request_error", | ||
"quart_unauthorized_error", | ||
"quart_forbidden_error", | ||
"quart_not_found_error", | ||
|
||
"quart_internal_server_error", | ||
"quart_not_implemented_error", | ||
] | ||
|
||
|
||
# noinspection PyIncorrectDocstring | ||
def quart_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable: | ||
""" | ||
Function to wrap quart_* error creators with something that supports the application.register_error_handler method, | ||
while also printing a traceback. Optionally, the keyword argument service_name can be passed in to make the error | ||
logging more precise. | ||
:param fn: The quart error-generating function to wrap | ||
:return: The wrapped function | ||
""" | ||
|
||
service_name = kwargs.pop("service_name", "Bento Service") | ||
|
||
# TODO: pass exception? | ||
def handle_error(e): | ||
print(f"[{service_name}] Encountered error:", file=sys.stderr) | ||
# TODO: py3.10: print_exception(e) | ||
traceback.print_exception(type(e), e, e.__traceback__) | ||
return fn(*args, **kwargs) | ||
return handle_error | ||
|
||
|
||
def quart_error_wrap(fn: Callable, *args, **kwargs) -> Callable: | ||
""" | ||
Function to wrap quart_* error creators with something that supports the application.register_error_handler method. | ||
:param fn: The quart error-generating function to wrap | ||
:return: The wrapped function | ||
""" | ||
return lambda _e: fn(*args, **kwargs) | ||
|
||
|
||
def quart_error(code: int, *errs, drs_compat: bool = False, sr_compat: bool = False): | ||
return jsonify(errors.http_error(code, *errs, drs_compat=drs_compat, sr_compat=sr_compat)), code | ||
|
||
|
||
def _quart_error(code: int) -> Callable: | ||
return partial(quart_error, code) | ||
|
||
|
||
quart_bad_request_error = _quart_error(400) | ||
quart_unauthorized_error = _quart_error(401) | ||
quart_forbidden_error = _quart_error(403) | ||
quart_not_found_error = _quart_error(404) | ||
|
||
quart_internal_server_error = _quart_error(500) | ||
quart_not_implemented_error = _quart_error(501) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,65 @@ | ||
aiofiles==22.1.0 | ||
appdirs==1.4.4 | ||
asgiref==3.5.2 | ||
attrs==21.2.0 | ||
attrs==22.1.0 | ||
backports.entry-points-selectable==1.1.0 | ||
blinker==1.5 | ||
certifi==2022.9.24 | ||
chardet==4.0.0 | ||
charset-normalizer==2.0.4 | ||
click==8.0.1 | ||
codecov==2.1.12 | ||
coverage==5.5 | ||
distlib==0.3.2 | ||
Django==4.1.1 | ||
djangorestframework==3.12.4 | ||
distlib==0.3.6 | ||
Django==4.1.3 | ||
djangorestframework==3.13.1 | ||
entrypoints==0.3 | ||
filelock==3.0.12 | ||
filelock==3.8.0 | ||
flake8==3.9.2 | ||
Flask==2.0.1 | ||
Flask==2.2.2 | ||
h11==0.14.0 | ||
h2==4.1.0 | ||
hpack==4.0.0 | ||
hypercorn==0.14.3 | ||
hyperframe==6.0.1 | ||
idna==2.10 | ||
importlib-metadata==4.6.1 | ||
importlib-metadata==5.0.0 | ||
iniconfig==1.1.1 | ||
itsdangerous==2.0.1 | ||
Jinja2==3.0.1 | ||
jsonschema==3.2.0 | ||
MarkupSafe==2.0.1 | ||
jsonschema==4.16.0 | ||
MarkupSafe==2.1.1 | ||
mccabe==0.6.1 | ||
more-itertools==8.9.0 | ||
packaging==21.0 | ||
platformdirs==2.3.0 | ||
platformdirs==2.5.2 | ||
pluggy==0.13.1 | ||
psycopg2==2.9.3 | ||
priority==2.0.0 | ||
psycopg2==2.9.5 | ||
py==1.10.0 | ||
pycodestyle==2.7.0 | ||
pyflakes==2.3.1 | ||
pyparsing==2.4.7 | ||
pyrsistent==0.18.0 | ||
pytest==6.2.5 | ||
pytest-asyncio==0.20.1 | ||
pytest-cov==2.12.1 | ||
pytest-django==4.4.0 | ||
python-dateutil==2.8.2 | ||
pytz==2021.1 | ||
quart==0.18.3 | ||
redis==3.5.3 | ||
requests==2.26.0 | ||
responses==0.13.4 | ||
requests==2.28.1 | ||
responses==0.22.0 | ||
six==1.16.0 | ||
sqlparse==0.4.2 | ||
toml==0.10.2 | ||
tox==3.24.3 | ||
urllib3==1.26.6 | ||
virtualenv==20.7.2 | ||
tomli==2.0.1 | ||
tox==3.27.0 | ||
types-toml==0.10.8 | ||
urllib3==1.26.12 | ||
virtualenv==20.16.6 | ||
wcwidth==0.2.5 | ||
Werkzeug==2.0.1 | ||
zipp==3.5.0 | ||
Werkzeug==2.2.2 | ||
wsproto==1.2.0 | ||
zipp==3.10.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import asyncio | ||
import bento_lib.auth.quart_decorators as qd | ||
import bento_lib.responses.quart_errors as qe | ||
import pytest | ||
import pytest_asyncio | ||
|
||
from quart import Quart | ||
from werkzeug.exceptions import BadRequest, NotFound | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def quart_client(): | ||
application = Quart(__name__) | ||
|
||
application.register_error_handler(Exception, qe.quart_error_wrap_with_traceback(qe.quart_internal_server_error)) | ||
application.register_error_handler(BadRequest, qe.quart_error_wrap(qe.quart_bad_request_error)) | ||
application.register_error_handler(NotFound, qe.quart_error_wrap(qe.quart_not_found_error, drs_compat=True)) | ||
|
||
@application.route("/500") | ||
async def r500(): | ||
await asyncio.sleep(0.5) | ||
raise Exception("help") | ||
|
||
@application.route("/test1") | ||
@qd.quart_permissions_any_user | ||
async def test1(): | ||
await asyncio.sleep(0.5) | ||
return "test1" | ||
|
||
@application.route("/test2") | ||
@qd.quart_permissions_owner | ||
async def test2(): | ||
await asyncio.sleep(0.5) | ||
return "test2" | ||
|
||
@application.route("/test3", methods=["GET", "POST"]) | ||
@qd.quart_permissions({"POST": {"owner"}}) | ||
async def test3(): | ||
await asyncio.sleep(0.5) | ||
return "test3" | ||
|
||
yield application.test_client() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_quart_errors(quart_client): | ||
# Turn CHORD permissions mode on to make sure we're getting real permissions checks | ||
qd.BENTO_PERMISSIONS = True | ||
|
||
# non-existent endpoint | ||
|
||
r = await quart_client.get("/non-existent") | ||
assert r.status_code == 404 | ||
rj = await r.get_json() | ||
assert rj["code"] == 404 | ||
|
||
# - We passed drs_compat=True to this, so check for DRS-specific fields | ||
assert rj["status_code"] == rj["code"] | ||
assert rj["msg"] == rj["message"] | ||
|
||
# server error endpoint | ||
|
||
r = await quart_client.get("/500") | ||
assert r.status_code == 500 | ||
assert (await r.get_json())["code"] == 500 | ||
|
||
# /test1 | ||
|
||
r = await quart_client.get("/test1") | ||
assert r.status_code == 403 | ||
assert (await r.get_json())["code"] == 403 | ||
|
||
r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "user"}) | ||
assert r.status_code == 200 | ||
assert (await r.get_data()).decode("utf-8") == "test1" | ||
|
||
r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "owner"}) | ||
assert r.status_code == 200 | ||
assert (await r.get_data()).decode("utf-8") == "test1" | ||
|
||
# /test2 | ||
|
||
r = await quart_client.get("/test2") | ||
assert r.status_code == 403 | ||
assert (await r.get_json())["code"] == 403 | ||
|
||
r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "user"}) | ||
assert r.status_code == 403 | ||
assert (await r.get_json())["code"] == 403 | ||
|
||
r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "owner"}) | ||
assert r.status_code == 200 | ||
assert (await r.get_data()).decode("utf-8") == "test2" | ||
|
||
# /test3 | ||
|
||
r = await quart_client.get("/test3") | ||
assert r.status_code == 200 | ||
assert (await r.get_data()).decode("utf-8") == "test3" | ||
|
||
r = await quart_client.post("/test3") | ||
assert r.status_code == 403 | ||
assert (await r.get_json())["code"] == 403 | ||
|
||
r = await quart_client.get("/test3", headers={"X-User": "test", "X-User-Role": "owner"}) | ||
assert r.status_code == 200 | ||
assert (await r.get_data()).decode("utf-8") == "test3" |