diff --git a/casbin_databases_adapter/adapter.py b/casbin_databases_adapter/adapter.py index 2a7d785..98cde2b 100644 --- a/casbin_databases_adapter/adapter.py +++ b/casbin_databases_adapter/adapter.py @@ -4,8 +4,6 @@ from databases import Database from sqlalchemy import Table -from casbin_databases_adapter.utils import to_sync - class Filter: ptype: List[str] = [] @@ -26,7 +24,6 @@ def __init__(self, db: Database, table: Table, filtered=False): self.table: Table = table self.filtered: bool = filtered - @to_sync() async def load_policy(self, model: Model): query = self.table.select() rows = await self.db.fetch_all(query) @@ -35,7 +32,6 @@ async def load_policy(self, model: Model): line = [v for k, v in row.items() if k in self.cols and v is not None] persist.load_policy_line(", ".join(line), model) - @to_sync() async def save_policy(self, model: Model): await self.db.execute(self.table.delete()) query = self.table.insert() @@ -54,12 +50,10 @@ async def save_policy(self, model: Model): await self.db.execute_many(query, values) return True - @to_sync() async def add_policy(self, sec, p_type, rule): row = self._policy_to_dict(p_type, rule) await self.db.execute(self.table.insert(), row) - @to_sync() async def remove_policy(self, sec, p_type, rule): query = self.table.delete().where(self.table.columns.ptype == p_type) for i, value in enumerate(rule): @@ -69,7 +63,6 @@ async def remove_policy(self, sec, p_type, rule): return True if result > 0 else False - @to_sync() async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): query = self.table.delete().where(self.table.columns.ptype == ptype) if not (0 <= field_index <= 5): @@ -82,7 +75,6 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): result = await self.db.execute(query) return True if result else False - @to_sync() async def load_filtered_policy(self, model: Model, filter_: Filter) -> None: query = self.table.select().order_by(self.table.columns.id) for att, value in filter_.__dict__.items(): diff --git a/casbin_databases_adapter/utils.py b/casbin_databases_adapter/utils.py deleted file mode 100644 index 0aa5c1d..0000000 --- a/casbin_databases_adapter/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import functools -import threading -from asyncio import Task -from typing import Callable, Coroutine - - -class RunThread(threading.Thread): - def __init__(self, func, args, kwargs): - self.func = func - self.args = args - self.kwargs = kwargs - super().__init__() - - def run(self): - self.result = asyncio.run(self.func(*self.args, **self.kwargs)) - - -def to_sync(as_task: bool = True): - """ - A better implementation of `asyncio.run`. - - :param as_task: Forces the future to be scheduled as task (needed for e.g. aiohttp). - - Link: https://stackoverflow.com/a/63593888 - """ - - def _run_async(func: Callable[..., Coroutine]): - """ - :param func: A function that return future or task or call of an async method. - :return: wrapped function - """ - - @functools.wraps(func) - def func_wrapper(*args, **kwargs): - try: - loop = asyncio.get_running_loop() - except RuntimeError: # no event loop running: - loop = asyncio.new_event_loop() - return loop.run_until_complete( - _to_task(func(*args, **kwargs), as_task, loop) - ) - else: - # handle nested event loop with thread - thread = RunThread(func, args, kwargs) - thread.start() - thread.join() - return thread.result - - return func_wrapper - - return _run_async - - -def _to_task(future, as_task, loop): - if not as_task or isinstance(future, Task): - return future - return loop.create_task(future) diff --git a/requirements.txt b/requirements.txt index 5a930ba..8d877fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -casbin>=0.8.1 SQLAlchemy>=1.2.18 -databases>=0.2.6 \ No newline at end of file +databases>=0.2.6 +asynccasbin>=1.1.7 diff --git a/tests/conftest.py b/tests/conftest.py index 741ef44..ac21f92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,4 +63,6 @@ async def enforcer( db: Database, setup_policies, casbin_rule_table: Table, model_conf_path ) -> Enforcer: adapter = DatabasesAdapter(db, table=casbin_rule_table) - return Enforcer(model_conf_path, adapter) + enforcer = Enforcer(model_conf_path, adapter) + await enforcer.load_policy() + return enforcer diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 6aad4b8..7e8efd3 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -4,7 +4,7 @@ from casbin_databases_adapter.adapter import Filter -def test_load_policy(db: Database, enforcer: Enforcer): +async def test_load_policy(db: Database, enforcer: Enforcer): assert enforcer.enforce("alice", "data1", "read") == True assert enforcer.enforce("bob", "data2", "write") == True @@ -16,14 +16,14 @@ def test_load_policy(db: Database, enforcer: Enforcer): assert enforcer.enforce("bob", "data2", "read") == False -def test_add_policy(db: Database, enforcer: Enforcer): +async def test_add_policy(db: Database, enforcer: Enforcer): assert not enforcer.enforce("eve", "data3", "read") - result = enforcer.add_permission_for_user("eve", "data3", "read") + result = await enforcer.add_permission_for_user("eve", "data3", "read") assert result assert enforcer.enforce("eve", "data3", "read") -def test_save_policy(db: Database, enforcer: Enforcer): +async def test_save_policy(db: Database, enforcer: Enforcer): assert not enforcer.enforce("alice", "data4", "read") model: Model = enforcer.get_model() @@ -31,45 +31,45 @@ def test_save_policy(db: Database, enforcer: Enforcer): model.add_policy("p", "p", ["alice", "data4", "read"]) adapter: Adapter = enforcer.get_adapter() - adapter.save_policy(model) + await adapter.save_policy(model) assert enforcer.enforce("alice", "data4", "read") -def test_remove_policy(db: Database, enforcer: Enforcer): +async def test_remove_policy(db: Database, enforcer: Enforcer): assert not (enforcer.enforce("alice", "data5", "read")) - enforcer.add_permission_for_user("alice", "data5", "read") + await enforcer.add_permission_for_user("alice", "data5", "read") assert enforcer.enforce("alice", "data5", "read") - enforcer.delete_permission_for_user("alice", "data5", "read") + await enforcer.delete_permission_for_user("alice", "data5", "read") assert not (enforcer.enforce("alice", "data5", "read")) -def test_remove_filtered_policy(db: Database, enforcer: Enforcer): +async def test_remove_filtered_policy(db: Database, enforcer: Enforcer): assert enforcer.enforce("alice", "data1", "read") - enforcer.remove_filtered_policy(1, "data1") + await enforcer.remove_filtered_policy(1, "data1") assert not (enforcer.enforce("alice", "data1", "read")) assert enforcer.enforce("bob", "data2", "write") assert enforcer.enforce("alice", "data2", "read") assert enforcer.enforce("alice", "data2", "write") - enforcer.remove_filtered_policy(1, "data2", "read") + await enforcer.remove_filtered_policy(1, "data2", "read") assert enforcer.enforce("bob", "data2", "write") assert not (enforcer.enforce("alice", "data2", "read")) assert enforcer.enforce("alice", "data2", "write") - enforcer.remove_filtered_policy(2, "write") + await enforcer.remove_filtered_policy(2, "write") assert not (enforcer.enforce("bob", "data2", "write")) assert not (enforcer.enforce("alice", "data2", "write")) -def test_filtered_policy(db: Database, enforcer: Enforcer): +async def test_filtered_policy(db: Database, enforcer: Enforcer): filter = Filter() filter.ptype = ["p"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert enforcer.enforce("alice", "data1", "read") assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -81,7 +81,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): filter.ptype = [] filter.v0 = ["alice"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert enforcer.enforce("alice", "data1", "read") assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -94,7 +94,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): assert not (enforcer.enforce("data2_admin", "data2", "write")) filter.v0 = ["bob"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert not (enforcer.enforce("alice", "data1", "read")) assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -107,7 +107,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): assert not (enforcer.enforce("data2_admin", "data2", "write")) filter.v0 = ["data2_admin"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert enforcer.enforce("data2_admin", "data2", "read") assert enforcer.enforce("data2_admin", "data2", "read") assert not (enforcer.enforce("alice", "data1", "read")) @@ -120,7 +120,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): assert not (enforcer.enforce("bob", "data2", "write")) filter.v0 = ["alice", "bob"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert enforcer.enforce("alice", "data1", "read") assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -134,7 +134,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): filter.v0 = [] filter.v1 = ["data1"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert enforcer.enforce("alice", "data1", "read") assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -147,7 +147,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): assert not (enforcer.enforce("data2_admin", "data2", "write")) filter.v1 = ["data2"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert not (enforcer.enforce("alice", "data1", "read")) assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -161,7 +161,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): filter.v1 = [] filter.v2 = ["read"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert enforcer.enforce("alice", "data1", "read") assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read")) @@ -174,7 +174,7 @@ def test_filtered_policy(db: Database, enforcer: Enforcer): assert not (enforcer.enforce("data2_admin", "data2", "write")) filter.v2 = ["write"] - enforcer.load_filtered_policy(filter) + await enforcer.load_filtered_policy(filter) assert not (enforcer.enforce("alice", "data1", "read")) assert not (enforcer.enforce("alice", "data1", "write")) assert not (enforcer.enforce("alice", "data2", "read"))