diff --git a/.idea/bento_lib.iml b/.idea/bento_lib.iml index b39e681..ef906d8 100644 --- a/.idea/bento_lib.iml +++ b/.idea/bento_lib.iml @@ -4,7 +4,7 @@ - + diff --git a/bento_lib/auth/quart_decorators.py b/bento_lib/auth/quart_decorators.py new file mode 100644 index 0000000..a3d08c2 --- /dev/null +++ b/bento_lib/auth/quart_decorators.py @@ -0,0 +1,45 @@ +import os + +from functools import wraps +from quart import request +from typing import Callable, Union + +from bento_lib.auth.headers import BENTO_USER_HEADER, BENTO_USER_ROLE_HEADER +from bento_lib.auth.roles import ROLE_OWNER, ROLE_USER +from bento_lib.responses.quart_errors import quart_forbidden_error + + +__all__ = [ + "quart_permissions", + "quart_permissions_any_user", + "quart_permissions_owner", +] + + +# TODO: Centralize this +BENTO_DEBUG = os.environ.get("CHORD_DEBUG", "true").lower() == "true" +BENTO_PERMISSIONS = os.environ.get("CHORD_PERMISSIONS", str(not BENTO_DEBUG)).lower() == "true" + + +def _check_roles(headers, roles: Union[set, dict]) -> bool: + method_roles = roles if not isinstance(roles, dict) else roles.get(request.method, set()) + return ( + not BENTO_PERMISSIONS or + len(method_roles) == 0 or + (BENTO_USER_HEADER in headers and headers.get(BENTO_USER_ROLE_HEADER, "") in method_roles) + ) + + +def quart_permissions(method_roles: Union[set, dict]) -> Callable: + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if not _check_roles(request.headers, method_roles): + return quart_forbidden_error() + return await func(*args, **kwargs) + return wrapper + return decorator + + +quart_permissions_any_user = quart_permissions({ROLE_USER, ROLE_OWNER}) +quart_permissions_owner = quart_permissions({ROLE_OWNER}) diff --git a/bento_lib/package.cfg b/bento_lib/package.cfg index ce86a5a..be28f5b 100644 --- a/bento_lib/package.cfg +++ b/bento_lib/package.cfg @@ -1,5 +1,5 @@ [package] name = bento_lib -version = 4.0.0 +version = 4.1.0 authors = David Lougheed, Paul Pillot author_emails = david.lougheed@mail.mcgill.ca, paul.pillot@computationalgenomics.ca diff --git a/bento_lib/responses/flask_errors.py b/bento_lib/responses/flask_errors.py index e2d1230..b6796dd 100644 --- a/bento_lib/responses/flask_errors.py +++ b/bento_lib/responses/flask_errors.py @@ -37,9 +37,10 @@ def flask_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable: service_name = kwargs.pop("service_name", "Bento Service") # TODO: pass exception? - def handle_error(_e): + def handle_error(e): print(f"[{service_name}] Encountered error:", file=sys.stderr) - traceback.print_exc() + # TODO: py3.10: print_exception(e) + traceback.print_exception(type(e), e, e.__traceback__) return fn(*args, **kwargs) return handle_error diff --git a/bento_lib/responses/quart_errors.py b/bento_lib/responses/quart_errors.py new file mode 100644 index 0000000..7a046b9 --- /dev/null +++ b/bento_lib/responses/quart_errors.py @@ -0,0 +1,71 @@ +import sys +import traceback + +from quart import jsonify +from functools import partial +from typing import Callable + +from bento_lib.responses import errors + + +__all__ = [ + "quart_error_wrap_with_traceback", + "quart_error_wrap", + + "quart_error", + + "quart_bad_request_error", + "quart_unauthorized_error", + "quart_forbidden_error", + "quart_not_found_error", + + "quart_internal_server_error", + "quart_not_implemented_error", +] + + +# noinspection PyIncorrectDocstring +def quart_error_wrap_with_traceback(fn: Callable, *args, **kwargs) -> Callable: + """ + Function to wrap quart_* error creators with something that supports the application.register_error_handler method, + while also printing a traceback. Optionally, the keyword argument service_name can be passed in to make the error + logging more precise. + :param fn: The quart error-generating function to wrap + :return: The wrapped function + """ + + service_name = kwargs.pop("service_name", "Bento Service") + + # TODO: pass exception? + def handle_error(e): + print(f"[{service_name}] Encountered error:", file=sys.stderr) + # TODO: py3.10: print_exception(e) + traceback.print_exception(type(e), e, e.__traceback__) + return fn(*args, **kwargs) + return handle_error + + +def quart_error_wrap(fn: Callable, *args, **kwargs) -> Callable: + """ + Function to wrap quart_* error creators with something that supports the application.register_error_handler method. + :param fn: The quart error-generating function to wrap + :return: The wrapped function + """ + return lambda _e: fn(*args, **kwargs) + + +def quart_error(code: int, *errs, drs_compat: bool = False, sr_compat: bool = False): + return jsonify(errors.http_error(code, *errs, drs_compat=drs_compat, sr_compat=sr_compat)), code + + +def _quart_error(code: int) -> Callable: + return partial(quart_error, code) + + +quart_bad_request_error = _quart_error(400) +quart_unauthorized_error = _quart_error(401) +quart_forbidden_error = _quart_error(403) +quart_not_found_error = _quart_error(404) + +quart_internal_server_error = _quart_error(500) +quart_not_implemented_error = _quart_error(501) diff --git a/bento_lib/search/postgres.py b/bento_lib/search/postgres.py index 6454db5..03499de 100644 --- a/bento_lib/search/postgres.py +++ b/bento_lib/search/postgres.py @@ -348,7 +348,7 @@ def _resolve(args: q.Args, params: tuple, schema: JSONSchema, _internal: bool = sql.Identifier(f_id) if f_id is not None else sql.SQL("*")), params -def _list(args: q.Args, params: tuple, schema: JSONSchema, _internal: bool = False) -> SQLComposableWithParams: +def _list(args: q.Args, params: tuple, _schema: JSONSchema, _internal: bool = False) -> SQLComposableWithParams: # :param args: a tuple of query.Literal objects to be used in an IN clause. # with psycopg2, it must be passed as a tuple of tuples, hence the enclosing # parentheses in the following statement. diff --git a/requirements.txt b/requirements.txt index d3ed9a1..66e0645 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,52 +1,65 @@ +aiofiles==22.1.0 appdirs==1.4.4 asgiref==3.5.2 -attrs==21.2.0 +attrs==22.1.0 backports.entry-points-selectable==1.1.0 +blinker==1.5 certifi==2022.9.24 chardet==4.0.0 charset-normalizer==2.0.4 click==8.0.1 codecov==2.1.12 coverage==5.5 -distlib==0.3.2 -Django==4.1.1 -djangorestframework==3.12.4 +distlib==0.3.6 +Django==4.1.3 +djangorestframework==3.13.1 entrypoints==0.3 -filelock==3.0.12 +filelock==3.8.0 flake8==3.9.2 -Flask==2.0.1 +Flask==2.2.2 +h11==0.14.0 +h2==4.1.0 +hpack==4.0.0 +hypercorn==0.14.3 +hyperframe==6.0.1 idna==2.10 -importlib-metadata==4.6.1 +importlib-metadata==5.0.0 iniconfig==1.1.1 itsdangerous==2.0.1 Jinja2==3.0.1 -jsonschema==3.2.0 -MarkupSafe==2.0.1 +jsonschema==4.16.0 +MarkupSafe==2.1.1 mccabe==0.6.1 more-itertools==8.9.0 packaging==21.0 -platformdirs==2.3.0 +platformdirs==2.5.2 pluggy==0.13.1 -psycopg2==2.9.3 +priority==2.0.0 +psycopg2==2.9.5 py==1.10.0 pycodestyle==2.7.0 pyflakes==2.3.1 pyparsing==2.4.7 pyrsistent==0.18.0 pytest==6.2.5 +pytest-asyncio==0.20.1 pytest-cov==2.12.1 pytest-django==4.4.0 python-dateutil==2.8.2 pytz==2021.1 +quart==0.18.3 redis==3.5.3 -requests==2.26.0 -responses==0.13.4 +requests==2.28.1 +responses==0.22.0 six==1.16.0 sqlparse==0.4.2 toml==0.10.2 -tox==3.24.3 -urllib3==1.26.6 -virtualenv==20.7.2 +tomli==2.0.1 +tox==3.27.0 +types-toml==0.10.8 +urllib3==1.26.12 +virtualenv==20.16.6 wcwidth==0.2.5 -Werkzeug==2.0.1 -zipp==3.5.0 +Werkzeug==2.2.2 +wsproto==1.2.0 +zipp==3.10.0 diff --git a/setup.py b/setup.py index ea396ae..77532bc 100644 --- a/setup.py +++ b/setup.py @@ -16,14 +16,15 @@ python_requires=">=3.8", install_requires=[ - "jsonschema>=3.2.0,<4", + "jsonschema>=3.2.0,<5", "psycopg2-binary>=2.8.6,<3.0", "redis>=3.5.3,<4.0", "Werkzeug>=2.0.1,<3", ], extras_require={ "flask": ["Flask>=2.0.1,<3"], - "django": ["Django>=4.1.1,<5", "djangorestframework>=3.13.1,<3.15"] + "django": ["Django>=4.1.1,<5", "djangorestframework>=3.13.1,<3.15"], + "quart": ["quart>=0.18.3,<0.19"], }, author=config["package"]["authors"], diff --git a/tests/test_platform_flask.py b/tests/test_platform_flask.py index 484b149..6e6fae1 100644 --- a/tests/test_platform_flask.py +++ b/tests/test_platform_flask.py @@ -37,7 +37,7 @@ def test3(): yield client -def test_flask_forbidden_error(flask_client): +def test_flask_errors(flask_client): # Turn CHORD permissions mode on to make sure we're getting real permissions checks fd.BENTO_PERMISSIONS = True diff --git a/tests/test_platform_quart.py b/tests/test_platform_quart.py new file mode 100644 index 0000000..f59e65d --- /dev/null +++ b/tests/test_platform_quart.py @@ -0,0 +1,107 @@ +import asyncio +import bento_lib.auth.quart_decorators as qd +import bento_lib.responses.quart_errors as qe +import pytest +import pytest_asyncio + +from quart import Quart +from werkzeug.exceptions import BadRequest, NotFound + + +@pytest_asyncio.fixture +async def quart_client(): + application = Quart(__name__) + + application.register_error_handler(Exception, qe.quart_error_wrap_with_traceback(qe.quart_internal_server_error)) + application.register_error_handler(BadRequest, qe.quart_error_wrap(qe.quart_bad_request_error)) + application.register_error_handler(NotFound, qe.quart_error_wrap(qe.quart_not_found_error, drs_compat=True)) + + @application.route("/500") + async def r500(): + await asyncio.sleep(0.5) + raise Exception("help") + + @application.route("/test1") + @qd.quart_permissions_any_user + async def test1(): + await asyncio.sleep(0.5) + return "test1" + + @application.route("/test2") + @qd.quart_permissions_owner + async def test2(): + await asyncio.sleep(0.5) + return "test2" + + @application.route("/test3", methods=["GET", "POST"]) + @qd.quart_permissions({"POST": {"owner"}}) + async def test3(): + await asyncio.sleep(0.5) + return "test3" + + yield application.test_client() + + +@pytest.mark.asyncio +async def test_quart_errors(quart_client): + # Turn CHORD permissions mode on to make sure we're getting real permissions checks + qd.BENTO_PERMISSIONS = True + + # non-existent endpoint + + r = await quart_client.get("/non-existent") + assert r.status_code == 404 + rj = await r.get_json() + assert rj["code"] == 404 + + # - We passed drs_compat=True to this, so check for DRS-specific fields + assert rj["status_code"] == rj["code"] + assert rj["msg"] == rj["message"] + + # server error endpoint + + r = await quart_client.get("/500") + assert r.status_code == 500 + assert (await r.get_json())["code"] == 500 + + # /test1 + + r = await quart_client.get("/test1") + assert r.status_code == 403 + assert (await r.get_json())["code"] == 403 + + r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "user"}) + assert r.status_code == 200 + assert (await r.get_data()).decode("utf-8") == "test1" + + r = await quart_client.get("/test1", headers={"X-User": "test", "X-User-Role": "owner"}) + assert r.status_code == 200 + assert (await r.get_data()).decode("utf-8") == "test1" + + # /test2 + + r = await quart_client.get("/test2") + assert r.status_code == 403 + assert (await r.get_json())["code"] == 403 + + r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "user"}) + assert r.status_code == 403 + assert (await r.get_json())["code"] == 403 + + r = await quart_client.get("/test2", headers={"X-User": "test", "X-User-Role": "owner"}) + assert r.status_code == 200 + assert (await r.get_data()).decode("utf-8") == "test2" + + # /test3 + + r = await quart_client.get("/test3") + assert r.status_code == 200 + assert (await r.get_data()).decode("utf-8") == "test3" + + r = await quart_client.post("/test3") + assert r.status_code == 403 + assert (await r.get_json())["code"] == 403 + + r = await quart_client.get("/test3", headers={"X-User": "test", "X-User-Role": "owner"}) + assert r.status_code == 200 + assert (await r.get_data()).decode("utf-8") == "test3"