Skip to content

Commit

Permalink
Merge pull request #72 from bento-platform/quart-support
Browse files Browse the repository at this point in the history
Quart support
  • Loading branch information
davidlougheed authored Nov 4, 2022
2 parents 8bbfd71 + 5db8c41 commit 495521a
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .idea/bento_lib.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions bento_lib/auth/quart_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os

from functools import wraps
from quart import request
from typing import Callable, Union

from bento_lib.auth.headers import BENTO_USER_HEADER, BENTO_USER_ROLE_HEADER
from bento_lib.auth.roles import ROLE_OWNER, ROLE_USER
from bento_lib.responses.quart_errors import quart_forbidden_error


__all__ = [
"quart_permissions",
"quart_permissions_any_user",
"quart_permissions_owner",
]


# TODO: Centralize this
BENTO_DEBUG = os.environ.get("CHORD_DEBUG", "true").lower() == "true"
BENTO_PERMISSIONS = os.environ.get("CHORD_PERMISSIONS", str(not BENTO_DEBUG)).lower() == "true"


def _check_roles(headers, roles: Union[set, dict]) -> bool:
method_roles = roles if not isinstance(roles, dict) else roles.get(request.method, set())
return (
not BENTO_PERMISSIONS or
len(method_roles) == 0 or
(BENTO_USER_HEADER in headers and headers.get(BENTO_USER_ROLE_HEADER, "") in method_roles)
)


def quart_permissions(method_roles: Union[set, dict]) -> Callable:
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if not _check_roles(request.headers, method_roles):
return quart_forbidden_error()
return await func(*args, **kwargs)
return wrapper
return decorator


quart_permissions_any_user = quart_permissions({ROLE_USER, ROLE_OWNER})
quart_permissions_owner = quart_permissions({ROLE_OWNER})
2 changes: 1 addition & 1 deletion bento_lib/package.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = bento_lib
version = 4.0.0
version = 4.1.0
authors = David Lougheed, Paul Pillot
author_emails = [email protected], [email protected]
5 changes: 3 additions & 2 deletions bento_lib/responses/flask_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def flask_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable:
service_name = kwargs.pop("service_name", "Bento Service")

# TODO: pass exception?
def handle_error(_e):
def handle_error(e):
print(f"[{service_name}] Encountered error:", file=sys.stderr)
traceback.print_exc()
# TODO: py3.10: print_exception(e)
traceback.print_exception(type(e), e, e.__traceback__)
return fn(*args, **kwargs)
return handle_error

Expand Down
71 changes: 71 additions & 0 deletions bento_lib/responses/quart_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import sys
import traceback

from quart import jsonify
from functools import partial
from typing import Callable

from bento_lib.responses import errors


__all__ = [
"quart_error_wrap_with_traceback",
"quart_error_wrap",

"quart_error",

"quart_bad_request_error",
"quart_unauthorized_error",
"quart_forbidden_error",
"quart_not_found_error",

"quart_internal_server_error",
"quart_not_implemented_error",
]


# noinspection PyIncorrectDocstring
def quart_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable:
"""
Function to wrap quart_* error creators with something that supports the application.register_error_handler method,
while also printing a traceback. Optionally, the keyword argument service_name can be passed in to make the error
logging more precise.
:param fn: The quart error-generating function to wrap
:return: The wrapped function
"""

service_name = kwargs.pop("service_name", "Bento Service")

# TODO: pass exception?
def handle_error(e):
print(f"[{service_name}] Encountered error:", file=sys.stderr)
# TODO: py3.10: print_exception(e)
traceback.print_exception(type(e), e, e.__traceback__)
return fn(*args, **kwargs)
return handle_error


def quart_error_wrap(fn: Callable, *args, **kwargs) -> Callable:
"""
Function to wrap quart_* error creators with something that supports the application.register_error_handler method.
:param fn: The quart error-generating function to wrap
:return: The wrapped function
"""
return lambda _e: fn(*args, **kwargs)


def quart_error(code: int, *errs, drs_compat: bool = False, sr_compat: bool = False):
return jsonify(errors.http_error(code, *errs, drs_compat=drs_compat, sr_compat=sr_compat)), code


def _quart_error(code: int) -> Callable:
return partial(quart_error, code)


quart_bad_request_error = _quart_error(400)
quart_unauthorized_error = _quart_error(401)
quart_forbidden_error = _quart_error(403)
quart_not_found_error = _quart_error(404)

quart_internal_server_error = _quart_error(500)
quart_not_implemented_error = _quart_error(501)
2 changes: 1 addition & 1 deletion bento_lib/search/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _resolve(args: q.Args, params: tuple, schema: JSONSchema, _internal: bool =
sql.Identifier(f_id) if f_id is not None else sql.SQL("*")), params


def _list(args: q.Args, params: tuple, schema: JSONSchema, _internal: bool = False) -> SQLComposableWithParams:
def _list(args: q.Args, params: tuple, _schema: JSONSchema, _internal: bool = False) -> SQLComposableWithParams:
# :param args: a tuple of query.Literal objects to be used in an IN clause.
# with psycopg2, it must be passed as a tuple of tuples, hence the enclosing
# parentheses in the following statement.
Expand Down
49 changes: 31 additions & 18 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,52 +1,65 @@
aiofiles==22.1.0
appdirs==1.4.4
asgiref==3.5.2
attrs==21.2.0
attrs==22.1.0
backports.entry-points-selectable==1.1.0
blinker==1.5
certifi==2022.9.24
chardet==4.0.0
charset-normalizer==2.0.4
click==8.0.1
codecov==2.1.12
coverage==5.5
distlib==0.3.2
Django==4.1.1
djangorestframework==3.12.4
distlib==0.3.6
Django==4.1.3
djangorestframework==3.13.1
entrypoints==0.3
filelock==3.0.12
filelock==3.8.0
flake8==3.9.2
Flask==2.0.1
Flask==2.2.2
h11==0.14.0
h2==4.1.0
hpack==4.0.0
hypercorn==0.14.3
hyperframe==6.0.1
idna==2.10
importlib-metadata==4.6.1
importlib-metadata==5.0.0
iniconfig==1.1.1
itsdangerous==2.0.1
Jinja2==3.0.1
jsonschema==3.2.0
MarkupSafe==2.0.1
jsonschema==4.16.0
MarkupSafe==2.1.1
mccabe==0.6.1
more-itertools==8.9.0
packaging==21.0
platformdirs==2.3.0
platformdirs==2.5.2
pluggy==0.13.1
psycopg2==2.9.3
priority==2.0.0
psycopg2==2.9.5
py==1.10.0
pycodestyle==2.7.0
pyflakes==2.3.1
pyparsing==2.4.7
pyrsistent==0.18.0
pytest==6.2.5
pytest-asyncio==0.20.1
pytest-cov==2.12.1
pytest-django==4.4.0
python-dateutil==2.8.2
pytz==2021.1
quart==0.18.3
redis==3.5.3
requests==2.26.0
responses==0.13.4
requests==2.28.1
responses==0.22.0
six==1.16.0
sqlparse==0.4.2
toml==0.10.2
tox==3.24.3
urllib3==1.26.6
virtualenv==20.7.2
tomli==2.0.1
tox==3.27.0
types-toml==0.10.8
urllib3==1.26.12
virtualenv==20.16.6
wcwidth==0.2.5
Werkzeug==2.0.1
zipp==3.5.0
Werkzeug==2.2.2
wsproto==1.2.0
zipp==3.10.0
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

python_requires=">=3.8",
install_requires=[
"jsonschema>=3.2.0,<4",
"jsonschema>=3.2.0,<5",
"psycopg2-binary>=2.8.6,<3.0",
"redis>=3.5.3,<4.0",
"Werkzeug>=2.0.1,<3",
],
extras_require={
"flask": ["Flask>=2.0.1,<3"],
"django": ["Django>=4.1.1,<5", "djangorestframework>=3.13.1,<3.15"]
"django": ["Django>=4.1.1,<5", "djangorestframework>=3.13.1,<3.15"],
"quart": ["quart>=0.18.3,<0.19"],
},

author=config["package"]["authors"],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_platform_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test3():
yield client


def test_flask_forbidden_error(flask_client):
def test_flask_errors(flask_client):
# Turn CHORD permissions mode on to make sure we're getting real permissions checks
fd.BENTO_PERMISSIONS = True

Expand Down
107 changes: 107 additions & 0 deletions tests/test_platform_quart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import asyncio
import bento_lib.auth.quart_decorators as qd
import bento_lib.responses.quart_errors as qe
import pytest
import pytest_asyncio

from quart import Quart
from werkzeug.exceptions import BadRequest, NotFound


@pytest_asyncio.fixture
async def quart_client():
application = Quart(__name__)

application.register_error_handler(Exception, qe.quart_error_wrap_with_traceback(qe.quart_internal_server_error))
application.register_error_handler(BadRequest, qe.quart_error_wrap(qe.quart_bad_request_error))
application.register_error_handler(NotFound, qe.quart_error_wrap(qe.quart_not_found_error, drs_compat=True))

@application.route("/500")
async def r500():
await asyncio.sleep(0.5)
raise Exception("help")

@application.route("/test1")
@qd.quart_permissions_any_user
async def test1():
await asyncio.sleep(0.5)
return "test1"

@application.route("/test2")
@qd.quart_permissions_owner
async def test2():
await asyncio.sleep(0.5)
return "test2"

@application.route("/test3", methods=["GET", "POST"])
@qd.quart_permissions({"POST": {"owner"}})
async def test3():
await asyncio.sleep(0.5)
return "test3"

yield application.test_client()


@pytest.mark.asyncio
async def test_quart_errors(quart_client):
# Turn CHORD permissions mode on to make sure we're getting real permissions checks
qd.BENTO_PERMISSIONS = True

# non-existent endpoint

r = await quart_client.get("/non-existent")
assert r.status_code == 404
rj = await r.get_json()
assert rj["code"] == 404

# - We passed drs_compat=True to this, so check for DRS-specific fields
assert rj["status_code"] == rj["code"]
assert rj["msg"] == rj["message"]

# server error endpoint

r = await quart_client.get("/500")
assert r.status_code == 500
assert (await r.get_json())["code"] == 500

# /test1

r = await quart_client.get("/test1")
assert r.status_code == 403
assert (await r.get_json())["code"] == 403

r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "user"})
assert r.status_code == 200
assert (await r.get_data()).decode("utf-8") == "test1"

r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "owner"})
assert r.status_code == 200
assert (await r.get_data()).decode("utf-8") == "test1"

# /test2

r = await quart_client.get("/test2")
assert r.status_code == 403
assert (await r.get_json())["code"] == 403

r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "user"})
assert r.status_code == 403
assert (await r.get_json())["code"] == 403

r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "owner"})
assert r.status_code == 200
assert (await r.get_data()).decode("utf-8") == "test2"

# /test3

r = await quart_client.get("/test3")
assert r.status_code == 200
assert (await r.get_data()).decode("utf-8") == "test3"

r = await quart_client.post("/test3")
assert r.status_code == 403
assert (await r.get_json())["code"] == 403

r = await quart_client.get("/test3", headers={"X-User": "test", "X-User-Role": "owner"})
assert r.status_code == 200
assert (await r.get_data()).decode("utf-8") == "test3"

0 comments on commit 495521a

Please sign in to comment.