Skip to content

Commit

Permalink
databasez 0.8 compatibility (#161)
Browse files Browse the repository at this point in the history
- fix compatibility with databasez >=0.8
- some small test utility related fixes
- fix pre-commit for python != 3.10
- Adapt Databasez >= 0.8
  • Loading branch information
devkral authored Aug 7, 2024
1 parent cd6c711 commit 6432879
Show file tree
Hide file tree
Showing 16 changed files with 92 additions and 149 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ node_modules/
results/
site/
target/
venv

# files
**/*.so
Expand All @@ -25,4 +26,4 @@ target/
.coverage
.python-version
coverage.*
example.sqlite
example.sqlite
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# See https://pre-commit.com for more information.
# See https://pre-commit.com/hooks.html for more hooks.
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
Expand Down
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ dependencies = [
"click>=8.1.3,<9.0.0",
"dymmond-settings>=1.0.4",
"loguru>=0.6.0,<0.10.0",
"databases>=0.7.2",
"databasez>=0.8.5",
"orjson >=3.8.5,<4.0.0",
"pydantic>=2.5.3,<3.0.0",
"rich>=13.3.1,<14.0.0",
"nest_asyncio",
]
keywords = [
"api",
Expand Down Expand Up @@ -115,15 +116,15 @@ doc = [
]

testing = ["sqlalchemy_utils>=0.40.0"]
postgres = ["databasez[postgresql]==0.7.2"]
mysql = ["databasez[mysql]==0.7.2"]
sqlite = ["databasez[sqlite]==0.7.2"]
postgres = ["databasez[postgresql]>=0.8.2"]
mysql = ["databasez[mysql]>=0.8.2"]
sqlite = ["databasez[sqlite]>=0.8.2"]

ptpython = ["ptpython>=3.0.23,<4.0.0"]
ipython = ["ipython>=8.10.0,<9.0.0"]

all = [
"databasez[postgresql,mysql,sqlite]==0.7.2",
"databasez[postgresql,mysql,sqlite]>=0.8.2",
"orjson>=3.8.5,<4.0.0",
"ptpython>=3.0.23,<4.0.0",
"ipython>=8.10.0,<9.0.0",
Expand Down
32 changes: 2 additions & 30 deletions saffier/core/connection/database.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,3 @@
from databasez import Database as Databasez # noqa
from databasez import DatabaseURL as DatabaseURL # noqa
from databasez import Database, DatabaseURL


class Database(Databasez):
"""
An abstraction on the top of the EncodeORM databases.Database object.
This object allows to pass also a configuration dictionary in the format of
```python
DATABASEZ_CONFIG = {
"connection": {
"credentials": {
"scheme": 'sqlite', "postgres"...
"host": ...,
"port": ...,
"user": ...,
"password": ...,
"database": ...,
"options": {
"driver": ...
"ssl": ...
}
}
}
}
```
"""

...
__all__ = ["Database", "DatabaseURL"]
51 changes: 10 additions & 41 deletions saffier/core/connection/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from typing import Any, Dict, Mapping, Type

import sqlalchemy
from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.orm import declarative_base as sa_declarative_base

from saffier.conf import settings
from saffier.core.connection.database import Database
from saffier.core.connection.schemas import Schema
from saffier.exceptions import ImproperlyConfigured


class Registry:
Expand Down Expand Up @@ -44,27 +41,6 @@ def metadata(self) -> Any:
def metadata(self, value: sqlalchemy.MetaData) -> None:
self._metadata = value

def _get_database_url(self) -> str:
url = self.database.url
if not url.driver:
if url.dialect in settings.postgres_dialects:
url = url.replace(driver="asyncpg")
elif url.dialect in settings.mysql_dialects:
url = url.replace(driver="aiomysql")
elif url.dialect in settings.sqlite_dialects:
url = url.replace(driver="aiosqlite")
elif url.dialect in settings.mssql_dialects:
raise ImproperlyConfigured("Saffier does not support MSSQL at the moment.")
elif url.driver in settings.mssql_drivers: # type: ignore
raise ImproperlyConfigured("Saffier does not support MSSQL at the moment.")
return str(url)

@cached_property
def _get_engine(self) -> AsyncEngine:
url = self._get_database_url()
engine = create_async_engine(url)
return engine

@cached_property
def declarative_base(self) -> Any:
if self.db_schema:
Expand All @@ -75,30 +51,23 @@ def declarative_base(self) -> Any:

@property
def engine(self) -> AsyncEngine:
return self._get_engine

@cached_property
def _get_sync_engine(self) -> Engine:
url = self._get_database_url()
engine = create_engine(url)
return engine
assert self.database.engine, "database not started, no engine found."
return self.database.engine

@property
def sync_engine(self) -> Engine:
return self._get_sync_engine
return self.engine.sync_engine

async def create_all(self) -> None:
if self.db_schema:
await self.schema.create_schema(self.db_schema, True)
async with self.database:
async with self.engine.begin() as connection:
await connection.run_sync(self.metadata.create_all)
await self.engine.dispose()
async with Database(self.database, force_rollback=False) as database:
async with database.transaction():
await database.create_all(self.metadata)

async def drop_all(self) -> None:
if self.db_schema:
await self.schema.drop_schema(self.db_schema, True, True)
async with self.database:
async with self.engine.begin() as conn:
await conn.run_sync(self.metadata.drop_all)
await self.engine.dispose()
async with Database(self.database, force_rollback=False) as database:
async with database.transaction():
await database.drop_all(self.metadata)
12 changes: 6 additions & 6 deletions saffier/core/connection/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def execute_create(connection: sqlalchemy.Connection) -> None:
except ProgrammingError as e:
raise SchemaError(detail=e.orig.args[0]) from e # type: ignore

async with self.registry.engine.begin() as connection:
await connection.run_sync(execute_create)
await self.registry.engine.dispose()
async with Database(self.registry.database, force_rollback=False) as database:
async with database.transaction():
await database.run_sync(execute_create)

async def drop_schema(
self, schema: str, cascade: bool = False, if_exists: bool = False
Expand All @@ -69,6 +69,6 @@ def execute_drop(connection: sqlalchemy.Connection) -> None:
except DBAPIError as e:
raise SchemaError(detail=e.orig.args[0]) from e # type: ignore

async with self.registry.engine.begin() as connection:
await connection.run_sync(execute_drop)
await self.registry.engine.dispose()
async with Database(self.registry.database, force_rollback=False) as database:
async with database.transaction():
await database.run_sync(execute_drop)
16 changes: 12 additions & 4 deletions saffier/core/db/models/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing
from typing import Any, Type, Union

from sqlalchemy.engine.result import Row

from saffier.core.db.models.base import SaffierBaseReflectModel
from saffier.core.db.models.mixins.generics import DeclarativeMixin
from saffier.core.db.models.row import ModelRow
Expand Down Expand Up @@ -113,10 +115,16 @@ async def _save(self, **kwargs: typing.Any) -> "Model":
Performs the save instruction.
"""
expression = self.table.insert().values(**kwargs)
awaitable = await self.database.execute(expression)
if not awaitable:
awaitable = kwargs.get(self.pkname)
saffier_setattr(self, self.pkname, awaitable)
autoincrement_value = await self.database.execute(expression)
# sqlalchemy supports only one autoincrement column
if autoincrement_value:
if isinstance(autoincrement_value, Row):
assert len(autoincrement_value) == 1
autoincrement_value = autoincrement_value[0]
column = self.table.autoincrement_column
# can be explicit set, which causes an invalid value returned
if column is not None and column.key not in kwargs:
saffier_setattr(self, column.key, autoincrement_value)
return self

async def save(
Expand Down
21 changes: 12 additions & 9 deletions saffier/core/db/models/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def from_query_result(
if column.name not in cls.fields.keys():
continue
elif related not in child_item:
if row[related] is not None:
child_item[column.name] = row[related]
if getattr(row, related) is not None:
child_item[column.name] = getattr(row, related)

# Make sure we generate a temporary reduced model
# For the related fields. We simply chnage the structure of the model
Expand All @@ -101,7 +101,7 @@ def from_query_result(
# Check for the only_fields
if is_only_fields or is_defer_fields:
mapping_fields = (
[str(field) for field in only_fields] if is_only_fields else list(row.keys()) # type: ignore
[str(field) for field in only_fields] if is_only_fields else list(row._mapping.keys()) # type: ignore
)

for column, value in row._mapping.items():
Expand All @@ -128,13 +128,16 @@ def from_query_result(
else:
# Pull out the regular column values.
for column in cls.table.columns:
if column.name in secret_fields:
if column.key in secret_fields:
continue
# Making sure when a table is reflected, maps the right fields of the ReflectModel
if column.name not in cls.fields.keys():
if column.key not in cls.fields:
continue
elif column.name not in item:
item[column.name] = row[column]
elif column.key not in item:
if column in row._mapping:
item[column.key] = row._mapping[column]
elif column.key in row._mapping:
item[column.key] = row._mapping[column.key]

model = (
cast("Type[Model]", cls(**item))
Expand Down Expand Up @@ -230,7 +233,7 @@ def __handle_prefetch_related(

# Check for individual not nested querysets
elif related.queryset is not None and not is_nested:
filter_by_pk = row[cls.pkname]
filter_by_pk = getattr(row, cls.pkname)
extra = {f"{related.related_name}__id": filter_by_pk}
related.queryset.extra = extra

Expand Down Expand Up @@ -281,7 +284,7 @@ def __process_nested_prefetch_related(
query = "__".join(query_split)

# Extact foreign key value
filter_by_pk = row[parent_cls.pkname]
filter_by_pk = getattr(row, parent_cls.pkname)

extra = {f"{query}__id": filter_by_pk}

Expand Down
13 changes: 9 additions & 4 deletions saffier/core/db/querysets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ async def bulk_update(self, objs: List[SaffierModel], fields: List[str]) -> None
new_obj = {}
for key, value in obj.__dict__.items():
if key in fields:
new_obj[key] = queryset._resolve_value(value)
new_obj[key] = value
new_objs.append(new_obj)

new_objs = [
Expand All @@ -970,19 +970,24 @@ async def bulk_update(self, objs: List[SaffierModel], fields: List[str]) -> None
]

pk = getattr(queryset.table.c, queryset.pkname)
expression = queryset.table.update().where(pk == sqlalchemy.bindparam(queryset.pkname))
expression = queryset.table.update().where(
pk == sqlalchemy.bindparam("__id" if queryset.pkname == "id" else queryset.pkname)
)
kwargs: Dict[str, Any] = {
field: sqlalchemy.bindparam(field) for obj in new_objs for field in obj.keys()
}
pks = [{queryset.pkname: getattr(obj, queryset.pkname)} for obj in objs]
pks = [
{"__id" if queryset.pkname == "id" else queryset.pkname: getattr(obj, queryset.pkname)}
for obj in objs
]

query_list = []
for pk, value in zip(pks, new_objs): # noqa
query_list.append({**pk, **value})

expression = expression.values(kwargs)
queryset._set_query_expression(expression)
await queryset.database.execute_many(str(expression), query_list)
await queryset.database.execute(expression, query_list)

async def delete(self) -> None:
queryset: "QuerySet" = self._clone()
Expand Down
2 changes: 2 additions & 0 deletions saffier/utils/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def inspect(self) -> None:
Starts the InspectDB and passes all the configurations.
"""
registry = Registry(database=self.database)
execsync(registry.database.connect)()

# Get the engine to connect
engine: AsyncEngine = registry.engine
Expand All @@ -91,6 +92,7 @@ def inspect(self) -> None:

for line in self.write_output(tables, self.database.url._url):
sys.stdout.writelines(line) # type: ignore
execsync(registry.database.disconnect)()

def generate_table_information(
self, metadata: sqlalchemy.MetaData
Expand Down
18 changes: 11 additions & 7 deletions tests/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os
import shlex
import subprocess


def run_cmd(app, cmd, is_app=True):
env = dict(os.environ)
if is_app:
os.environ["SAFFIER_DEFAULT_APP"] = app
process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
(stdout, stderr) = process.communicate()
env["SAFFIER_DEFAULT_APP"] = app
# CI uses something different as workdir and we aren't hatch test yet.
if "VIRTUAL_ENV" not in env:
basedir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if os.path.isdir(f"{basedir}/venv/bin/"):
cmd = f"{basedir}/venv/bin/{cmd}"
result = subprocess.run(cmd, capture_output=True, env=env, shell=True)
print("\n$ " + cmd)
print(stdout.decode("utf-8"))
print(stderr.decode("utf-8"))
return stdout, stderr, process.wait()
print(result.stdout.decode("utf-8"))
print(result.stderr.decode("utf-8"))
return result.stdout, result.stderr, result.returncode
Loading

0 comments on commit 6432879

Please sign in to comment.