From 1e40ad1e6ae718bd0d299d7c3303d625de6fc083 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Wed, 21 Feb 2024 12:25:56 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=9B=20Moving=20to=20SQLAlchemy=202.0?= =?UTF-8?q?=20(#540)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🪛 Added support for SQLAlchemy 2.0 * Added common and dialects packages to handle the new SQLAlchemy 2.0+ * 🪲 Fix specific asyncpg oriented test --------- Co-authored-by: ansipunk --- .github/workflows/test-suite.yml | 2 +- README.md | 4 +- databases/backends/aiopg.py | 69 +++++++----- databases/backends/asyncmy.py | 63 +++++++---- databases/backends/common/__init__.py | 0 databases/backends/common/records.py | 137 +++++++++++++++++++++++ databases/backends/compilers/__init__.py | 0 databases/backends/compilers/psycopg.py | 17 +++ databases/backends/dialects/__init__.py | 0 databases/backends/dialects/psycopg.py | 46 ++++++++ databases/backends/mysql.py | 60 ++++++---- databases/backends/postgres.py | 122 ++------------------ databases/backends/sqlite.py | 68 +++++------ databases/core.py | 2 +- docs/index.md | 4 +- scripts/clean | 6 + setup.cfg | 5 + setup.py | 3 +- tests/test_databases.py | 61 ++-------- 19 files changed, 394 insertions(+), 275 deletions(-) create mode 100644 databases/backends/common/__init__.py create mode 100644 databases/backends/common/records.py create mode 100644 databases/backends/compilers/__init__.py create mode 100644 databases/backends/compilers/psycopg.py create mode 100644 databases/backends/dialects/__init__.py create mode 100644 databases/backends/dialects/psycopg.py diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index bc271a65..3c01b801 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] services: mysql: diff --git a/README.md b/README.md index ba16a104..89107f2c 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ values = [ ] await database.execute_many(query=query, values=values) -# Run a database query. +# Run a database query. query = "SELECT * FROM HighScores" rows = await database.fetch_all(query=query) print('High Scores:', rows) @@ -115,4 +115,4 @@ for examples of how to start using databases together with SQLAlchemy core expre [quart]: https://gitlab.com/pgjones/quart [aiohttp]: https://github.com/aio-libs/aiohttp [tornado]: https://github.com/tornadoweb/tornado -[fastapi]: https://github.com/tiangolo/fastapi +[fastapi]: https://github.com/tiangolo/fastapi \ No newline at end of file diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 8668b2b9..0b4d95a3 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -5,19 +5,20 @@ import uuid import aiopg -from aiopg.sa.engine import APGCompiler_psycopg2 -from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from databases.core import DatabaseURL +from databases.backends.common.records import Record, Row, create_column_maps +from databases.backends.compilers.psycopg import PGCompiler_psycopg +from databases.backends.dialects.psycopg import PGDialect_psycopg +from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -34,10 +35,10 @@ def __init__( self._pool: typing.Union[aiopg.Pool, None] = None def _get_dialect(self) -> Dialect: - dialect = PGDialect_psycopg2( + dialect = PGDialect_psycopg( json_serializer=json.dumps, json_deserializer=lambda x: x ) - dialect.statement_compiler = APGCompiler_psycopg2 + dialect.statement_compiler = PGCompiler_psycopg dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ @@ -117,30 +118,35 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] finally: cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -148,19 +154,19 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -173,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor = await self._connection.cursor() try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: cursor.close() @@ -182,36 +188,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: cursor.close() def transaction(self) -> TransactionBackend: return AiopgTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -224,11 +231,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - logger.debug("Query: %s\nArgs: %s", compiled.string, args) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiopg.connection.Connection: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index 0811ef21..040a4346 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -7,15 +7,15 @@ from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -108,30 +108,37 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [ + Record(row, result_columns, dialect, column_maps) for row in rows + ] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) @@ -139,19 +146,19 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: await cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) @@ -166,7 +173,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: async with self._connection.cursor() as cursor: try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -175,36 +182,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.cursor() as cursor: try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: await cursor.close() def transaction(self) -> TransactionBackend: return AsyncMyTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -217,12 +225,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") - logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> asyncmy.connection.Connection: diff --git a/databases/backends/common/__init__.py b/databases/backends/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/common/records.py b/databases/backends/common/records.py new file mode 100644 index 00000000..1d8a2fd4 --- /dev/null +++ b/databases/backends/common/records.py @@ -0,0 +1,137 @@ +import json +import typing +from datetime import date, datetime + +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.engine.row import Row as SQLRow +from sqlalchemy.sql.compiler import _CompileLabel +from sqlalchemy.sql.schema import Column +from sqlalchemy.types import TypeEngine + +from databases.interfaces import Record as RecordInterface + +DIALECT_EXCLUDE = {"postgresql"} + + +class Record(RecordInterface): + __slots__ = ( + "_row", + "_result_columns", + "_dialect", + "_column_map", + "_column_map_int", + "_column_map_full", + ) + + def __init__( + self, + row: typing.Any, + result_columns: tuple, + dialect: Dialect, + column_maps: typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], + ], + ) -> None: + self._row = row + self._result_columns = result_columns + self._dialect = dialect + self._column_map, self._column_map_int, self._column_map_full = column_maps + + @property + def _mapping(self) -> typing.Mapping: + return self._row + + def keys(self) -> typing.KeysView: + return self._mapping.keys() + + def values(self) -> typing.ValuesView: + return self._mapping.values() + + def __getitem__(self, key: typing.Any) -> typing.Any: + if len(self._column_map) == 0: + return self._row[key] + elif isinstance(key, Column): + idx, datatype = self._column_map_full[str(key)] + elif isinstance(key, int): + idx, datatype = self._column_map_int[key] + else: + idx, datatype = self._column_map[key] + + raw = self._row[idx] + processor = datatype._cached_result_processor(self._dialect, None) + + if self._dialect.name not in DIALECT_EXCLUDE: + if isinstance(raw, dict): + raw = json.dumps(raw) + + if processor is not None and (not isinstance(raw, (datetime, date))): + return processor(raw) + return raw + + def __iter__(self) -> typing.Iterator: + return iter(self._row.keys()) + + def __len__(self) -> int: + return len(self._row) + + def __getattr__(self, name: str) -> typing.Any: + try: + return self.__getitem__(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +class Row(SQLRow): + def __getitem__(self, key: typing.Any) -> typing.Any: + """ + An instance of a Row in SQLAlchemy allows the access + to the Row._fields as tuple and the Row._mapping for + the values. + """ + if isinstance(key, int): + return super().__getitem__(key) + + idx = self._key_to_index[key][0] + return super().__getitem__(idx) + + def keys(self): + return self._mapping.keys() + + def values(self): + return self._mapping.values() + + +def create_column_maps( + result_columns: typing.Any, +) -> typing.Tuple[ + typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], + typing.Mapping[int, typing.Tuple[int, TypeEngine]], + typing.Mapping[str, typing.Tuple[int, TypeEngine]], +]: + """ + Generate column -> datatype mappings from the column definitions. + + These mappings are used throughout PostgresConnection methods + to initialize Record-s. The underlying DB driver does not do type + conversion for us so we have wrap the returned asyncpg.Record-s. + + :return: Three mappings from different ways to address a column to \ + corresponding column indexes and datatypes: \ + 1. by column identifier; \ + 2. by column index; \ + 3. by column name in Column sqlalchemy objects. + """ + column_map, column_map_int, column_map_full = {}, {}, {} + for idx, (column_name, _, column, datatype) in enumerate(result_columns): + column_map[column_name] = (idx, datatype) + column_map_int[idx] = (idx, datatype) + + # Added in SQLA 2.0 and _CompileLabels do not have _annotations + # When this happens, the mapping is on the second position + if isinstance(column[0], _CompileLabel): + column_map_full[str(column[2])] = (idx, datatype) + else: + column_map_full[str(column[0])] = (idx, datatype) + return column_map, column_map_int, column_map_full diff --git a/databases/backends/compilers/__init__.py b/databases/backends/compilers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/compilers/psycopg.py b/databases/backends/compilers/psycopg.py new file mode 100644 index 00000000..654c22a1 --- /dev/null +++ b/databases/backends/compilers/psycopg.py @@ -0,0 +1,17 @@ +from sqlalchemy.dialects.postgresql.psycopg import PGCompiler_psycopg + + +class APGCompiler_psycopg2(PGCompiler_psycopg): + def construct_params(self, *args, **kwargs): + pd = super().construct_params(*args, **kwargs) + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default): + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg diff --git a/databases/backends/dialects/__init__.py b/databases/backends/dialects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py new file mode 100644 index 00000000..07bd1880 --- /dev/null +++ b/databases/backends/dialects/psycopg.py @@ -0,0 +1,46 @@ +""" +All the unique changes for the databases package +with the custom Numeric as the deprecated pypostgresql +for backwards compatibility and to make sure the +package can go to SQLAlchemy 2.0+. +""" + +import typing + +from sqlalchemy import types, util +from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext +from sqlalchemy.engine import processors +from sqlalchemy.types import Float, Numeric + + +class PGExecutionContext_psycopg(PGExecutionContext): + ... + + +class PGNumeric(Numeric): + def bind_processor( + self, dialect: typing.Any + ) -> typing.Union[str, None]: # pragma: no cover + return processors.to_str + + def result_processor( + self, dialect: typing.Any, coltype: typing.Any + ) -> typing.Union[float, None]: # pragma: no cover + if self.asdecimal: + return None + else: + return processors.to_float + + +class PGDialect_psycopg(PGDialect): + colspecs = util.update_copy( + PGDialect.colspecs, + { + types.Numeric: PGNumeric, + types.Float: Float, + }, + ) + execution_ctx_cls = PGExecutionContext_psycopg + + +dialect = PGDialect_psycopg diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 630f7cd3..792f3685 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -7,15 +7,15 @@ from sqlalchemy.dialects.mysql import pymysql from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, DatabaseBackend, - Record, + Record as RecordInterface, TransactionBackend, ) @@ -108,30 +108,34 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -139,19 +143,19 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) finally: await cursor.close() async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, _, _ = self._compile(query) cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) @@ -166,7 +170,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor = await self._connection.cursor() try: for single_query in queries: - single_query, args, context = self._compile(single_query) + single_query, args, _, _ = self._compile(single_query) await cursor.execute(single_query, args) finally: await cursor.close() @@ -175,36 +179,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect cursor = await self._connection.cursor() try: await cursor.execute(query_str, args) metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) finally: await cursor.close() def transaction(self) -> TransactionBackend: return MySQLTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, dict, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + args = compiled.construct_params() for key, val in args.items(): if key in compiled._bind_processors: @@ -217,12 +222,23 @@ def _compile( compiled._ad_hoc_textual, compiled._loose_column_name_matching, ) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + else: args = {} + result_map = None + compiled_query = compiled.string - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") - logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) - return compiled.string, args, CompilationContext(execution_context) + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") + logger.debug( + "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA + ) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiomysql.connection.Connection: diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 85972c3d..c42688e1 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -2,13 +2,12 @@ import typing import asyncpg -from sqlalchemy.dialects.postgresql import pypostgresql from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.sql.schema import Column -from sqlalchemy.types import TypeEngine +from databases.backends.common.records import Record, create_column_maps +from databases.backends.dialects.psycopg import dialect as psycopg_dialect from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, @@ -30,7 +29,7 @@ def __init__( self._pool = None def _get_dialect(self) -> Dialect: - dialect = pypostgresql.dialect(paramstyle="pyformat") + dialect = psycopg_dialect(paramstyle="pyformat") dialect.implicit_returning = True dialect.supports_native_enum = True @@ -83,82 +82,6 @@ def connection(self) -> "PostgresConnection": return PostgresConnection(self, self._dialect) -class Record(RecordInterface): - __slots__ = ( - "_row", - "_result_columns", - "_dialect", - "_column_map", - "_column_map_int", - "_column_map_full", - ) - - def __init__( - self, - row: asyncpg.Record, - result_columns: tuple, - dialect: Dialect, - column_maps: typing.Tuple[ - typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], - typing.Mapping[int, typing.Tuple[int, TypeEngine]], - typing.Mapping[str, typing.Tuple[int, TypeEngine]], - ], - ) -> None: - self._row = row - self._result_columns = result_columns - self._dialect = dialect - self._column_map, self._column_map_int, self._column_map_full = column_maps - - @property - def _mapping(self) -> typing.Mapping: - return self._row - - def keys(self) -> typing.KeysView: - import warnings - - warnings.warn( - "The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.keys()` instead.", - DeprecationWarning, - ) - return self._mapping.keys() - - def values(self) -> typing.ValuesView: - import warnings - - warnings.warn( - "The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.values()` instead.", - DeprecationWarning, - ) - return self._mapping.values() - - def __getitem__(self, key: typing.Any) -> typing.Any: - if len(self._column_map) == 0: # raw query - return self._row[key] - elif isinstance(key, Column): - idx, datatype = self._column_map_full[str(key)] - elif isinstance(key, int): - idx, datatype = self._column_map_int[key] - else: - idx, datatype = self._column_map[key] - raw = self._row[idx] - processor = datatype._cached_result_processor(self._dialect, None) - - if processor is not None: - return processor(raw) - return raw - - def __iter__(self) -> typing.Iterator: - return iter(self._row.keys()) - - def __len__(self) -> int: - return len(self._row) - - def __getattr__(self, name: str) -> typing.Any: - return self._mapping.get(name) - - class PostgresConnection(ConnectionBackend): def __init__(self, database: PostgresBackend, dialect: Dialect): self._database = database @@ -181,7 +104,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: query_str, args, result_columns = self._compile(query) rows = await self._connection.fetch(query_str, *args) dialect = self._dialect - column_maps = self._create_column_maps(result_columns) + column_maps = create_column_maps(result_columns) return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: @@ -194,7 +117,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa row, result_columns, self._dialect, - self._create_column_maps(result_columns), + create_column_maps(result_columns), ) async def fetch_val( @@ -214,7 +137,7 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, result_columns = self._compile(query) + query_str, args, _ = self._compile(query) return await self._connection.fetchval(query_str, *args) async def execute_many(self, queries: typing.List[ClauseElement]) -> None: @@ -223,7 +146,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: # loop through multiple executes here, which should all end up # using the same prepared statement. for single_query in queries: - single_query, args, result_columns = self._compile(single_query) + single_query, args, _ = self._compile(single_query) await self._connection.execute(single_query, *args) async def iterate( @@ -231,7 +154,7 @@ async def iterate( ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) - column_maps = self._create_column_maps(result_columns) + column_maps = create_column_maps(result_columns) async for row in self._connection.cursor(query_str, *args): yield Record(row, result_columns, self._dialect, column_maps) @@ -256,7 +179,6 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: processors[key](val) if key in processors else val for key, val in compiled_params ] - result_map = compiled._result_columns else: compiled_query = compiled.string @@ -269,34 +191,6 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, result_map - @staticmethod - def _create_column_maps( - result_columns: tuple, - ) -> typing.Tuple[ - typing.Mapping[typing.Any, typing.Tuple[int, TypeEngine]], - typing.Mapping[int, typing.Tuple[int, TypeEngine]], - typing.Mapping[str, typing.Tuple[int, TypeEngine]], - ]: - """ - Generate column -> datatype mappings from the column definitions. - - These mappings are used throughout PostgresConnection methods - to initialize Record-s. The underlying DB driver does not do type - conversion for us so we have wrap the returned asyncpg.Record-s. - - :return: Three mappings from different ways to address a column to \ - corresponding column indexes and datatypes: \ - 1. by column identifier; \ - 2. by column index; \ - 3. by column name in Column sqlalchemy objects. - """ - column_map, column_map_int, column_map_full = {}, {}, {} - for idx, (column_name, _, column, datatype) in enumerate(result_columns): - column_map[column_name] = (idx, datatype) - column_map_int[idx] = (idx, datatype) - column_map_full[str(column[0])] = (idx, datatype) - return column_map, column_map_int, column_map_full - @property def raw_connection(self) -> asyncpg.connection.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 1267fe20..16e17e9e 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -8,17 +8,12 @@ from sqlalchemy.dialects.sqlite import pysqlite from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement +from databases.backends.common.records import Record, Row, create_column_maps from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ( - ConnectionBackend, - DatabaseBackend, - Record, - TransactionBackend, -) +from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") @@ -35,16 +30,7 @@ def __init__( self._pool = SQLitePool(self._database_url, **self._options) async def connect(self) -> None: - pass - # assert self._pool is None, "DatabaseBackend is already running" - # self._pool = await aiomysql.create_pool( - # host=self._database_url.hostname, - # port=self._database_url.port or 3306, - # user=self._database_url.username or getpass.getuser(), - # password=self._database_url.password, - # db=self._database_url.database, - # autocommit=True, - # ) + ... async def disconnect(self) -> None: # if it extsis, remove reference to connection to cached in-memory database on disconnect @@ -105,42 +91,46 @@ async def release(self) -> None: async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.execute(query_str, args) as cursor: rows = await cursor.fetchall() metadata = CursorResultMetaData(context, cursor.description) - return [ + rows = [ Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) for row in rows ] + return [Record(row, result_columns, dialect, column_maps) for row in rows] async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect async with self._connection.execute(query_str, args) as cursor: row = await cursor.fetchone() if row is None: return None metadata = CursorResultMetaData(context, cursor.description) - return Row( + row = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + return Record(row, result_columns, dialect, column_maps) async def execute(self, query: ClauseElement) -> typing.Any: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) async with self._connection.cursor() as cursor: await cursor.execute(query_str, args) if cursor.lastrowid == 0: @@ -156,34 +146,37 @@ async def iterate( self, query: ClauseElement ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" - query_str, args, context = self._compile(query) + query_str, args, result_columns, context = self._compile(query) + column_maps = create_column_maps(result_columns) + dialect = self._dialect + async with self._connection.execute(query_str, args) as cursor: metadata = CursorResultMetaData(context, cursor.description) async for row in cursor: - yield Row( + record = Row( metadata, metadata._processors, metadata._keymap, - Row._default_key_style, row, ) + yield Record(record, result_columns, dialect, column_maps) def transaction(self) -> TransactionBackend: return SQLiteTransaction(self) - def _compile( - self, query: ClauseElement - ) -> typing.Tuple[str, list, CompilationContext]: + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( dialect=self._dialect, compile_kwargs={"render_postcompile": True} ) - execution_context = self._dialect.execution_ctx_cls() execution_context.dialect = self._dialect args = [] + result_map = None if not isinstance(query, DDLElement): + compiled_params = sorted(compiled.params.items()) + params = compiled.construct_params() for key in compiled.positiontup: raw_val = params[key] @@ -201,11 +194,20 @@ def _compile( compiled._loose_column_name_matching, ) - query_message = compiled.string.replace(" \n", " ").replace("\n", " ") + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + result_map = compiled._result_columns + + else: + compiled_query = compiled.string + + query_message = compiled_query.replace(" \n", " ").replace("\n", " ") logger.debug( "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA ) - return compiled.string, args, CompilationContext(execution_context) + return compiled.string, args, result_map, CompilationContext(execution_context) @property def raw_connection(self) -> aiosqlite.core.Connection: diff --git a/databases/core.py b/databases/core.py index 795609ea..d55dd3c8 100644 --- a/databases/core.py +++ b/databases/core.py @@ -356,7 +356,7 @@ def _build_query( return query.bindparams(**values) if values is not None else query elif values: - return query.values(**values) + return query.values(**values) # type: ignore return query diff --git a/docs/index.md b/docs/index.md index b18de817..fba3f147 100644 --- a/docs/index.md +++ b/docs/index.md @@ -83,7 +83,7 @@ values = [ ] await database.execute_many(query=query, values=values) -# Run a database query. +# Run a database query. query = "SELECT * FROM HighScores" rows = await database.fetch_all(query=query) print('High Scores:', rows) @@ -113,4 +113,4 @@ for examples of how to start using databases together with SQLAlchemy core expre [quart]: https://gitlab.com/pgjones/quart [aiohttp]: https://github.com/aio-libs/aiohttp [tornado]: https://github.com/tornadoweb/tornado -[fastapi]: https://github.com/tiangolo/fastapi +[fastapi]: https://github.com/tiangolo/fastapi \ No newline at end of file diff --git a/scripts/clean b/scripts/clean index f01cc831..d7388629 100755 --- a/scripts/clean +++ b/scripts/clean @@ -9,6 +9,12 @@ fi if [ -d 'databases.egg-info' ] ; then rm -r databases.egg-info fi +if [ -d '.mypy_cache' ] ; then + rm -r .mypy_cache +fi +if [ -d '.pytest_cache' ] ; then + rm -r .pytest_cache +fi find databases -type f -name "*.py[co]" -delete find databases -type d -name __pycache__ -delete diff --git a/setup.cfg b/setup.cfg index da1831fd..b4182c83 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,11 @@ disallow_untyped_defs = True ignore_missing_imports = True no_implicit_optional = True +disallow_any_generics = false +disallow_untyped_decorators = true +implicit_reexport = true +disallow_incomplete_defs = true +exclude = databases/backends [tool:isort] profile = black diff --git a/setup.py b/setup.py index 3725cab9..a6bb8965 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages("databases"), package_data={"databases": ["py.typed"]}, - install_requires=["sqlalchemy>=1.4.42,<1.5"], + install_requires=["sqlalchemy>=2.0.7"], extras_require={ "postgresql": ["asyncpg"], "asyncpg": ["asyncpg"], @@ -70,6 +70,7 @@ def get_packages(package): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", ], zip_safe=False, diff --git a/tests/test_databases.py b/tests/test_databases.py index 144691b6..cd907fd1 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -5,7 +5,6 @@ import gc import itertools import os -import re import sqlite3 from typing import MutableMapping from unittest.mock import MagicMock, patch @@ -174,24 +173,24 @@ async def test_queries(database_url): assert result["completed"] == True # fetch_val() - query = sqlalchemy.sql.select([notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.text]) result = await database.fetch_val(query=query) assert result == "example1" # fetch_val() with no rows - query = sqlalchemy.sql.select([notes.c.text]).where( + query = sqlalchemy.sql.select(*[notes.c.text]).where( notes.c.text == "impossible" ) result = await database.fetch_val(query=query) assert result is None # fetch_val() with a different column - query = sqlalchemy.sql.select([notes.c.id, notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.id, notes.c.text]) result = await database.fetch_val(query=query, column=1) assert result == "example1" # row access (needed to maintain test coverage for Record.__getitem__ in postgres backend) - query = sqlalchemy.sql.select([notes.c.text]) + query = sqlalchemy.sql.select(*[notes.c.text]) result = await database.fetch_one(query=query) assert result["text"] == "example1" assert result[0] == "example1" @@ -251,6 +250,7 @@ async def test_queries_raw(database_url): query = "SELECT completed FROM notes WHERE text = :text" result = await database.fetch_val(query=query, values={"text": "example1"}) assert result == True + query = "SELECT * FROM notes WHERE text = :text" result = await database.fetch_val( query=query, values={"text": "example1"}, column="completed" @@ -361,7 +361,7 @@ async def test_results_support_column_reference(database_url): await database.execute(query, values) # fetch_all() - query = sqlalchemy.select([articles, custom_date]) + query = sqlalchemy.select(*[articles, custom_date]) results = await database.fetch_all(query=query) assert len(results) == 1 assert results[0][articles.c.title] == "Hello, world Article" @@ -753,6 +753,7 @@ def insert_independently(): query = notes.insert().values(text="example1", completed=True) conn.execute(query) + conn.close() def delete_independently(): engine = sqlalchemy.create_engine(str(database_url)) @@ -760,6 +761,7 @@ def delete_independently(): query = notes.delete() conn.execute(query) + conn.close() async with Database(database_url) as database: async with database.transaction(force_rollback=True, isolation="serializable"): @@ -971,6 +973,7 @@ async def test_json_field(database_url): # fetch_all() query = session.select() results = await database.fetch_all(query=query) + assert len(results) == 1 assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} @@ -1455,52 +1458,6 @@ async def test_column_names(database_url, select_query): assert results[0]["completed"] == True -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_posgres_interface(database_url): - """ - Since SQLAlchemy 1.4, `Row.values()` is removed and `Row.keys()` is deprecated. - Custom postgres interface mimics more or less this behaviour by deprecating those - two methods - """ - database_url = DatabaseURL(database_url) - - if database_url.scheme not in ["postgresql", "postgresql+asyncpg"]: - pytest.skip("Test is only for asyncpg") - - async with Database(database_url) as database: - async with database.transaction(force_rollback=True): - query = notes.insert() - values = {"text": "example1", "completed": True} - await database.execute(query, values) - - query = notes.select() - result = await database.fetch_one(query=query) - - with pytest.warns( - DeprecationWarning, - match=re.escape( - "The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.keys()` instead." - ), - ): - assert ( - list(result.keys()) - == [k for k in result] - == ["id", "text", "completed"] - ) - - with pytest.warns( - DeprecationWarning, - match=re.escape( - "The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, " - "use `Row._mapping.values()` instead." - ), - ): - # avoid checking `id` at index 0 since it may change depending on the launched tests - assert list(result.values())[1:] == ["example1", True] - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_postcompile_queries(database_url):