Skip to content

Commit

Permalink
multithread safety (#45)
Browse files Browse the repository at this point in the history
Changes:

- make databasez threadsafe
- bump version
  • Loading branch information
devkral authored Aug 20, 2024
1 parent d024474 commit 46d5d2a
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 8 deletions.
2 changes: 1 addition & 1 deletion databasez/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from databasez.core import Database, DatabaseURL

__version__ = "0.9.5"
__version__ = "0.9.6"

__all__ = ["Database", "DatabaseURL"]
19 changes: 19 additions & 0 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy import text

from databasez import interfaces
from databasez.utils import multiloop_protector

from .transaction import Transaction

Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(
self._force_rollback = force_rollback
self.connection_transaction: typing.Optional[Transaction] = None

@multiloop_protector(True)
async def __aenter__(self) -> Connection:
async with self._connection_lock:
self._connection_counter += 1
Expand Down Expand Up @@ -83,6 +85,11 @@ async def __aexit__(
await self._connection.release()
self._database._connection = None

@property
def _loop(self) -> typing.Any:
return self._database._loop

@multiloop_protector(True)
async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -92,6 +99,7 @@ async def fetch_all(
async with self._query_lock:
return await self._connection.fetch_all(built_query)

@multiloop_protector(True)
async def fetch_one(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -102,6 +110,7 @@ async def fetch_one(
async with self._query_lock:
return await self._connection.fetch_one(built_query, pos=pos)

@multiloop_protector(True)
async def fetch_val(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -113,6 +122,7 @@ async def fetch_val(
async with self._query_lock:
return await self._connection.fetch_val(built_query, column, pos=pos)

@multiloop_protector(True)
async def execute(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -126,6 +136,7 @@ async def execute(
async with self._query_lock:
return await self._connection.execute(query, values)

@multiloop_protector(True)
async def execute_many(
self, query: typing.Union[ClauseElement, str], values: typing.Any = None
) -> typing.Union[typing.Sequence[interfaces.Record], int]:
Expand All @@ -137,6 +148,7 @@ async def execute_many(
async with self._query_lock:
return await self._connection.execute_many(query, values)

@multiloop_protector(True)
async def iterate(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -148,6 +160,7 @@ async def iterate(
async for record in self._connection.iterate(built_query, batch_size):
yield record

@multiloop_protector(True)
async def batched_iterate(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -159,6 +172,7 @@ async def batched_iterate(
async for records in self._connection.batched_iterate(built_query, batch_size):
yield records

@multiloop_protector(True)
async def run_sync(
self,
fn: typing.Callable[..., typing.Any],
Expand All @@ -168,20 +182,25 @@ async def run_sync(
async with self._query_lock:
return await self._connection.run_sync(fn, *args, **kwargs)

@multiloop_protector(True)
async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
await self.run_sync(meta.create_all, **kwargs)

@multiloop_protector(True)
async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
await self.run_sync(meta.drop_all, **kwargs)

@multiloop_protector(True)
def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(weakref.ref(self), force_rollback, **kwargs)

@property
@multiloop_protector(True)
def async_connection(self) -> typing.Any:
"""The first layer (sqlalchemy)."""
return self._connection.async_connection

@multiloop_protector(True)
async def get_raw_connection(self) -> typing.Any:
"""The real raw connection (driver)."""
return await self.async_connection.get_raw_connection()
Expand Down
57 changes: 51 additions & 6 deletions databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from types import TracebackType

from databasez import interfaces
from databasez.utils import multiloop_protector

from .connection import Connection
from .databaseurl import DatabaseURL
Expand Down Expand Up @@ -62,6 +63,12 @@ def init() -> None:
)


# we need a dict to ensure the references are kept
ACTIVE_DATABASES: ContextVar[typing.Optional[typing.Dict[typing.Any, Database]]] = ContextVar(
"ACTIVE_DATABASES", default=None
)


ACTIVE_FORCE_ROLLBACKS: ContextVar[
typing.Optional[weakref.WeakKeyDictionary[ForceRollback, bool]]
] = ContextVar("ACTIVE_FORCE_ROLLBACKS", default=None)
Expand Down Expand Up @@ -142,6 +149,7 @@ class Database:
"""

_connection_map: weakref.WeakKeyDictionary[asyncio.Task, Connection]
_loop: typing.Any = None
backend: interfaces.DatabaseBackend
url: DatabaseURL
options: typing.Any
Expand Down Expand Up @@ -234,6 +242,7 @@ async def decr_refcount(self) -> bool:
async def connect_hook(self) -> None:
"""Refcount protected connect hook. Executed begore engine and global connection setup."""

@multiloop_protector(True)
async def connect(self) -> bool:
"""
Establish the connection pool.
Expand All @@ -246,6 +255,7 @@ async def connect(self) -> bool:
except BaseException as exc:
await self.decr_refcount()
raise exc
self._loop = asyncio.get_event_loop()

await self.backend.connect(self.url, **self.options)
logger.info("Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA)
Expand All @@ -259,6 +269,7 @@ async def connect(self) -> bool:
async def disconnect_hook(self) -> None:
"""Refcount protected disconnect hook. Executed after connection, engine cleanup."""

@multiloop_protector(True)
async def disconnect(self, force: bool = False) -> bool:
"""
Close all connections in the connection pool.
Expand All @@ -285,21 +296,41 @@ async def disconnect(self, force: bool = False) -> bool:
)
self.is_connected = False
await self.backend.disconnect()
self._loop = None
await self.disconnect_hook()
return True

async def __aenter__(self) -> "Database":
await self.connect()
return self
loop = asyncio.get_running_loop()
database = self
if self._loop is not None and loop != self._loop:
dbs = ACTIVE_DATABASES.get()
if dbs is None:
dbs = {}
else:
dbs = dbs.copy()
database = self.__copy__()
dbs[loop] = database
# it is always a copy required to prevent sideeffects between the contexts
ACTIVE_DATABASES.set(dbs)
await database.connect()
return database

async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
exc_value: typing.Optional[BaseException] = None,
traceback: typing.Optional[TracebackType] = None,
) -> None:
await self.disconnect()

loop = asyncio.get_running_loop()
database = self
if self._loop is not None and loop != self._loop:
dbs = ACTIVE_DATABASES.get()
if dbs is not None:
database = dbs.pop(loop, database)
await database.disconnect()

@multiloop_protector(False)
async def fetch_all(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -308,6 +339,7 @@ async def fetch_all(
async with self.connection() as connection:
return await connection.fetch_all(query, values)

@multiloop_protector(False)
async def fetch_one(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -316,7 +348,9 @@ async def fetch_one(
) -> typing.Optional[interfaces.Record]:
async with self.connection() as connection:
return await connection.fetch_one(query, values, pos=pos)
assert connection._connection_counter == 1

@multiloop_protector(False)
async def fetch_val(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -327,6 +361,7 @@ async def fetch_val(
async with self.connection() as connection:
return await connection.fetch_val(query, values, column=column, pos=pos)

@multiloop_protector(False)
async def execute(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -335,12 +370,14 @@ async def execute(
async with self.connection() as connection:
return await connection.execute(query, values)

@multiloop_protector(False)
async def execute_many(
self, query: typing.Union[ClauseElement, str], values: typing.Any = None
) -> typing.Union[typing.Sequence[interfaces.Record], int]:
async with self.connection() as connection:
return await connection.execute_many(query, values)

@multiloop_protector(False)
async def iterate(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -351,6 +388,7 @@ async def iterate(
async for record in connection.iterate(query, values, chunk_size):
yield record

@multiloop_protector(False)
async def batched_iterate(
self,
query: typing.Union[ClauseElement, str],
Expand All @@ -362,9 +400,11 @@ async def batched_iterate(
async for records in connection.batched_iterate(query, values, batch_size):
yield batch_wrapper(records)

@multiloop_protector(True)
def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)

@multiloop_protector(False)
async def run_sync(
self,
fn: typing.Callable[..., typing.Any],
Expand All @@ -374,23 +414,28 @@ async def run_sync(
async with self.connection() as connection:
return await connection.run_sync(fn, *args, **kwargs)

@multiloop_protector(False)
async def create_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
async with self.connection() as connection:
await connection.create_all(meta, **kwargs)

@multiloop_protector(False)
async def drop_all(self, meta: MetaData, **kwargs: typing.Any) -> None:
async with self.connection() as connection:
await connection.drop_all(meta, **kwargs)

@multiloop_protector(False, wrap_context_manager=True)
def connection(self) -> Connection:
if self.force_rollback:
return typing.cast(Connection, self._global_connection)

if not self._connection:
self._connection = Connection(self, self.backend)
if self._connection is None:
_connection = self._connection = Connection(self, self.backend)
return _connection
return self._connection

@property
@multiloop_protector(True)
def engine(self) -> typing.Optional[AsyncEngine]:
return self.backend.engine

Expand Down
50 changes: 49 additions & 1 deletion databasez/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import inspect
import typing
from functools import partial
from contextlib import asynccontextmanager
from functools import partial, wraps
from threading import Thread

async_wrapper_slots = (
Expand Down Expand Up @@ -134,3 +135,50 @@ def join(self, timeout: typing.Union[float, int, None] = None) -> None:
super().join(timeout=timeout)
if self._exc_raised:
raise self._exc_raised


MultiloopProtectorCallable = typing.TypeVar("MultiloopProtectorCallable", bound=typing.Callable)


async def _async_helper(
database: typing.Any, fn: MultiloopProtectorCallable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
# copy
async with database.__class__(database) as new_database:
return await fn(new_database, *args, **kwargs)


@asynccontextmanager
async def _contextmanager_helper(
database: typing.Any, fn: MultiloopProtectorCallable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
async with database.__copy__() as new_database:
async with fn(new_database, *args, **kwargs) as result:
yield result


def multiloop_protector(
fail_with_different_loop: bool, wrap_context_manager: bool = False
) -> typing.Callable[[MultiloopProtectorCallable], MultiloopProtectorCallable]:
"""For multiple threads or other reasons why the loop changes"""

# True works with all methods False only for methods of Database
# needs _loop attribute to check against
def _decorator(fn: MultiloopProtectorCallable) -> MultiloopProtectorCallable:
@wraps(fn)
def wrapper(self: typing.Any, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and self._loop is not None and loop != self._loop:
if fail_with_different_loop:
raise RuntimeError("Different loop used")
if wrap_context_manager:
return _contextmanager_helper(self, fn, *args, **kwargs)
return _async_helper(self, fn, *args, **kwargs)
return fn(self, *args, **kwargs)

return typing.cast(MultiloopProtectorCallable, wrapper)

return _decorator
7 changes: 7 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Release Notes


## 0.9.6

### Fixed

- Databasez is now threadsafe (and multiloop safe).


## 0.9.5

### Changed
Expand Down
Loading

0 comments on commit 46d5d2a

Please sign in to comment.