Skip to content

Commit

Permalink
rename attribute and fix removed global connection
Browse files Browse the repository at this point in the history
  • Loading branch information
devkral committed Sep 4, 2024
1 parent dc3e0fe commit c1ec61f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
53 changes: 26 additions & 27 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,30 @@ async def _startup(database: Database, is_initialized: Event) -> None:
is_initialized.set()


def _init_thread(database: Database, is_initialized: Event, is_cleared: Event) -> None:
is_cleared.wait()
# now set the flag so new init_threads have to wait
is_cleared.clear()
loop = asyncio.new_event_loop()
# keep reference
task = loop.create_task(_startup(database, is_initialized))
try:
def _init_thread(
database: Database, is_initialized: Event, _connection_thread_running_lock: Lock
) -> None:
# ensure only thread manages the connection thread at the same time
# this is only relevant when starting up after a shutdown
with _connection_thread_running_lock:
loop = asyncio.new_event_loop()
# keep reference
task = loop.create_task(_startup(database, is_initialized))
try:
loop.run_forever()
except RuntimeError:
pass
try:
loop.run_forever()
except RuntimeError:
pass
finally:
# now all inits wait
is_initialized.clear()
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
# now all inits wait
is_initialized.clear()
loop.run_until_complete(database.disconnect())
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
del task
loop.close()
database._loop = None
database._global_connection._isolation_thread = None # type: ignore
is_cleared.set()
del task
loop.close()
database._loop = None
database._global_connection._isolation_thread = None # type: ignore


class Connection:
Expand All @@ -66,14 +67,12 @@ def __init__(
self._full_isolation = full_isolation
self._connection_thread_lock: typing.Optional[Lock] = None
self._connection_thread_is_initialized: typing.Optional[Event] = None
self._connection_thread_is_cleared: typing.Optional[Event] = None
self._connection_thread_running_lock: typing.Optional[Lock] = None
self._isolation_thread: typing.Optional[Thread] = None
if self._full_isolation:
self._connection_thread_lock = Lock()
self._connection_thread_is_initialized = Event()
self._connection_thread_is_cleared = Event()
# initially it is cleared
self._connection_thread_is_cleared.set()
self._connection_thread_running_lock = Lock()
self._database = database.__class__(
database, force_rollback=force_rollback, full_isolation=False, poll_interval=-1
)
Expand Down Expand Up @@ -124,7 +123,7 @@ async def __aenter__(self) -> Connection:
thread: typing.Optional[Thread] = None
assert self._connection_thread_lock is not None
assert self._connection_thread_is_initialized is not None
assert self._connection_thread_is_cleared is not None
assert self._connection_thread_running_lock is not None
with self._connection_thread_lock:
thread = self._isolation_thread
if thread is None:
Expand All @@ -134,7 +133,7 @@ async def __aenter__(self) -> Connection:
args=[
self._database,
self._connection_thread_is_initialized,
self._connection_thread_is_cleared,
self._connection_thread_running_lock,
],
daemon=True,
)
Expand Down
6 changes: 5 additions & 1 deletion databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Database:
options: typing.Any
is_connected: bool = False
_call_hooks: bool = True
_remove_global_connection: bool = True
_full_isolation: bool = False
poll_interval: float
_force_rollback: ForceRollback
Expand Down Expand Up @@ -313,6 +314,8 @@ async def connect(self) -> bool:
if self._global_connection is None:
connection = Connection(self, force_rollback=True, full_isolation=self._full_isolation)
self._global_connection = connection
else:
self._remove_global_connection = False
return True

async def disconnect_hook(self) -> None:
Expand Down Expand Up @@ -350,7 +353,8 @@ async def disconnect(
try:
assert self._global_connection is not None
await self._global_connection.__aexit__()
self._global_connection = None
if self._remove_global_connection:
self._global_connection = None
self._connection = None
finally:
logger.info(
Expand Down

0 comments on commit c1ec61f

Please sign in to comment.