Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- join thread outside of lock and other changes to mitigate a deadlock
  possibility
- add configuration knobs in utils for debugging
  • Loading branch information
devkral committed Sep 3, 2024
1 parent c39f40a commit e1b5980
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 26 deletions.
41 changes: 22 additions & 19 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
import weakref
from contextvars import copy_context
from threading import Event, RLock, Thread, current_thread
from threading import Event, Lock, Thread, current_thread
from types import TracebackType

from sqlalchemy import text
Expand Down Expand Up @@ -38,21 +38,19 @@ async def _startup(database: Database, is_initialized: Event) -> None:

def _init_thread(database: Database, is_initialized: Event) -> None:
loop = asyncio.new_event_loop()
# keep reference
task = loop.create_task(_startup(database, is_initialized))
try:
loop.run_forever()
except RuntimeError:
pass
try:
task.result()
finally:
try:
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
del task
loop.close()
database._loop = None
loop.run_forever()
except RuntimeError:
pass
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
del task
loop.close()
database._loop = None


class Connection:
Expand All @@ -61,15 +59,15 @@ def __init__(
) -> None:
self._orig_database = self._database = database
self._full_isolation = full_isolation
self._connection_thread_lock: typing.Optional[RLock] = None
self._connection_thread_lock: typing.Optional[Lock] = None
self._isolation_thread: typing.Optional[Thread] = None
if self._full_isolation:
self._database = database.__class__(
database, force_rollback=force_rollback, full_isolation=False
)
self._database._call_hooks = False
self._database._global_connection = self
self._connection_thread_lock = RLock()
self._connection_thread_lock = Lock()
# the asyncio locks are overwritten in python versions < 3.10 when using full_isolation
self._query_lock = asyncio.Lock()
self._connection_lock = asyncio.Lock()
Expand Down Expand Up @@ -113,6 +111,7 @@ async def __aenter__(self) -> Connection:
initialized = False
if self._full_isolation:
assert self._connection_thread_lock is not None
thread: typing.Optional[Thread] = None
with self._connection_thread_lock:
if self._isolation_thread is None:
initialized = True
Expand All @@ -128,10 +127,13 @@ async def __aenter__(self) -> Connection:
daemon=True,
)
thread.start()
is_initialized.wait()
if not thread.is_alive():
self._isolation_thread = None
thread.join()
while not is_initialized.wait(1):
if not thread.is_alive():
self._isolation_thread = None
break
if thread is not None and not thread.is_alive():
thread.join(10)
raise Exception("Cannot start full isolation thread")
if not initialized:
await self._aenter()
return self
Expand Down Expand Up @@ -161,6 +163,7 @@ async def _aexit(self) -> typing.Optional[Thread]:
if await self._aexit_raw():
loop = self._database._loop
thread = self._isolation_thread
# after stopping the _isolation_thread is removed in thread
if loop is not None:
loop.stop()
else:
Expand Down
4 changes: 2 additions & 2 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from types import TracebackType

from databasez import interfaces
from databasez.utils import multiloop_protector
from databasez.utils import arun_coroutine_threadsafe, multiloop_protector

from .connection import Connection
from .databaseurl import DatabaseURL
Expand Down Expand Up @@ -331,7 +331,7 @@ async def disconnect(
if self._databases_map:
assert not self._databases_map, "sub databases still active, force terminate them"
for sub_database in self._databases_map.values():
asyncio.run_coroutine_threadsafe(
await arun_coroutine_threadsafe(
sub_database.disconnect(True), sub_database._loop
)
self._databases_map = {}
Expand Down
44 changes: 39 additions & 5 deletions databasez/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,46 @@
import asyncio
import contextvars
import inspect
import typing
from functools import partial, wraps
from threading import Thread
from types import TracebackType

DATABASEZ_RESULT_TIMEOUT: typing.Optional[float] = None
DATABASEZ_WRAP_IN_THREAD: bool = False

try:
to_thread = asyncio.to_thread
except AttributeError:
# for py <= 3.8
async def to_thread(
func: typing.Any, /, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)


def _run_coroutine_threadsafe_result_shim(
coro: typing.Coroutine, loop: asyncio.BaseEventLoop
) -> typing.Any:
assert loop.is_running(), "loop is closed"
return asyncio.run_coroutine_threadsafe(coro, loop).result(DATABASEZ_RESULT_TIMEOUT)


async def arun_coroutine_threadsafe(
coro: typing.Coroutine, loop: asyncio.BaseEventLoop
) -> typing.Any:
running_loop = asyncio.get_running_loop()
if running_loop is loop:
return await coro
elif not DATABASEZ_WRAP_IN_THREAD:
return _run_coroutine_threadsafe_result_shim(coro, loop)
else:
return await to_thread(_run_coroutine_threadsafe_result_shim, coro, loop)


async_wrapper_slots = (
"_async_wrapped",
"_async_pool",
Expand Down Expand Up @@ -220,7 +256,7 @@ async def call(self) -> typing.Any:
return result

async def acall(self) -> typing.Any:
return asyncio.run_coroutine_threadsafe(self.call(), self.connection._loop).result()
return await arun_coroutine_threadsafe(self.call(), self.connection._loop)

def __await__(self) -> typing.Any:
return self.acall().__await__()
Expand All @@ -239,9 +275,7 @@ async def exit_intern(self) -> typing.Any:
await self.connection.__aexit__()

async def __aenter__(self) -> typing.Any:
return asyncio.run_coroutine_threadsafe(
self.enter_intern(), self.connection._loop
).result()
return await arun_coroutine_threadsafe(self.enter_intern(), self.connection._loop)

async def __aexit__(
self,
Expand All @@ -250,7 +284,7 @@ async def __aexit__(
traceback: typing.Optional[TracebackType] = None,
) -> None:
assert self.ctm is not None
asyncio.run_coroutine_threadsafe(self.exit_intern(), self.connection._loop).result()
await arun_coroutine_threadsafe(self.exit_intern(), self.connection._loop)


def multiloop_protector(
Expand Down

0 comments on commit e1b5980

Please sign in to comment.