Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- don't wait for thread in another thread
- fix testclient constructor interface matching database constructor
  interface
- lock down jdbc to use only one connection per engine, this should fix
  the segfaults
- don't execute concurrent tests with mssql on github
  • Loading branch information
devkral committed Sep 11, 2024
1 parent 228dd59 commit 8f4f124
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ jobs:
run: "hatch run test:check_types"
- name: "Run tests"
if: steps.filters.outputs.src == 'true' || steps.filters.outputs.workflows == 'true' || github.event.schedule != ''
run: hatch test
run: env TEST_NO_RISK_SEGFAULTS=true hatch test
4 changes: 2 additions & 2 deletions databasez/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ async def _aexit_raw(self) -> bool:
self._database._connection = None
return closing

@multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
async def _aexit(self) -> typing.Optional[Thread]:
if self._full_isolation:
assert self._connection_thread_lock is not None
Expand All @@ -248,7 +249,6 @@ async def _aexit(self) -> typing.Optional[Thread]:
await self._aexit_raw()
return None

@multiloop_protector(False, passthrough_timeout=True) # fail when specifying timeout
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]] = None,
Expand All @@ -257,7 +257,7 @@ async def __aexit__(
) -> None:
thread = await self._aexit()
if thread is not None and thread is not current_thread():
while thread.is_alive(): # noqa: ASYNC110
while thread.is_alive():
await asyncio.sleep(self.poll_interval)
thread.join(1)

Expand Down
2 changes: 1 addition & 1 deletion databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class Database:

def __init__(
self,
url: typing.Optional[typing.Union[str, DatabaseURL, URL, Database]] = None,
url: typing.Union[str, DatabaseURL, URL, Database, None] = None,
*,
force_rollback: typing.Union[bool, None] = None,
config: typing.Optional["DictAny"] = None,
Expand Down
65 changes: 30 additions & 35 deletions databasez/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from typing import Any

import sqlalchemy as sa
import sqlalchemy
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy_utils.functions.database import _sqlite_file_exists
from sqlalchemy_utils.functions.orm import quote
Expand Down Expand Up @@ -45,7 +45,7 @@ class DatabaseTestClient(Database):

def __init__(
self,
url: typing.Union[str, "DatabaseURL", "sa.URL", Database],
url: typing.Union[str, DatabaseURL, sqlalchemy.URL, Database, None] = None,
*,
force_rollback: typing.Union[bool, None] = None,
full_isolation: typing.Union[bool, None] = None,
Expand All @@ -66,49 +66,42 @@ def __init__(
test_prefix = self.testclient_default_test_prefix
self._setup_executed_init = False
if isinstance(url, Database):
test_database_url = (
url.url.replace(database=f"{test_prefix}{url.url.database}")
if test_prefix
else url.url
)
# replace only if not cloning a DatabaseTestClient
self.test_db_url = str(getattr(url, "test_db_url", test_database_url))
self.use_existing = getattr(url, "use_existing", use_existing)
self.drop = getattr(url, "drop", drop_database)
# only if explicit set to False
if lazy_setup is False:
self.setup_protected(self.testclient_operation_timeout_init)
self._setup_executed_init = True
super().__init__(url, force_rollback=force_rollback, **options)
# fix url
if str(self.url) != self.test_db_url:
self.url = test_database_url
if hasattr(url, "test_db_url"):
self.test_db_url = url.test_db_url
else:
if test_prefix:
self.url = self.url.replace(database=f"{test_prefix}{self.url.database}")
self.test_db_url = str(self.url)
else:
if lazy_setup is None:
lazy_setup = self.testclient_default_lazy_setup
if force_rollback is None:
force_rollback = self.testclient_default_force_rollback
if poll_interval is None:
poll_interval = self.testclient_default_poll_interval
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
test_database_url = (
url.replace(database=f"{test_prefix}{url.database}") if test_prefix else url
)
self.test_db_url = str(test_database_url)
self.use_existing = use_existing
self.drop = drop_database
# if None or False
if not lazy_setup:
self.setup_protected(self.testclient_operation_timeout_init)
self._setup_executed_init = True

super().__init__(
test_database_url,
url,
force_rollback=force_rollback,
full_isolation=full_isolation,
poll_interval=poll_interval,
**options,
)
if test_prefix:
self.url = self.url.replace(database=f"{test_prefix}{self.url.database}")
self.test_db_url = str(self.url)
# if None or False
if not lazy_setup:
self.setup_protected(self.testclient_operation_timeout_init)
self._setup_executed_init = True

async def setup(self) -> None:
"""
Expand Down Expand Up @@ -150,7 +143,7 @@ async def is_database_exist(self) -> Any:
return await self.database_exists(self.test_db_url)

@classmethod
async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> bool:
async def database_exists(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseURL]) -> bool:
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
database = url.database
dialect_name = url.sqla_url.get_dialect(True).name
Expand All @@ -160,7 +153,9 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
url = url.replace(database=db)
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
return bool(
await _get_scalar_result(db_client.engine, sqlalchemy.text(text))
)
except (ProgrammingError, OperationalError):
pass
return False
Expand All @@ -172,7 +167,7 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
"WHERE SCHEMA_NAME = '%s'" % database
)
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text)))

elif dialect_name == "sqlite":
if database:
Expand All @@ -185,14 +180,14 @@ async def database_exists(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) ->
text = "SELECT 1"
async with Database(url, full_isolation=False, force_rollback=False) as db_client:
try:
return bool(await _get_scalar_result(db_client.engine, sa.text(text)))
return bool(await _get_scalar_result(db_client.engine, sqlalchemy.text(text)))
except (ProgrammingError, OperationalError):
return False

@classmethod
async def create_database(
cls,
url: typing.Union[str, "sa.URL", DatabaseURL],
url: typing.Union[str, "sqlalchemy.URL", DatabaseURL],
encoding: str = "utf8",
template: typing.Any = None,
) -> None:
Expand Down Expand Up @@ -229,29 +224,29 @@ async def create_database(
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database), encoding, quote(conn, template)
)
await conn.execute(sa.text(text))
await conn.execute(sqlalchemy.text(text))

elif dialect_name == "mysql":
async with db_client.engine.begin() as conn: # type: ignore
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database), encoding
)
await conn.execute(sa.text(text))
await conn.execute(sqlalchemy.text(text))

elif dialect_name == "sqlite" and database != ":memory:":
if database:
# create a sqlite file
async with db_client.engine.begin() as conn: # type: ignore
await conn.execute(sa.text("CREATE TABLE DB(id int)"))
await conn.execute(sa.text("DROP TABLE DB"))
await conn.execute(sqlalchemy.text("CREATE TABLE DB(id int)"))
await conn.execute(sqlalchemy.text("DROP TABLE DB"))

else:
async with db_client.engine.begin() as conn: # type: ignore
text = f"CREATE DATABASE {quote(conn, database)}"
await conn.execute(sa.text(text))
await conn.execute(sqlalchemy.text(text))

@classmethod
async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> None:
async def drop_database(cls, url: typing.Union[str, "sqlalchemy.URL", DatabaseURL]) -> None:
url = url if isinstance(url, DatabaseURL) else DatabaseURL(url)
database = url.database
dialect = url.sqla_url.get_dialect(True)
Expand Down Expand Up @@ -310,7 +305,7 @@ async def drop_database(cls, url: typing.Union[str, "sa.URL", DatabaseURL]) -> N
else:
async with db_client.connection() as conn:
text = f"DROP DATABASE {quote(conn.async_connection, database)}"
await conn.execute(sa.text(text))
await conn.execute(sqlalchemy.text(text))

def drop_db_protected(self) -> None:
thread = ThreadPassingExceptions(
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- The global connection is now entered lazily despite sub-databases.
- Fix deadlock with full_isolation off.
- Fix database.transaction() failing because of AsyncDatabaseHelper.
- Fix DatabaseTestClient not able to use config initialization.

## 0.10.2

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ select = [
]

ignore = [
"ASYNC110", # use anio.Event
"B008", # do not perform function calls in argument defaults
"C901", # too complex
"E712", # Comparison to True should be cond is True
Expand Down
12 changes: 10 additions & 2 deletions tests/shared_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ async def database_client(url: typing.Union[dict, str], meta=None) -> DatabaseTe
url, test_prefix="", use_existing=not is_sqlite, drop_database=is_sqlite
)
else:
database = Database(config=url)
scheme = url["connection"]["credentials"]["scheme"]
is_sqlite = scheme.startswith("sqlite")
database = DatabaseTestClient(
config=url,
test_prefix="",
use_existing=not is_sqlite,
drop_database=is_sqlite,
)
await database.connect()
await database.create_all(meta)
return database
Expand All @@ -88,5 +95,6 @@ async def database_client(url: typing.Union[dict, str], meta=None) -> DatabaseTe
async def stop_database_client(database: Database, meta=None):
if meta is None:
meta = metadata
await database.drop_all(meta)
if not getattr(database, "drop", False):
await database.drop_all(meta)
await database.disconnect()
4 changes: 3 additions & 1 deletion tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]

if not any((x.endswith(" for SQL Server") for x in pyodbc.drivers())):
if os.environ.get("TEST_NO_RISK_SEGFAULTS") or not any(
(x.endswith(" for SQL Server") for x in pyodbc.drivers())
):
DATABASE_URLS = list(filter(lambda x: "mssql" not in x, DATABASE_URLS))


Expand Down
6 changes: 3 additions & 3 deletions tests/test_really_old_jdbc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import sqlalchemy
from sqlalchemy.pool import NullPool
from sqlalchemy.pool import StaticPool

from databasez import Database

Expand All @@ -25,7 +25,7 @@ async def test_jdbc_connect():
"""
async with Database(
"jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC",
poolclass=NullPool,
poolclass=StaticPool,
) as database:
async with database.connection():
pass
Expand All @@ -39,7 +39,7 @@ async def test_jdbc_queries():
"""
async with Database(
"jdbc+sqlite://testsuite.sqlite3?classpath=tests/sqlite-jdbc-3.6.13.jar&jdbc_driver=org.sqlite.JDBC",
poolclass=NullPool,
poolclass=StaticPool,
) as database:
async with database.connection() as connection:
await connection.create_all(metadata)
Expand Down

0 comments on commit 8f4f124

Please sign in to comment.