diff --git a/.idea/bento_lib.iml b/.idea/bento_lib.iml
index b39e681..ef906d8 100644
--- a/.idea/bento_lib.iml
+++ b/.idea/bento_lib.iml
@@ -4,7 +4,7 @@
-
+
diff --git a/bento_lib/auth/quart_decorators.py b/bento_lib/auth/quart_decorators.py
new file mode 100644
index 0000000..a3d08c2
--- /dev/null
+++ b/bento_lib/auth/quart_decorators.py
@@ -0,0 +1,45 @@
+import os
+
+from functools import wraps
+from quart import request
+from typing import Callable, Union
+
+from bento_lib.auth.headers import BENTO_USER_HEADER, BENTO_USER_ROLE_HEADER
+from bento_lib.auth.roles import ROLE_OWNER, ROLE_USER
+from bento_lib.responses.quart_errors import quart_forbidden_error
+
+
+__all__ = [
+ "quart_permissions",
+ "quart_permissions_any_user",
+ "quart_permissions_owner",
+]
+
+
+# TODO: Centralize this
+BENTO_DEBUG = os.environ.get("CHORD_DEBUG", "true").lower() == "true"
+BENTO_PERMISSIONS = os.environ.get("CHORD_PERMISSIONS", str(not BENTO_DEBUG)).lower() == "true"
+
+
+def _check_roles(headers, roles: Union[set, dict]) -> bool:
+ method_roles = roles if not isinstance(roles, dict) else roles.get(request.method, set())
+ return (
+ not BENTO_PERMISSIONS or
+ len(method_roles) == 0 or
+ (BENTO_USER_HEADER in headers and headers.get(BENTO_USER_ROLE_HEADER, "") in method_roles)
+ )
+
+
+def quart_permissions(method_roles: Union[set, dict]) -> Callable:
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+ if not _check_roles(request.headers, method_roles):
+ return quart_forbidden_error()
+ return await func(*args, **kwargs)
+ return wrapper
+ return decorator
+
+
+quart_permissions_any_user = quart_permissions({ROLE_USER, ROLE_OWNER})
+quart_permissions_owner = quart_permissions({ROLE_OWNER})
diff --git a/bento_lib/package.cfg b/bento_lib/package.cfg
index ce86a5a..be28f5b 100644
--- a/bento_lib/package.cfg
+++ b/bento_lib/package.cfg
@@ -1,5 +1,5 @@
[package]
name = bento_lib
-version = 4.0.0
+version = 4.1.0
authors = David Lougheed, Paul Pillot
author_emails = david.lougheed@mail.mcgill.ca, paul.pillot@computationalgenomics.ca
diff --git a/bento_lib/responses/flask_errors.py b/bento_lib/responses/flask_errors.py
index e2d1230..b6796dd 100644
--- a/bento_lib/responses/flask_errors.py
+++ b/bento_lib/responses/flask_errors.py
@@ -37,9 +37,10 @@ def flask_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable:
service_name = kwargs.pop("service_name", "Bento Service")
# TODO: pass exception?
- def handle_error(_e):
+ def handle_error(e):
print(f"[{service_name}] Encountered error:", file=sys.stderr)
- traceback.print_exc()
+ # TODO: py3.10: print_exception(e)
+ traceback.print_exception(type(e), e, e.__traceback__)
return fn(*args, **kwargs)
return handle_error
diff --git a/bento_lib/responses/quart_errors.py b/bento_lib/responses/quart_errors.py
new file mode 100644
index 0000000..7a046b9
--- /dev/null
+++ b/bento_lib/responses/quart_errors.py
@@ -0,0 +1,71 @@
+import sys
+import traceback
+
+from quart import jsonify
+from functools import partial
+from typing import Callable
+
+from bento_lib.responses import errors
+
+
+__all__ = [
+ "quart_error_wrap_with_traceback",
+ "quart_error_wrap",
+
+ "quart_error",
+
+ "quart_bad_request_error",
+ "quart_unauthorized_error",
+ "quart_forbidden_error",
+ "quart_not_found_error",
+
+ "quart_internal_server_error",
+ "quart_not_implemented_error",
+]
+
+
+# noinspection PyIncorrectDocstring
+def quart_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable:
+ """
+ Function to wrap quart_* error creators with something that supports the application.register_error_handler method,
+ while also printing a traceback. Optionally, the keyword argument service_name can be passed in to make the error
+ logging more precise.
+ :param fn: The quart error-generating function to wrap
+ :return: The wrapped function
+ """
+
+ service_name = kwargs.pop("service_name", "Bento Service")
+
+ # TODO: pass exception?
+ def handle_error(e):
+ print(f"[{service_name}] Encountered error:", file=sys.stderr)
+ # TODO: py3.10: print_exception(e)
+ traceback.print_exception(type(e), e, e.__traceback__)
+ return fn(*args, **kwargs)
+ return handle_error
+
+
+def quart_error_wrap(fn: Callable, *args, **kwargs) -> Callable:
+ """
+ Function to wrap quart_* error creators with something that supports the application.register_error_handler method.
+ :param fn: The quart error-generating function to wrap
+ :return: The wrapped function
+ """
+ return lambda _e: fn(*args, **kwargs)
+
+
+def quart_error(code: int, *errs, drs_compat: bool = False, sr_compat: bool = False):
+ return jsonify(errors.http_error(code, *errs, drs_compat=drs_compat, sr_compat=sr_compat)), code
+
+
+def _quart_error(code: int) -> Callable:
+ return partial(quart_error, code)
+
+
+quart_bad_request_error = _quart_error(400)
+quart_unauthorized_error = _quart_error(401)
+quart_forbidden_error = _quart_error(403)
+quart_not_found_error = _quart_error(404)
+
+quart_internal_server_error = _quart_error(500)
+quart_not_implemented_error = _quart_error(501)
diff --git a/bento_lib/search/postgres.py b/bento_lib/search/postgres.py
index 6454db5..03499de 100644
--- a/bento_lib/search/postgres.py
+++ b/bento_lib/search/postgres.py
@@ -348,7 +348,7 @@ def _resolve(args: q.Args, params: tuple, schema: JSONSchema, _internal: bool =
sql.Identifier(f_id) if f_id is not None else sql.SQL("*")), params
-def _list(args: q.Args, params: tuple, schema: JSONSchema, _internal: bool = False) -> SQLComposableWithParams:
+def _list(args: q.Args, params: tuple, _schema: JSONSchema, _internal: bool = False) -> SQLComposableWithParams:
# :param args: a tuple of query.Literal objects to be used in an IN clause.
# with psycopg2, it must be passed as a tuple of tuples, hence the enclosing
# parentheses in the following statement.
diff --git a/requirements.txt b/requirements.txt
index d3ed9a1..66e0645 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,52 +1,65 @@
+aiofiles==22.1.0
appdirs==1.4.4
asgiref==3.5.2
-attrs==21.2.0
+attrs==22.1.0
backports.entry-points-selectable==1.1.0
+blinker==1.5
certifi==2022.9.24
chardet==4.0.0
charset-normalizer==2.0.4
click==8.0.1
codecov==2.1.12
coverage==5.5
-distlib==0.3.2
-Django==4.1.1
-djangorestframework==3.12.4
+distlib==0.3.6
+Django==4.1.3
+djangorestframework==3.13.1
entrypoints==0.3
-filelock==3.0.12
+filelock==3.8.0
flake8==3.9.2
-Flask==2.0.1
+Flask==2.2.2
+h11==0.14.0
+h2==4.1.0
+hpack==4.0.0
+hypercorn==0.14.3
+hyperframe==6.0.1
idna==2.10
-importlib-metadata==4.6.1
+importlib-metadata==5.0.0
iniconfig==1.1.1
itsdangerous==2.0.1
Jinja2==3.0.1
-jsonschema==3.2.0
-MarkupSafe==2.0.1
+jsonschema==4.16.0
+MarkupSafe==2.1.1
mccabe==0.6.1
more-itertools==8.9.0
packaging==21.0
-platformdirs==2.3.0
+platformdirs==2.5.2
pluggy==0.13.1
-psycopg2==2.9.3
+priority==2.0.0
+psycopg2==2.9.5
py==1.10.0
pycodestyle==2.7.0
pyflakes==2.3.1
pyparsing==2.4.7
pyrsistent==0.18.0
pytest==6.2.5
+pytest-asyncio==0.20.1
pytest-cov==2.12.1
pytest-django==4.4.0
python-dateutil==2.8.2
pytz==2021.1
+quart==0.18.3
redis==3.5.3
-requests==2.26.0
-responses==0.13.4
+requests==2.28.1
+responses==0.22.0
six==1.16.0
sqlparse==0.4.2
toml==0.10.2
-tox==3.24.3
-urllib3==1.26.6
-virtualenv==20.7.2
+tomli==2.0.1
+tox==3.27.0
+types-toml==0.10.8
+urllib3==1.26.12
+virtualenv==20.16.6
wcwidth==0.2.5
-Werkzeug==2.0.1
-zipp==3.5.0
+Werkzeug==2.2.2
+wsproto==1.2.0
+zipp==3.10.0
diff --git a/setup.py b/setup.py
index ea396ae..77532bc 100644
--- a/setup.py
+++ b/setup.py
@@ -16,14 +16,15 @@
python_requires=">=3.8",
install_requires=[
- "jsonschema>=3.2.0,<4",
+ "jsonschema>=3.2.0,<5",
"psycopg2-binary>=2.8.6,<3.0",
"redis>=3.5.3,<4.0",
"Werkzeug>=2.0.1,<3",
],
extras_require={
"flask": ["Flask>=2.0.1,<3"],
- "django": ["Django>=4.1.1,<5", "djangorestframework>=3.13.1,<3.15"]
+ "django": ["Django>=4.1.1,<5", "djangorestframework>=3.13.1,<3.15"],
+ "quart": ["quart>=0.18.3,<0.19"],
},
author=config["package"]["authors"],
diff --git a/tests/test_platform_flask.py b/tests/test_platform_flask.py
index 484b149..6e6fae1 100644
--- a/tests/test_platform_flask.py
+++ b/tests/test_platform_flask.py
@@ -37,7 +37,7 @@ def test3():
yield client
-def test_flask_forbidden_error(flask_client):
+def test_flask_errors(flask_client):
# Turn CHORD permissions mode on to make sure we're getting real permissions checks
fd.BENTO_PERMISSIONS = True
diff --git a/tests/test_platform_quart.py b/tests/test_platform_quart.py
new file mode 100644
index 0000000..f59e65d
--- /dev/null
+++ b/tests/test_platform_quart.py
@@ -0,0 +1,107 @@
+import asyncio
+import bento_lib.auth.quart_decorators as qd
+import bento_lib.responses.quart_errors as qe
+import pytest
+import pytest_asyncio
+
+from quart import Quart
+from werkzeug.exceptions import BadRequest, NotFound
+
+
+@pytest_asyncio.fixture
+async def quart_client():
+ application = Quart(__name__)
+
+ application.register_error_handler(Exception, qe.quart_error_wrap_with_traceback(qe.quart_internal_server_error))
+ application.register_error_handler(BadRequest, qe.quart_error_wrap(qe.quart_bad_request_error))
+ application.register_error_handler(NotFound, qe.quart_error_wrap(qe.quart_not_found_error, drs_compat=True))
+
+ @application.route("/500")
+ async def r500():
+ await asyncio.sleep(0.5)
+ raise Exception("help")
+
+ @application.route("/test1")
+ @qd.quart_permissions_any_user
+ async def test1():
+ await asyncio.sleep(0.5)
+ return "test1"
+
+ @application.route("/test2")
+ @qd.quart_permissions_owner
+ async def test2():
+ await asyncio.sleep(0.5)
+ return "test2"
+
+ @application.route("/test3", methods=["GET", "POST"])
+ @qd.quart_permissions({"POST": {"owner"}})
+ async def test3():
+ await asyncio.sleep(0.5)
+ return "test3"
+
+ yield application.test_client()
+
+
+@pytest.mark.asyncio
+async def test_quart_errors(quart_client):
+ # Turn CHORD permissions mode on to make sure we're getting real permissions checks
+ qd.BENTO_PERMISSIONS = True
+
+ # non-existent endpoint
+
+ r = await quart_client.get("/non-existent")
+ assert r.status_code == 404
+ rj = await r.get_json()
+ assert rj["code"] == 404
+
+ # - We passed drs_compat=True to this, so check for DRS-specific fields
+ assert rj["status_code"] == rj["code"]
+ assert rj["msg"] == rj["message"]
+
+ # server error endpoint
+
+ r = await quart_client.get("/500")
+ assert r.status_code == 500
+ assert (await r.get_json())["code"] == 500
+
+ # /test1
+
+ r = await quart_client.get("/test1")
+ assert r.status_code == 403
+ assert (await r.get_json())["code"] == 403
+
+ r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "user"})
+ assert r.status_code == 200
+ assert (await r.get_data()).decode("utf-8") == "test1"
+
+ r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "owner"})
+ assert r.status_code == 200
+ assert (await r.get_data()).decode("utf-8") == "test1"
+
+ # /test2
+
+ r = await quart_client.get("/test2")
+ assert r.status_code == 403
+ assert (await r.get_json())["code"] == 403
+
+ r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "user"})
+ assert r.status_code == 403
+ assert (await r.get_json())["code"] == 403
+
+ r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "owner"})
+ assert r.status_code == 200
+ assert (await r.get_data()).decode("utf-8") == "test2"
+
+ # /test3
+
+ r = await quart_client.get("/test3")
+ assert r.status_code == 200
+ assert (await r.get_data()).decode("utf-8") == "test3"
+
+ r = await quart_client.post("/test3")
+ assert r.status_code == 403
+ assert (await r.get_json())["code"] == 403
+
+ r = await quart_client.get("/test3", headers={"X-User": "test", "X-User-Role": "owner"})
+ assert r.status_code == 200
+ assert (await r.get_data()).decode("utf-8") == "test3"