diff --git a/.gitignore b/.gitignore index 1ca72ce..c72513c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ node_modules/ results/ site/ target/ +venv # files **/*.so @@ -25,4 +26,4 @@ target/ .coverage .python-version coverage.* -example.sqlite \ No newline at end of file +example.sqlite diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 024e1e2..0a54cd9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,5 @@ # See https://pre-commit.com for more information. # See https://pre-commit.com/hooks.html for more hooks. -default_language_version: - python: python3.10 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 diff --git a/pyproject.toml b/pyproject.toml index 089a3d6..8a289ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,11 @@ dependencies = [ "click>=8.1.3,<9.0.0", "dymmond-settings>=1.0.4", "loguru>=0.6.0,<0.10.0", - "databases>=0.7.2", + "databasez>=0.8.5", "orjson >=3.8.5,<4.0.0", "pydantic>=2.5.3,<3.0.0", "rich>=13.3.1,<14.0.0", + "nest_asyncio", ] keywords = [ "api", @@ -115,15 +116,15 @@ doc = [ ] testing = ["sqlalchemy_utils>=0.40.0"] -postgres = ["databasez[postgresql]==0.7.2"] -mysql = ["databasez[mysql]==0.7.2"] -sqlite = ["databasez[sqlite]==0.7.2"] +postgres = ["databasez[postgresql]>=0.8.2"] +mysql = ["databasez[mysql]>=0.8.2"] +sqlite = ["databasez[sqlite]>=0.8.2"] ptpython = ["ptpython>=3.0.23,<4.0.0"] ipython = ["ipython>=8.10.0,<9.0.0"] all = [ - "databasez[postgresql,mysql,sqlite]==0.7.2", + "databasez[postgresql,mysql,sqlite]>=0.8.2", "orjson>=3.8.5,<4.0.0", "ptpython>=3.0.23,<4.0.0", "ipython>=8.10.0,<9.0.0", diff --git a/saffier/core/connection/database.py b/saffier/core/connection/database.py index 82d6f6b..b6167b6 100644 --- a/saffier/core/connection/database.py +++ b/saffier/core/connection/database.py @@ -1,31 +1,3 @@ -from databasez import Database as Databasez # noqa -from databasez import DatabaseURL as DatabaseURL # noqa +from databasez import Database, DatabaseURL - -class Database(Databasez): - """ - An abstraction on the top of the EncodeORM databases.Database object. - - This object allows to pass also a configuration dictionary in the format of - - ```python - DATABASEZ_CONFIG = { - "connection": { - "credentials": { - "scheme": 'sqlite', "postgres"... - "host": ..., - "port": ..., - "user": ..., - "password": ..., - "database": ..., - "options": { - "driver": ... - "ssl": ... - } - } - } - } - ``` - """ - - ... +__all__ = ["Database", "DatabaseURL"] diff --git a/saffier/core/connection/registry.py b/saffier/core/connection/registry.py index ed9b8c3..b9b402c 100644 --- a/saffier/core/connection/registry.py +++ b/saffier/core/connection/registry.py @@ -2,15 +2,12 @@ from typing import Any, Dict, Mapping, Type import sqlalchemy -from sqlalchemy import Engine, create_engine -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy import Engine from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import declarative_base as sa_declarative_base -from saffier.conf import settings from saffier.core.connection.database import Database from saffier.core.connection.schemas import Schema -from saffier.exceptions import ImproperlyConfigured class Registry: @@ -44,27 +41,6 @@ def metadata(self) -> Any: def metadata(self, value: sqlalchemy.MetaData) -> None: self._metadata = value - def _get_database_url(self) -> str: - url = self.database.url - if not url.driver: - if url.dialect in settings.postgres_dialects: - url = url.replace(driver="asyncpg") - elif url.dialect in settings.mysql_dialects: - url = url.replace(driver="aiomysql") - elif url.dialect in settings.sqlite_dialects: - url = url.replace(driver="aiosqlite") - elif url.dialect in settings.mssql_dialects: - raise ImproperlyConfigured("Saffier does not support MSSQL at the moment.") - elif url.driver in settings.mssql_drivers: # type: ignore - raise ImproperlyConfigured("Saffier does not support MSSQL at the moment.") - return str(url) - - @cached_property - def _get_engine(self) -> AsyncEngine: - url = self._get_database_url() - engine = create_async_engine(url) - return engine - @cached_property def declarative_base(self) -> Any: if self.db_schema: @@ -75,30 +51,23 @@ def declarative_base(self) -> Any: @property def engine(self) -> AsyncEngine: - return self._get_engine - - @cached_property - def _get_sync_engine(self) -> Engine: - url = self._get_database_url() - engine = create_engine(url) - return engine + assert self.database.engine, "database not started, no engine found." + return self.database.engine @property def sync_engine(self) -> Engine: - return self._get_sync_engine + return self.engine.sync_engine async def create_all(self) -> None: if self.db_schema: await self.schema.create_schema(self.db_schema, True) - async with self.database: - async with self.engine.begin() as connection: - await connection.run_sync(self.metadata.create_all) - await self.engine.dispose() + async with Database(self.database, force_rollback=False) as database: + async with database.transaction(): + await database.create_all(self.metadata) async def drop_all(self) -> None: if self.db_schema: await self.schema.drop_schema(self.db_schema, True, True) - async with self.database: - async with self.engine.begin() as conn: - await conn.run_sync(self.metadata.drop_all) - await self.engine.dispose() + async with Database(self.database, force_rollback=False) as database: + async with database.transaction(): + await database.drop_all(self.metadata) diff --git a/saffier/core/connection/schemas.py b/saffier/core/connection/schemas.py index 398f4cb..0c8c6a6 100644 --- a/saffier/core/connection/schemas.py +++ b/saffier/core/connection/schemas.py @@ -50,9 +50,9 @@ def execute_create(connection: sqlalchemy.Connection) -> None: except ProgrammingError as e: raise SchemaError(detail=e.orig.args[0]) from e # type: ignore - async with self.registry.engine.begin() as connection: - await connection.run_sync(execute_create) - await self.registry.engine.dispose() + async with Database(self.registry.database, force_rollback=False) as database: + async with database.transaction(): + await database.run_sync(execute_create) async def drop_schema( self, schema: str, cascade: bool = False, if_exists: bool = False @@ -69,6 +69,6 @@ def execute_drop(connection: sqlalchemy.Connection) -> None: except DBAPIError as e: raise SchemaError(detail=e.orig.args[0]) from e # type: ignore - async with self.registry.engine.begin() as connection: - await connection.run_sync(execute_drop) - await self.registry.engine.dispose() + async with Database(self.registry.database, force_rollback=False) as database: + async with database.transaction(): + await database.run_sync(execute_drop) diff --git a/saffier/core/db/models/model.py b/saffier/core/db/models/model.py index db40fa3..65c18ab 100644 --- a/saffier/core/db/models/model.py +++ b/saffier/core/db/models/model.py @@ -1,6 +1,8 @@ import typing from typing import Any, Type, Union +from sqlalchemy.engine.result import Row + from saffier.core.db.models.base import SaffierBaseReflectModel from saffier.core.db.models.mixins.generics import DeclarativeMixin from saffier.core.db.models.row import ModelRow @@ -113,10 +115,16 @@ async def _save(self, **kwargs: typing.Any) -> "Model": Performs the save instruction. """ expression = self.table.insert().values(**kwargs) - awaitable = await self.database.execute(expression) - if not awaitable: - awaitable = kwargs.get(self.pkname) - saffier_setattr(self, self.pkname, awaitable) + autoincrement_value = await self.database.execute(expression) + # sqlalchemy supports only one autoincrement column + if autoincrement_value: + if isinstance(autoincrement_value, Row): + assert len(autoincrement_value) == 1 + autoincrement_value = autoincrement_value[0] + column = self.table.autoincrement_column + # can be explicit set, which causes an invalid value returned + if column is not None and column.key not in kwargs: + saffier_setattr(self, column.key, autoincrement_value) return self async def save( diff --git a/saffier/core/db/models/row.py b/saffier/core/db/models/row.py index 1faf497..ba66cab 100644 --- a/saffier/core/db/models/row.py +++ b/saffier/core/db/models/row.py @@ -89,8 +89,8 @@ def from_query_result( if column.name not in cls.fields.keys(): continue elif related not in child_item: - if row[related] is not None: - child_item[column.name] = row[related] + if getattr(row, related) is not None: + child_item[column.name] = getattr(row, related) # Make sure we generate a temporary reduced model # For the related fields. We simply chnage the structure of the model @@ -101,7 +101,7 @@ def from_query_result( # Check for the only_fields if is_only_fields or is_defer_fields: mapping_fields = ( - [str(field) for field in only_fields] if is_only_fields else list(row.keys()) # type: ignore + [str(field) for field in only_fields] if is_only_fields else list(row._mapping.keys()) # type: ignore ) for column, value in row._mapping.items(): @@ -128,13 +128,16 @@ def from_query_result( else: # Pull out the regular column values. for column in cls.table.columns: - if column.name in secret_fields: + if column.key in secret_fields: continue # Making sure when a table is reflected, maps the right fields of the ReflectModel - if column.name not in cls.fields.keys(): + if column.key not in cls.fields: continue - elif column.name not in item: - item[column.name] = row[column] + elif column.key not in item: + if column in row._mapping: + item[column.key] = row._mapping[column] + elif column.key in row._mapping: + item[column.key] = row._mapping[column.key] model = ( cast("Type[Model]", cls(**item)) @@ -230,7 +233,7 @@ def __handle_prefetch_related( # Check for individual not nested querysets elif related.queryset is not None and not is_nested: - filter_by_pk = row[cls.pkname] + filter_by_pk = getattr(row, cls.pkname) extra = {f"{related.related_name}__id": filter_by_pk} related.queryset.extra = extra @@ -281,7 +284,7 @@ def __process_nested_prefetch_related( query = "__".join(query_split) # Extact foreign key value - filter_by_pk = row[parent_cls.pkname] + filter_by_pk = getattr(row, parent_cls.pkname) extra = {f"{query}__id": filter_by_pk} diff --git a/saffier/core/db/querysets/base.py b/saffier/core/db/querysets/base.py index f3bc38e..4827baa 100644 --- a/saffier/core/db/querysets/base.py +++ b/saffier/core/db/querysets/base.py @@ -961,7 +961,7 @@ async def bulk_update(self, objs: List[SaffierModel], fields: List[str]) -> None new_obj = {} for key, value in obj.__dict__.items(): if key in fields: - new_obj[key] = queryset._resolve_value(value) + new_obj[key] = value new_objs.append(new_obj) new_objs = [ @@ -970,11 +970,16 @@ async def bulk_update(self, objs: List[SaffierModel], fields: List[str]) -> None ] pk = getattr(queryset.table.c, queryset.pkname) - expression = queryset.table.update().where(pk == sqlalchemy.bindparam(queryset.pkname)) + expression = queryset.table.update().where( + pk == sqlalchemy.bindparam("__id" if queryset.pkname == "id" else queryset.pkname) + ) kwargs: Dict[str, Any] = { field: sqlalchemy.bindparam(field) for obj in new_objs for field in obj.keys() } - pks = [{queryset.pkname: getattr(obj, queryset.pkname)} for obj in objs] + pks = [ + {"__id" if queryset.pkname == "id" else queryset.pkname: getattr(obj, queryset.pkname)} + for obj in objs + ] query_list = [] for pk, value in zip(pks, new_objs): # noqa @@ -982,7 +987,7 @@ async def bulk_update(self, objs: List[SaffierModel], fields: List[str]) -> None expression = expression.values(kwargs) queryset._set_query_expression(expression) - await queryset.database.execute_many(str(expression), query_list) + await queryset.database.execute(expression, query_list) async def delete(self) -> None: queryset: "QuerySet" = self._clone() diff --git a/saffier/utils/inspect.py b/saffier/utils/inspect.py index a129b8a..01d33ba 100644 --- a/saffier/utils/inspect.py +++ b/saffier/utils/inspect.py @@ -74,6 +74,7 @@ def inspect(self) -> None: Starts the InspectDB and passes all the configurations. """ registry = Registry(database=self.database) + execsync(registry.database.connect)() # Get the engine to connect engine: AsyncEngine = registry.engine @@ -91,6 +92,7 @@ def inspect(self) -> None: for line in self.write_output(tables, self.database.url._url): sys.stdout.writelines(line) # type: ignore + execsync(registry.database.disconnect)() def generate_table_information( self, metadata: sqlalchemy.MetaData diff --git a/tests/cli/utils.py b/tests/cli/utils.py index 7bd78b1..2dc3a0f 100644 --- a/tests/cli/utils.py +++ b/tests/cli/utils.py @@ -1,14 +1,18 @@ import os -import shlex import subprocess def run_cmd(app, cmd, is_app=True): + env = dict(os.environ) if is_app: - os.environ["SAFFIER_DEFAULT_APP"] = app - process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (stdout, stderr) = process.communicate() + env["SAFFIER_DEFAULT_APP"] = app + # CI uses something different as workdir and we aren't hatch test yet. + if "VIRTUAL_ENV" not in env: + basedir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + if os.path.isdir(f"{basedir}/venv/bin/"): + cmd = f"{basedir}/venv/bin/{cmd}" + result = subprocess.run(cmd, capture_output=True, env=env, shell=True) print("\n$ " + cmd) - print(stdout.decode("utf-8")) - print(stderr.decode("utf-8")) - return stdout, stderr, process.wait() + print(result.stdout.decode("utf-8")) + print(result.stderr.decode("utf-8")) + return result.stdout, result.stderr, result.returncode diff --git a/tests/foreign_keys/test_foreignkey.py b/tests/foreign_keys/test_foreignkey.py index 99b4293..71153f1 100644 --- a/tests/foreign_keys/test_foreignkey.py +++ b/tests/foreign_keys/test_foreignkey.py @@ -1,8 +1,5 @@ -import sqlite3 - -import asyncpg -import pymysql import pytest +from sqlalchemy.exc import IntegrityError import saffier from saffier.testclient import DatabaseTestClient as Database @@ -227,12 +224,7 @@ async def test_on_delete_retstrict(): organisation = await Organisation.query.create(ident="Encode") await Team.query.create(org=organisation, name="Maintainers") - exceptions = ( - asyncpg.exceptions.ForeignKeyViolationError, - pymysql.err.IntegrityError, - ) - - with pytest.raises(exceptions): + with pytest.raises(IntegrityError): await organisation.delete() @@ -258,13 +250,7 @@ async def test_one_to_one_field_crud(): await person.profile.load() assert person.profile.website == "https://saffier.com" - exceptions = ( - asyncpg.exceptions.UniqueViolationError, - pymysql.err.IntegrityError, - sqlite3.IntegrityError, - ) - - with pytest.raises(exceptions): + with pytest.raises(IntegrityError): await Person.query.create(email="contact@saffier.com", profile=profile) @@ -278,13 +264,7 @@ async def test_one_to_one_crud(): await person.profile.load() assert person.profile.website == "https://saffier.com" - exceptions = ( - asyncpg.exceptions.UniqueViolationError, - pymysql.err.IntegrityError, - sqlite3.IntegrityError, - ) - - with pytest.raises(exceptions): + with pytest.raises(IntegrityError): await AnotherPerson.query.create(email="contact@saffier.com", profile=profile) diff --git a/tests/tenancy/test_mt_bulk_create.py b/tests/tenancy/test_mt_bulk_create.py index c883a17..2b4f6d3 100644 --- a/tests/tenancy/test_mt_bulk_create.py +++ b/tests/tenancy/test_mt_bulk_create.py @@ -2,7 +2,7 @@ from enum import Enum import pytest -from asyncpg.exceptions import UndefinedTableError +from sqlalchemy.exc import ProgrammingError import saffier from saffier.core.db import fields @@ -49,7 +49,7 @@ async def rollback_transactions(): async def test_bulk_create_another_tenant(): - with pytest.raises(UndefinedTableError): + with pytest.raises(ProgrammingError): await Product.query.using("another").bulk_create( [ {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, diff --git a/tests/uniques/test_unique.py b/tests/uniques/test_unique.py index 0005ad5..6fb65fd 100644 --- a/tests/uniques/test_unique.py +++ b/tests/uniques/test_unique.py @@ -2,7 +2,7 @@ from enum import Enum import pytest -from asyncpg.exceptions import UniqueViolationError +from sqlalchemy.exc import IntegrityError import saffier from saffier.testclient import DatabaseTestClient as Database @@ -51,5 +51,5 @@ async def rollback_transactions(): async def test_unique(): await User.query.create(name="Tiago", email="test@example.com") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await User.query.create(name="Tiago", email="test2@example.come") diff --git a/tests/uniques/test_unique_constraint.py b/tests/uniques/test_unique_constraint.py index 0849675..56630a0 100644 --- a/tests/uniques/test_unique_constraint.py +++ b/tests/uniques/test_unique_constraint.py @@ -2,7 +2,7 @@ from enum import Enum import pytest -from asyncpg.exceptions import UniqueViolationError +from sqlalchemy.exc import IntegrityError import saffier from saffier.core.db.datastructures import UniqueConstraint @@ -84,7 +84,7 @@ async def test_unique_together(): await User.query.create(name="Test", email="test@example.com") await User.query.create(name="Test", email="test2@example.come") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await User.query.create(name="Test", email="test@example.com") @@ -93,7 +93,7 @@ async def test_unique_together_multiple(): await HubUser.query.create(name="Test", email="test@example.com") await HubUser.query.create(name="Test", email="test2@example.come") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await HubUser.query.create(name="Test", email="test@example.com") @@ -101,7 +101,7 @@ async def test_unique_together_multiple(): async def test_unique_together_multiple_name_age(): await HubUser.query.create(name="NewTest", email="test@example.com", age=18) - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await HubUser.query.create(name="Test", email="test@example.com", age=18) @@ -109,7 +109,7 @@ async def test_unique_together_multiple_name_age(): async def test_unique_together_multiple_single_string(): await Product.query.create(name="android", sku="12345") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await Product.query.create(name="android", sku="12345") @@ -117,5 +117,5 @@ async def test_unique_together_multiple_single_string(): async def test_unique_together_multiple_single_string_two(): await Product.query.create(name="android", sku="12345") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await Product.query.create(name="iphone", sku="12345") diff --git a/tests/uniques/test_unique_together.py b/tests/uniques/test_unique_together.py index b1a6944..323786b 100644 --- a/tests/uniques/test_unique_together.py +++ b/tests/uniques/test_unique_together.py @@ -2,7 +2,7 @@ from enum import Enum import pytest -from asyncpg.exceptions import UniqueViolationError +from sqlalchemy.exc import IntegrityError import saffier from saffier.testclient import DatabaseTestClient as Database @@ -80,7 +80,7 @@ async def test_unique_together(): await User.query.create(name="Test", email="test@example.com") await User.query.create(name="Test", email="test2@example.come") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await User.query.create(name="Test", email="test@example.com") @@ -89,7 +89,7 @@ async def test_unique_together_multiple(): await HubUser.query.create(name="Test", email="test@example.com") await HubUser.query.create(name="Test", email="test2@example.come") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await HubUser.query.create(name="Test", email="test@example.com") @@ -97,7 +97,7 @@ async def test_unique_together_multiple(): async def test_unique_together_multiple_name_age(): await HubUser.query.create(name="NewTest", email="test@example.com", age=18) - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await HubUser.query.create(name="Test", email="test@example.com", age=18) @@ -105,7 +105,7 @@ async def test_unique_together_multiple_name_age(): async def test_unique_together_multiple_single_string(): await Product.query.create(name="android", sku="12345") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await Product.query.create(name="android", sku="12345") @@ -113,5 +113,5 @@ async def test_unique_together_multiple_single_string(): async def test_unique_together_multiple_single_string_two(): await Product.query.create(name="android", sku="12345") - with pytest.raises(UniqueViolationError): + with pytest.raises(IntegrityError): await Product.query.create(name="iphone", sku="12345")