Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- improve thread-safety
- add parameter poll_interval, we poll now loop friendly
- remove knob DATABASEZ_WRAP_IN_THREAD
  • Loading branch information
devkral committed Sep 4, 2024
1 parent 98c1d01 commit 49a801a
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 67 deletions.
97 changes: 64 additions & 33 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

import asyncio
import sys
import typing
import weakref
from contextvars import copy_context
from threading import Event, Lock, Thread, current_thread
from types import TracebackType

Expand All @@ -28,15 +26,17 @@ async def _startup(database: Database, is_initialized: Event) -> None:
await database.connect()
_global_connection = typing.cast(Connection, database._global_connection)
await _global_connection._aenter()
if sys.version_info < (3, 10):
# for old python versions <3.10 the locks must be created in the same event loop
_global_connection._query_lock = asyncio.Lock()
_global_connection._connection_lock = asyncio.Lock()
_global_connection._transaction_lock = asyncio.Lock()
# we ensure fresh locks
_global_connection._query_lock = asyncio.Lock()
_global_connection._connection_lock = asyncio.Lock()
_global_connection._transaction_lock = asyncio.Lock()
is_initialized.set()


def _init_thread(database: Database, is_initialized: Event) -> None:
def _init_thread(database: Database, is_initialized: Event, is_cleared: Event) -> None:
is_cleared.wait()
# now set the flag so new init_threads have to wait
is_cleared.clear()
loop = asyncio.new_event_loop()
# keep reference
task = loop.create_task(_startup(database, is_initialized))
Expand All @@ -45,12 +45,17 @@ def _init_thread(database: Database, is_initialized: Event) -> None:
loop.run_forever()
except RuntimeError:
pass
finally:
# now all inits wait
is_initialized.clear()
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
del task
loop.close()
database._loop = None
database._global_connection._isolation_thread = None # type: ignore
is_cleared.set()


class Connection:
Expand All @@ -60,14 +65,20 @@ def __init__(
self._orig_database = self._database = database
self._full_isolation = full_isolation
self._connection_thread_lock: typing.Optional[Lock] = None
self._connection_thread_is_initialized: typing.Optional[Event] = None
self._connection_thread_is_cleared: typing.Optional[Event] = None
self._isolation_thread: typing.Optional[Thread] = None
if self._full_isolation:
self._connection_thread_lock = Lock()
self._connection_thread_is_initialized = Event()
self._connection_thread_is_cleared = Event()
# initially it is cleared
self._connection_thread_is_cleared.set()
self._database = database.__class__(
database, force_rollback=force_rollback, full_isolation=False
database, force_rollback=force_rollback, full_isolation=False, poll_interval=-1
)
self._database._call_hooks = False
self._database._global_connection = self
self._connection_thread_lock = Lock()
# the asyncio locks are overwritten in python versions < 3.10 when using full_isolation
self._query_lock = asyncio.Lock()
self._connection_lock = asyncio.Lock()
Expand Down Expand Up @@ -108,32 +119,47 @@ async def _aenter(self) -> None:
raise e

async def __aenter__(self) -> Connection:
initialized = False
initialized: bool = False
if self._full_isolation:
assert self._connection_thread_lock is not None
thread: typing.Optional[Thread] = None
assert self._connection_thread_lock is not None
assert self._connection_thread_is_initialized is not None
assert self._connection_thread_is_cleared is not None
with self._connection_thread_lock:
if self._isolation_thread is None:
thread = self._isolation_thread
if thread is None:
initialized = True
is_initialized = Event()
ctx = copy_context()
self._isolation_thread = thread = Thread(
target=ctx.run,
target=_init_thread,
args=[
_init_thread,
self._database,
is_initialized,
self._connection_thread_is_initialized,
self._connection_thread_is_cleared,
],
daemon=True,
)
# must be started with lock held, for setting is_alive
thread.start()
while not is_initialized.wait(1):
assert thread is not None
# bypass for full_isolated
if thread is not current_thread():
if initialized:
while not self._connection_thread_is_initialized.is_set():
if not thread.is_alive():
with self._connection_thread_lock:
self._isolation_thread = None
self._connection_thread_is_initialized.clear()
thread.join(1)
raise Exception("Cannot start full isolation thread")
await asyncio.sleep(self.poll_interval)

else:
# ensure to be not in the isolation thread itself
while not self._connection_thread_is_initialized.is_set():
if not thread.is_alive():
self._isolation_thread = None
break
if thread is not None and not thread.is_alive():
thread.join(10)
raise Exception("Cannot start full isolation thread")
raise Exception("Isolation thread is dead")
await asyncio.sleep(self.poll_interval)

if not initialized:
await self._aenter()
return self
Expand All @@ -159,12 +185,12 @@ async def _aexit_raw(self) -> bool:
async def _aexit(self) -> typing.Optional[Thread]:
if self._full_isolation:
assert self._connection_thread_lock is not None
# the lock must be held on exit
with self._connection_thread_lock:
if await self._aexit_raw():
loop = self._database._loop
thread = self._isolation_thread
# after stopping the _isolation_thread is removed in thread
if loop is not None:
if loop is not None and loop.is_running():
loop.stop()
else:
self._isolation_thread = None
Expand All @@ -180,17 +206,22 @@ async def __aexit__(
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
thread = None
try:
thread = await self._aexit()
finally:
if thread is not None and thread is not current_thread():
thread.join()
thread = await self._aexit()
if thread is not None and thread is not current_thread():
while thread.is_alive():
await asyncio.sleep(self.poll_interval)
thread.join(1)

@property
def _loop(self) -> typing.Any:
def _loop(self) -> typing.Optional[asyncio.AbstractEventLoop]:
return self._database._loop

@property
def poll_interval(self) -> float:
if self._orig_database.poll_interval < 0:
raise RuntimeError("Not supposed to run in the poll path")
return self._orig_database.poll_interval

@property
def _backend(self) -> interfaces.DatabaseBackend:
return self._database.backend
Expand Down
16 changes: 13 additions & 3 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from types import TracebackType

from databasez import interfaces
from databasez.utils import arun_coroutine_threadsafe, multiloop_protector
from databasez.utils import DATABASEZ_POLL_INTERVAL, arun_coroutine_threadsafe, multiloop_protector

from .connection import Connection
from .databaseurl import DatabaseURL
Expand Down Expand Up @@ -144,13 +144,14 @@ class Database:

_connection_map: weakref.WeakKeyDictionary[asyncio.Task, Connection]
_databases_map: typing.Dict[typing.Any, Database]
_loop: typing.Any = None
_loop: typing.Optional[asyncio.AbstractEventLoop] = None
backend: interfaces.DatabaseBackend
url: DatabaseURL
options: typing.Any
is_connected: bool = False
_call_hooks: bool = True
_full_isolation: bool = False
poll_interval: float
_force_rollback: ForceRollback
# descriptor
force_rollback = ForceRollbackDescriptor()
Expand All @@ -162,6 +163,8 @@ def __init__(
force_rollback: typing.Union[bool, None] = None,
config: typing.Optional["DictAny"] = None,
full_isolation: typing.Union[bool, None] = None,
# for
poll_interval: typing.Union[float, None] = None,
**options: typing.Any,
):
init()
Expand All @@ -172,6 +175,8 @@ def __init__(
self.url = url.url
self.options = url.options
self._call_hooks = url._call_hooks
if poll_interval is None:
poll_interval = url.poll_interval
if force_rollback is None:
force_rollback = bool(url.force_rollback)
if full_isolation is None:
Expand All @@ -190,6 +195,9 @@ def __init__(
force_rollback = False
if full_isolation is None:
full_isolation = False
if poll_interval is None:
poll_interval = DATABASEZ_POLL_INTERVAL
self.poll_interval = poll_interval
self._full_isolation = full_isolation
self._force_rollback = ForceRollback(force_rollback)
self.backend.owner = self
Expand Down Expand Up @@ -269,6 +277,8 @@ async def connect(self) -> bool:
"""
loop = asyncio.get_running_loop()
if self._loop is not None and loop != self._loop:
if self.poll_interval < 0:
raise RuntimeError("Subdatabases and polling are disabled")
# copy when not in map
if loop not in self._databases_map:
assert (
Expand Down Expand Up @@ -332,7 +342,7 @@ async def disconnect(
assert not self._databases_map, "sub databases still active, force terminate them"
for sub_database in self._databases_map.values():
await arun_coroutine_threadsafe(
sub_database.disconnect(True), sub_database._loop
sub_database.disconnect(True), sub_database._loop, self.poll_interval
)
self._databases_map = {}
assert not self._databases_map, "sub databases still active"
Expand Down
16 changes: 13 additions & 3 deletions databasez/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlalchemy_utils.functions.orm import quote

from databasez import Database, DatabaseURL
from databasez.utils import ThreadPassingExceptions
from databasez.utils import DATABASEZ_POLL_INTERVAL, ThreadPassingExceptions


async def _get_scalar_result(engine: typing.Any, sql: typing.Any) -> Any:
Expand All @@ -36,6 +36,7 @@ class DatabaseTestClient(Database):
# hooks for overwriting defaults of args with None
testclient_default_full_isolation: bool = True
testclient_default_force_rollback: bool = False
testclient_default_poll_interval: float = DATABASEZ_POLL_INTERVAL
testclient_default_lazy_setup: bool = False
# customization hooks
testclient_default_use_existing: bool = False
Expand All @@ -47,8 +48,9 @@ def __init__(
url: typing.Union[str, "DatabaseURL", "sa.URL", Database],
*,
force_rollback: typing.Union[bool, None] = None,
use_existing: typing.Union[bool, None] = None,
full_isolation: typing.Union[bool, None] = None,
poll_interval: typing.Union[float, None] = None,
use_existing: typing.Union[bool, None] = None,
drop_database: typing.Union[bool, None] = None,
lazy_setup: typing.Union[bool, None] = None,
test_prefix: typing.Union[str, None] = None,
Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(
lazy_setup = self.testclient_default_lazy_setup
if force_rollback is None:
force_rollback = self.testclient_default_force_rollback
if poll_interval is None:
poll_interval = self.testclient_default_poll_interval
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
test_database_url = (
url.replace(database=f"{test_prefix}{url.database}") if test_prefix else url
Expand All @@ -98,7 +102,13 @@ def __init__(
self.setup_protected(self.testclient_operation_timeout_init)
self._setup_executed_init = True

super().__init__(test_database_url, force_rollback=force_rollback, **options)
super().__init__(
test_database_url,
force_rollback=force_rollback,
full_isolation=full_isolation,
poll_interval=poll_interval,
**options,
)

async def setup(self) -> None:
"""
Expand Down
Loading

0 comments on commit 49a801a

Please sign in to comment.