From c9f8b0f7087f56c1656a5f883e511d9cc2952979 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Sat, 18 Jul 2020 14:07:38 -0700 Subject: [PATCH] Allow using custom Record class Add the new `record_class` parameter to the `create_pool()` and `connect()` functions, as well as to the `cursor()`, `prepare()`, `fetch()` and `fetchrow()` connection methods. This not only allows adding custom functionality to the returned objects, but also assists with typing (see #577 for discussion). Fixes: #40. --- .flake8 | 2 +- asyncpg/_testbase/__init__.py | 3 + asyncpg/connect_utils.py | 27 ++- asyncpg/connection.py | 303 ++++++++++++++++++++++----- asyncpg/cursor.py | 82 ++++++-- asyncpg/pool.py | 25 ++- asyncpg/prepared_stmt.py | 12 +- asyncpg/protocol/codecs/base.pyx | 3 +- asyncpg/protocol/prepared_stmt.pxd | 2 + asyncpg/protocol/prepared_stmt.pyx | 11 +- asyncpg/protocol/protocol.pxd | 1 + asyncpg/protocol/protocol.pyx | 15 +- asyncpg/protocol/record/__init__.pxd | 2 +- asyncpg/protocol/record/recordobj.c | 47 +++-- asyncpg/protocol/record/recordobj.h | 2 +- tests/test_record.py | 174 +++++++++++++++ tests/test_timeout.py | 4 +- 17 files changed, 610 insertions(+), 105 deletions(-) diff --git a/.flake8 b/.flake8 index 7cf64d1f..9697fc96 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] -ignore = E402,E731,W504,E252 +ignore = E402,E731,W503,W504,E252 exclude = .git,__pycache__,build,dist,.eggs,.github,.local diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index baf55c1b..ce7f827f 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -19,6 +19,7 @@ import unittest +import asyncpg from asyncpg import cluster as pg_cluster from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool @@ -266,6 +267,7 @@ def create_pool(dsn=None, *, loop=None, pool_class=pg_pool.Pool, connection_class=pg_connection.Connection, + record_class=asyncpg.Record, **connect_kwargs): return pool_class( dsn, @@ -273,6 +275,7 @@ def create_pool(dsn=None, *, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, connection_class=connection_class, + record_class=record_class, **connect_kwargs) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 2678b358..e5feebc2 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -594,8 +594,16 @@ async def _create_ssl_connection(protocol_factory, host, port, *, raise -async def _connect_addr(*, addr, loop, timeout, params, config, - connection_class): +async def _connect_addr( + *, + addr, + loop, + timeout, + params, + config, + connection_class, + record_class +): assert loop is not None if timeout <= 0: @@ -613,7 +621,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config, params = params._replace(password=password) proto_factory = lambda: protocol.Protocol( - addr, connected, params, loop) + addr, connected, params, record_class, loop) if isinstance(addr, str): # UNIX socket @@ -649,7 +657,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config, return con -async def _connect(*, loop, timeout, connection_class, **kwargs): +async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() @@ -661,9 +669,14 @@ async def _connect(*, loop, timeout, connection_class, **kwargs): before = time.monotonic() try: con = await _connect_addr( - addr=addr, loop=loop, timeout=timeout, - params=params, config=config, - connection_class=connection_class) + addr=addr, + loop=loop, + timeout=timeout, + params=params, + config=config, + connection_class=connection_class, + record_class=record_class, + ) except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex else: diff --git a/asyncpg/connection.py b/asyncpg/connection.py index a78aafa7..7f502ace 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -50,7 +50,7 @@ class Connection(metaclass=ConnectionMeta): '_source_traceback', '__weakref__') def __init__(self, protocol, transport, loop, - addr: (str, int) or str, + addr, config: connect_utils._ClientConfiguration, params: connect_utils._ConnectionParameters): self._protocol = protocol @@ -294,7 +294,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: if not args: return await self._protocol.query(query, timeout) - _, status, _ = await self._execute(query, args, 0, timeout, True) + _, status, _ = await self._execute( + query, + args, + 0, + timeout, + return_status=True, + ) return status.decode() async def executemany(self, command: str, args, *, timeout: float=None): @@ -327,10 +333,20 @@ async def executemany(self, command: str, args, *, timeout: float=None): self._check_open() return await self._executemany(command, args, timeout) - async def _get_statement(self, query, timeout, *, named: bool=False, - use_cache: bool=True): + async def _get_statement( + self, + query, + timeout, + *, + named: bool=False, + use_cache: bool=True, + record_class=None + ): + if record_class is None: + record_class = self._protocol.get_record_class() + if use_cache: - statement = self._stmt_cache.get(query) + statement = self._stmt_cache.get((query, record_class)) if statement is not None: return statement @@ -348,7 +364,12 @@ async def _get_statement(self, query, timeout, *, named: bool=False, else: stmt_name = '' - statement = await self._protocol.prepare(stmt_name, query, timeout) + statement = await self._protocol.prepare( + stmt_name, + query, + timeout, + record_class=record_class, + ) need_reprepare = False types_with_missing_codecs = statement._init_types() tries = 0 @@ -384,10 +405,15 @@ async def _get_statement(self, query, timeout, *, named: bool=False, if need_reprepare: await self._protocol.prepare( - stmt_name, query, timeout, state=statement) + stmt_name, + query, + timeout, + state=statement, + record_class=record_class, + ) if use_cache: - self._stmt_cache.put(query, statement) + self._stmt_cache.put((query, record_class), statement) # If we've just created a new statement object, check if there # are any statements for GC. @@ -400,47 +426,124 @@ async def _introspect_types(self, typeoids, timeout): return await self.__execute( self._intro_query, (list(typeoids),), 0, timeout) - def cursor(self, query, *args, prefetch=None, timeout=None): + def cursor( + self, + query, + *args, + prefetch=None, + timeout=None, + record_class=None + ): """Return a *cursor factory* for the specified query. - :param args: Query arguments. - :param int prefetch: The number of rows the *cursor iterator* - will prefetch (defaults to ``50``.) - :param float timeout: Optional timeout in seconds. + :param args: + Query arguments. + :param int prefetch: + The number of rows the *cursor iterator* + will prefetch (defaults to ``50``.) + :param float timeout: + Optional timeout in seconds. + :param type record_class: + If specified, the class to use for records returned by this cursor. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return: + A :class:`~cursor.CursorFactory` object. - :return: A :class:`~cursor.CursorFactory` object. + .. versionchanged:: 0.21.0 + Added the *record_class* parameter. """ self._check_open() - return cursor.CursorFactory(self, query, None, args, - prefetch, timeout) + return cursor.CursorFactory( + self, + query, + None, + args, + prefetch, + timeout, + record_class, + ) - async def prepare(self, query, *, timeout=None): + async def prepare(self, query, *, timeout=None, record_class=None): """Create a *prepared statement* for the specified query. - :param str query: Text of the query to create a prepared statement for. - :param float timeout: Optional timeout value in seconds. + :param str query: + Text of the query to create a prepared statement for. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by the + prepared statement. Must be a subclass of + :class:`~asyncpg.Record`. If not specified, a per-connection + *record_class* is used. + + :return: + A :class:`~prepared_stmt.PreparedStatement` instance. - :return: A :class:`~prepared_stmt.PreparedStatement` instance. + .. versionchanged:: 0.21.0 + Added the *record_class* parameter. """ - return await self._prepare(query, timeout=timeout, use_cache=False) + return await self._prepare( + query, + timeout=timeout, + use_cache=False, + record_class=record_class, + ) - async def _prepare(self, query, *, timeout=None, use_cache: bool=False): + async def _prepare( + self, + query, + *, + timeout=None, + use_cache: bool=False, + record_class=None + ): self._check_open() - stmt = await self._get_statement(query, timeout, named=True, - use_cache=use_cache) + stmt = await self._get_statement( + query, + timeout, + named=True, + use_cache=use_cache, + record_class=record_class, + ) return prepared_stmt.PreparedStatement(self, query, stmt) - async def fetch(self, query, *args, timeout=None) -> list: + async def fetch( + self, + query, + *args, + timeout=None, + record_class=None + ) -> list: """Run a query and return the results as a list of :class:`Record`. - :param str query: Query text. - :param args: Query arguments. - :param float timeout: Optional timeout value in seconds. + :param str query: + Query text. + :param args: + Query arguments. + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for records returned by this method. + Must be a subclass of :class:`~asyncpg.Record`. If not specified, + a per-connection *record_class* is used. + + :return list: + A list of :class:`~asyncpg.Record` instances. If specified, the + actual type of list elements would be *record_class*. - :return list: A list of :class:`Record` instances. + .. versionchanged:: 0.21.0 + Added the *record_class* parameter. """ self._check_open() - return await self._execute(query, args, 0, timeout) + return await self._execute( + query, + args, + 0, + timeout, + record_class=record_class, + ) async def fetchval(self, query, *args, column=0, timeout=None): """Run a query and return a value in the first row. @@ -463,18 +566,42 @@ async def fetchval(self, query, *args, column=0, timeout=None): return None return data[0][column] - async def fetchrow(self, query, *args, timeout=None): + async def fetchrow( + self, + query, + *args, + timeout=None, + record_class=None + ): """Run a query and return the first row. - :param str query: Query text - :param args: Query arguments - :param float timeout: Optional timeout value in seconds. - - :return: The first row as a :class:`Record` instance, or None if - no records were returned by the query. + :param str query: + Query text + :param args: + Query arguments + :param float timeout: + Optional timeout value in seconds. + :param type record_class: + If specified, the class to use for the value returned by this + method. Must be a subclass of :class:`~asyncpg.Record`. + If not specified, a per-connection *record_class* is used. + + :return: + The first row as a :class:`~asyncpg.Record` instance, or None if + no records were returned by the query. If specified, + *record_class* is used as the type for the result value. + + .. versionchanged:: 0.21.0 + Added the *record_class* parameter. """ self._check_open() - data = await self._execute(query, args, 1, timeout) + data = await self._execute( + query, + args, + 1, + timeout, + record_class=record_class, + ) if not data: return None return data[0] @@ -1185,7 +1312,10 @@ def _mark_stmts_as_closed(self): self._stmts_to_close.clear() def _maybe_gc_stmt(self, stmt): - if stmt.refs == 0 and not self._stmt_cache.has(stmt.query): + if ( + stmt.refs == 0 + and not self._stmt_cache.has((stmt.query, stmt.record_class)) + ): # If low-level `stmt` isn't referenced from any high-level # `PreparedStatement` object and is not in the `_stmt_cache`: # @@ -1440,18 +1570,46 @@ async def reload_schema_state(self): self._drop_global_type_cache() self._drop_global_statement_cache() - async def _execute(self, query, args, limit, timeout, return_status=False): + async def _execute( + self, + query, + args, + limit, + timeout, + *, + return_status=False, + record_class=None + ): with self._stmt_exclusive_section: result, _ = await self.__execute( - query, args, limit, timeout, return_status=return_status) + query, + args, + limit, + timeout, + return_status=return_status, + record_class=record_class, + ) return result - async def __execute(self, query, args, limit, timeout, - return_status=False): + async def __execute( + self, + query, + args, + limit, + timeout, + *, + return_status=False, + record_class=None + ): executor = lambda stmt, timeout: self._protocol.bind_execute( stmt, args, '', limit, return_status, timeout) timeout = self._protocol._get_timeout(timeout) - return await self._do_execute(query, executor, timeout) + return await self._do_execute( + query, + executor, + timeout, + record_class=record_class, + ) async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( @@ -1461,12 +1619,28 @@ async def _executemany(self, query, args, timeout): result, _ = await self._do_execute(query, executor, timeout) return result - async def _do_execute(self, query, executor, timeout, retry=True): + async def _do_execute( + self, + query, + executor, + timeout, + retry=True, + *, + record_class=None + ): if timeout is None: - stmt = await self._get_statement(query, None) + stmt = await self._get_statement( + query, + None, + record_class=record_class, + ) else: before = time.monotonic() - stmt = await self._get_statement(query, timeout) + stmt = await self._get_statement( + query, + timeout, + record_class=record_class, + ) after = time.monotonic() timeout -= after - before before = after @@ -1535,6 +1709,7 @@ async def connect(dsn=None, *, command_timeout=None, ssl=None, connection_class=Connection, + record_class=protocol.Record, server_settings=None): r"""A coroutine to establish a connection to a PostgreSQL server. @@ -1654,10 +1829,15 @@ async def connect(dsn=None, *, PostgreSQL documentation for a `list of supported options `_. - :param Connection connection_class: + :param type connection_class: Class of the returned connection object. Must be a subclass of :class:`~asyncpg.connection.Connection`. + :param type record_class: + If specified, the class to use for records returned by queries on + this connection object. Must be a subclass of + :class:`~asyncpg.Record`. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -1696,6 +1876,9 @@ async def connect(dsn=None, *, .. versionchanged:: 0.21.0 The *password* argument now accepts a callable or an async function. + .. versionchanged:: 0.21.0 + Added the *record_class* parameter. + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext .. _create_default_context: https://docs.python.org/3/library/ssl.html#ssl.create_default_context @@ -1712,19 +1895,33 @@ async def connect(dsn=None, *, 'connection_class is expected to be a subclass of ' 'asyncpg.Connection, got {!r}'.format(connection_class)) + if not issubclass(record_class, protocol.Record): + raise TypeError( + 'record_class is expected to be a subclass of ' + 'asyncpg.Record, got {!r}'.format(record_class)) + if loop is None: loop = asyncio.get_event_loop() return await connect_utils._connect( - loop=loop, timeout=timeout, connection_class=connection_class, - dsn=dsn, host=host, port=port, user=user, - password=password, passfile=passfile, - ssl=ssl, database=database, + loop=loop, + timeout=timeout, + connection_class=connection_class, + record_class=record_class, + dsn=dsn, + host=host, + port=port, + user=user, + password=password, + passfile=passfile, + ssl=ssl, + database=database, server_settings=server_settings, command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size) + max_cacheable_statement_size=max_cacheable_statement_size, + ) class _StatementCacheEntry: diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py index 030def0e..978824c3 100644 --- a/asyncpg/cursor.py +++ b/asyncpg/cursor.py @@ -19,15 +19,32 @@ class CursorFactory(connresource.ConnectionResource): results of a large query. """ - __slots__ = ('_state', '_args', '_prefetch', '_query', '_timeout') - - def __init__(self, connection, query, state, args, prefetch, timeout): + __slots__ = ( + '_state', + '_args', + '_prefetch', + '_query', + '_timeout', + '_record_class', + ) + + def __init__( + self, + connection, + query, + state, + args, + prefetch, + timeout, + record_class + ): super().__init__(connection) self._args = args self._prefetch = prefetch self._query = query self._timeout = timeout self._state = state + self._record_class = record_class if state is not None: state.attach() @@ -35,18 +52,28 @@ def __init__(self, connection, query, state, args, prefetch, timeout): @connresource.guarded def __aiter__(self): prefetch = 50 if self._prefetch is None else self._prefetch - return CursorIterator(self._connection, - self._query, self._state, - self._args, prefetch, - self._timeout) + return CursorIterator( + self._connection, + self._query, + self._state, + self._args, + self._record_class, + prefetch, + self._timeout, + ) @connresource.guarded def __await__(self): if self._prefetch is not None: raise exceptions.InterfaceError( 'prefetch argument can only be specified for iterable cursor') - cursor = Cursor(self._connection, self._query, - self._state, self._args) + cursor = Cursor( + self._connection, + self._query, + self._state, + self._args, + self._record_class, + ) return cursor._init(self._timeout).__await__() def __del__(self): @@ -57,9 +84,16 @@ def __del__(self): class BaseCursor(connresource.ConnectionResource): - __slots__ = ('_state', '_args', '_portal_name', '_exhausted', '_query') + __slots__ = ( + '_state', + '_args', + '_portal_name', + '_exhausted', + '_query', + '_record_class', + ) - def __init__(self, connection, query, state, args): + def __init__(self, connection, query, state, args, record_class): super().__init__(connection) self._args = args self._state = state @@ -68,6 +102,7 @@ def __init__(self, connection, query, state, args): self._portal_name = None self._exhausted = False self._query = query + self._record_class = record_class def _check_ready(self): if self._state is None: @@ -151,8 +186,17 @@ class CursorIterator(BaseCursor): __slots__ = ('_buffer', '_prefetch', '_timeout') - def __init__(self, connection, query, state, args, prefetch, timeout): - super().__init__(connection, query, state, args) + def __init__( + self, + connection, + query, + state, + args, + record_class, + prefetch, + timeout + ): + super().__init__(connection, query, state, args, record_class) if prefetch <= 0: raise exceptions.InterfaceError( @@ -171,7 +215,11 @@ def __aiter__(self): async def __anext__(self): if self._state is None: self._state = await self._connection._get_statement( - self._query, self._timeout, named=True) + self._query, + self._timeout, + named=True, + record_class=self._record_class, + ) self._state.attach() if not self._portal_name: @@ -196,7 +244,11 @@ class Cursor(BaseCursor): async def _init(self, timeout): if self._state is None: self._state = await self._connection._get_statement( - self._query, timeout, named=True) + self._query, + timeout, + named=True, + record_class=self._record_class, + ) self._state.attach() self._check_ready() await self._bind(timeout) diff --git a/asyncpg/pool.py b/asyncpg/pool.py index ec42f816..bf61c9ce 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -15,6 +15,7 @@ from . import connection from . import connect_utils from . import exceptions +from . import protocol logger = logging.getLogger(__name__) @@ -309,7 +310,7 @@ class Pool: '_init', '_connect_args', '_connect_kwargs', '_working_addr', '_working_config', '_working_params', '_holders', '_initialized', '_initializing', '_closing', - '_closed', '_connection_class', '_generation', + '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' ) @@ -322,6 +323,7 @@ def __init__(self, *connect_args, init, loop, connection_class, + record_class, **connect_kwargs): if len(connect_args) > 1: @@ -359,6 +361,11 @@ def __init__(self, *connect_args, 'connection_class is expected to be a subclass of ' 'asyncpg.Connection, got {!r}'.format(connection_class)) + if not issubclass(record_class, protocol.Record): + raise TypeError( + 'record_class is expected to be a subclass of ' + 'asyncpg.Record, got {!r}'.format(record_class)) + self._minsize = min_size self._maxsize = max_size @@ -372,6 +379,7 @@ def __init__(self, *connect_args, self._working_params = None self._connection_class = connection_class + self._record_class = record_class self._closing = False self._closed = False @@ -469,6 +477,7 @@ async def _get_new_connection(self): *self._connect_args, loop=self._loop, connection_class=self._connection_class, + record_class=self._record_class, **self._connect_kwargs) self._working_addr = con._addr @@ -484,7 +493,9 @@ async def _get_new_connection(self): timeout=self._working_params.connect_timeout, config=self._working_config, params=self._working_params, - connection_class=self._connection_class) + connection_class=self._connection_class, + record_class=self._record_class, + ) if self._init is not None: try: @@ -793,6 +804,7 @@ def create_pool(dsn=None, *, init=None, loop=None, connection_class=connection.Connection, + record_class=protocol.Record, **connect_kwargs): r"""Create a connection pool. @@ -851,6 +863,11 @@ def create_pool(dsn=None, *, The class to use for connections. Must be a subclass of :class:`~asyncpg.connection.Connection`. + :param type record_class: + If specified, the class to use for records returned by queries on + the connections in this pool. Must be a subclass of + :class:`~asyncpg.Record`. + :param int min_size: Number of connection the pool will be initialized with. @@ -901,10 +918,14 @@ def create_pool(dsn=None, *, or :meth:`Connection.add_log_listener() `) present on the connection at the moment of its release to the pool. + + .. versionchanged:: 0.21.0 + Added the *record_class* parameter. """ return Pool( dsn, connection_class=connection_class, + record_class=record_class, min_size=min_size, max_size=max_size, max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 09a0a2ec..5df6b674 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -103,9 +103,15 @@ def cursor(self, *args, prefetch=None, :return: A :class:`~cursor.CursorFactory` object. """ - return cursor.CursorFactory(self._connection, self._query, - self._state, args, prefetch, - timeout) + return cursor.CursorFactory( + self._connection, + self._query, + self._state, + args, + prefetch, + timeout, + self._state.record_class, + ) @connresource.guarded async def explain(self, *args, analyze=False): diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index 5d3ccc4b..238fa280 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -7,6 +7,7 @@ from collections.abc import Mapping as MappingABC +import asyncpg from asyncpg import exceptions @@ -232,7 +233,7 @@ cdef class Codec: schema=self.schema, data_type=self.name, ) - result = record.ApgRecord_New(self.record_desc, elem_count) + result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count) for i in range(elem_count): elem_typ = self.element_type_oids[i] received_elem_typ = hton.unpack_int32(frb_read(buf, 4)) diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 0d3f8d3b..90944c1a 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -11,6 +11,8 @@ cdef class PreparedStatementState: readonly str query readonly bint closed readonly int refs + readonly type record_class + list row_desc list parameters_desc diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index b69f76be..60094be6 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -11,7 +11,13 @@ from asyncpg import exceptions @cython.final cdef class PreparedStatementState: - def __cinit__(self, str name, str query, BaseProtocol protocol): + def __cinit__( + self, + str name, + str query, + BaseProtocol protocol, + type record_class + ): self.name = name self.query = query self.settings = protocol.settings @@ -21,6 +27,7 @@ cdef class PreparedStatementState: self.cols_desc = None self.closed = False self.refs = 0 + self.record_class = record_class def _get_parameters(self): cdef Codec codec @@ -264,7 +271,7 @@ cdef class PreparedStatementState: 'different from what was described ({})'.format( fnum, self.cols_num)) - dec_row = record.ApgRecord_New(self.cols_desc, fnum) + dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum) for i in range(fnum): flen = hton.unpack_int32(frb_read(&rbuf, 4)) diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd index 14a7ecc6..772d6432 100644 --- a/asyncpg/protocol/protocol.pxd +++ b/asyncpg/protocol/protocol.pxd @@ -42,6 +42,7 @@ cdef class BaseProtocol(CoreProtocol): object timeout_callback object completed_callback object conref + type record_class bint is_reading str last_query diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 857fb4cc..4f7ce675 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -73,7 +73,7 @@ NO_TIMEOUT = object() cdef class BaseProtocol(CoreProtocol): - def __init__(self, addr, connected_fut, con_params, loop): + def __init__(self, addr, connected_fut, con_params, record_class: type, loop): # type of `con_params` is `_ConnectionParameters` CoreProtocol.__init__(self, con_params) @@ -85,6 +85,7 @@ cdef class BaseProtocol(CoreProtocol): self.address = addr self.settings = ConnectionSettings((self.address, con_params.database)) + self.record_class = record_class self.statement = None self.return_extra = False @@ -122,6 +123,9 @@ cdef class BaseProtocol(CoreProtocol): def get_settings(self): return self.settings + def get_record_class(self): + return self.record_class + def is_in_transaction(self): # PQTRANS_INTRANS = idle, within transaction block # PQTRANS_INERROR = idle, within failed transaction @@ -139,7 +143,9 @@ cdef class BaseProtocol(CoreProtocol): @cython.iterable_coroutine async def prepare(self, stmt_name, query, timeout, - PreparedStatementState state=None): + *, + PreparedStatementState state=None, + record_class): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -154,7 +160,8 @@ cdef class BaseProtocol(CoreProtocol): self._prepare(stmt_name, query) # network op self.last_query = query if state is None: - state = PreparedStatementState(stmt_name, query, self) + state = PreparedStatementState( + stmt_name, query, self, record_class) self.statement = state except Exception as ex: waiter.set_exception(ex) @@ -955,7 +962,7 @@ def _create_record(object mapping, tuple elems): desc = record.ApgRecordDesc_New( mapping, tuple(mapping) if mapping else ()) - rec = record.ApgRecord_New(desc, len(elems)) + rec = record.ApgRecord_New(Record, desc, len(elems)) for i in range(len(elems)): elem = elems[i] cpython.Py_INCREF(elem) diff --git a/asyncpg/protocol/record/__init__.pxd b/asyncpg/protocol/record/__init__.pxd index 3d6b5fd7..43ac5e33 100644 --- a/asyncpg/protocol/record/__init__.pxd +++ b/asyncpg/protocol/record/__init__.pxd @@ -13,7 +13,7 @@ cdef extern from "record/recordobj.h": cpython.PyTypeObject *ApgRecord_InitTypes() except NULL int ApgRecord_CheckExact(object) - object ApgRecord_New(object, int) + object ApgRecord_New(type, object, int) void ApgRecord_SET_ITEM(object, int, object) object ApgRecordDesc_New(object, object) diff --git a/asyncpg/protocol/record/recordobj.c b/asyncpg/protocol/record/recordobj.c index 2f468a33..412e8174 100644 --- a/asyncpg/protocol/record/recordobj.c +++ b/asyncpg/protocol/record/recordobj.c @@ -15,9 +15,14 @@ static PyObject * record_new_items_iter(PyObject *); static ApgRecordObject *free_list[ApgRecord_MAXSAVESIZE]; static int numfree[ApgRecord_MAXSAVESIZE]; +static size_t MAX_RECORD_SIZE = ( + ((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - sizeof(PyObject *)) + / sizeof(PyObject *) +); + PyObject * -ApgRecord_New(PyObject *desc, Py_ssize_t size) +ApgRecord_New(PyTypeObject *type, PyObject *desc, Py_ssize_t size) { ApgRecordObject *o; Py_ssize_t i; @@ -27,19 +32,36 @@ ApgRecord_New(PyObject *desc, Py_ssize_t size) return NULL; } - if (size < ApgRecord_MAXSAVESIZE && (o = free_list[size]) != NULL) { - free_list[size] = (ApgRecordObject *) o->ob_item[0]; - numfree[size]--; - _Py_NewReference((PyObject *)o); - } - else { - /* Check for overflow */ - if ((size_t)size > ((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - - sizeof(PyObject *)) / sizeof(PyObject *)) { + if (type == &ApgRecord_Type) { + if (size < ApgRecord_MAXSAVESIZE && (o = free_list[size]) != NULL) { + free_list[size] = (ApgRecordObject *) o->ob_item[0]; + numfree[size]--; + _Py_NewReference((PyObject *)o); + } + else { + /* Check for overflow */ + if ((size_t)size > MAX_RECORD_SIZE) { + return PyErr_NoMemory(); + } + o = PyObject_GC_NewVar(ApgRecordObject, &ApgRecord_Type, size); + if (o == NULL) { + return NULL; + } + } + + PyObject_GC_Track(o); + } else { + assert(PyType_IsSubtype(type, &ApgRecord_Type)); + + if ((size_t)size > MAX_RECORD_SIZE) { return PyErr_NoMemory(); } - o = PyObject_GC_NewVar(ApgRecordObject, &ApgRecord_Type, size); - if (o == NULL) { + o = (ApgRecordObject *)type->tp_alloc(type, size); + if (!_PyObject_GC_IS_TRACKED(o)) { + PyErr_SetString( + PyExc_TypeError, + "record subclass is not tracked by GC" + ); return NULL; } } @@ -51,7 +73,6 @@ ApgRecord_New(PyObject *desc, Py_ssize_t size) Py_INCREF(desc); o->desc = (ApgRecordDescObject*)desc; o->self_hash = -1; - PyObject_GC_Track(o); return (PyObject *) o; } diff --git a/asyncpg/protocol/record/recordobj.h b/asyncpg/protocol/record/recordobj.h index d329f57e..2c6c1f1c 100644 --- a/asyncpg/protocol/record/recordobj.h +++ b/asyncpg/protocol/record/recordobj.h @@ -46,7 +46,7 @@ extern PyTypeObject ApgRecordDesc_Type; (((ApgRecordObject *)(op))->ob_item[i]) PyTypeObject *ApgRecord_InitTypes(void); -PyObject *ApgRecord_New(PyObject *, Py_ssize_t); +PyObject *ApgRecord_New(PyTypeObject *, PyObject *, Py_ssize_t); PyObject *ApgRecordDesc_New(PyObject *, PyObject *); #endif diff --git a/tests/test_record.py b/tests/test_record.py index e9abab45..8abe90ee 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -22,6 +22,14 @@ R_ABC = collections.OrderedDict([('a', 0), ('b', 1), ('c', 2)]) +class CustomRecord(asyncpg.Record): + pass + + +class AnotherCustomRecord(asyncpg.Record): + pass + + class TestRecord(tb.ConnectedTestCase): @contextlib.contextmanager @@ -339,3 +347,169 @@ async def test_record_no_new(self): with self.assertRaisesRegex( TypeError, "cannot create 'asyncpg.Record' instances"): asyncpg.Record() + + @tb.with_connection_options(record_class=CustomRecord) + async def test_record_subclass_01(self): + r = await self.con.fetchrow("SELECT 1 as a, '2' as b") + self.assertIsInstance(r, CustomRecord) + + r = await self.con.fetch("SELECT 1 as a, '2' as b") + self.assertIsInstance(r[0], CustomRecord) + + async with self.con.transaction(): + cur = await self.con.cursor("SELECT 1 as a, '2' as b") + r = await cur.fetchrow() + self.assertIsInstance(r, CustomRecord) + + cur = await self.con.cursor("SELECT 1 as a, '2' as b") + r = await cur.fetch(1) + self.assertIsInstance(r[0], CustomRecord) + + async with self.con.transaction(): + cur = self.con.cursor("SELECT 1 as a, '2' as b") + async for r in cur: + self.assertIsInstance(r, CustomRecord) + + ps = await self.con.prepare("SELECT 1 as a, '2' as b") + r = await ps.fetchrow() + self.assertIsInstance(r, CustomRecord) + + async def test_record_subclass_02(self): + r = await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + self.assertIsInstance(r, CustomRecord) + + r = await self.con.fetch( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + self.assertIsInstance(r[0], CustomRecord) + + async with self.con.transaction(): + cur = await self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + r = await cur.fetchrow() + self.assertIsInstance(r, CustomRecord) + + cur = await self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + r = await cur.fetch(1) + self.assertIsInstance(r[0], CustomRecord) + + async with self.con.transaction(): + cur = self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + async for r in cur: + self.assertIsInstance(r, CustomRecord) + + ps = await self.con.prepare( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + r = await ps.fetchrow() + self.assertIsInstance(r, CustomRecord) + + r = await ps.fetch() + self.assertIsInstance(r[0], CustomRecord) + + @tb.with_connection_options(record_class=AnotherCustomRecord) + async def test_record_subclass_03(self): + r = await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + self.assertIsInstance(r, CustomRecord) + + r = await self.con.fetch( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + self.assertIsInstance(r[0], CustomRecord) + + async with self.con.transaction(): + cur = await self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + r = await cur.fetchrow() + self.assertIsInstance(r, CustomRecord) + + cur = await self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + r = await cur.fetch(1) + self.assertIsInstance(r[0], CustomRecord) + + async with self.con.transaction(): + cur = self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + async for r in cur: + self.assertIsInstance(r, CustomRecord) + + ps = await self.con.prepare( + "SELECT 1 as a, '2' as b", + record_class=CustomRecord, + ) + r = await ps.fetchrow() + self.assertIsInstance(r, CustomRecord) + + r = await ps.fetch() + self.assertIsInstance(r[0], CustomRecord) + + @tb.with_connection_options(record_class=CustomRecord) + async def test_record_subclass_04(self): + r = await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=asyncpg.Record, + ) + self.assertIs(type(r), asyncpg.Record) + + r = await self.con.fetch( + "SELECT 1 as a, '2' as b", + record_class=asyncpg.Record, + ) + self.assertIs(type(r[0]), asyncpg.Record) + + async with self.con.transaction(): + cur = await self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=asyncpg.Record, + ) + r = await cur.fetchrow() + self.assertIs(type(r), asyncpg.Record) + + cur = await self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=asyncpg.Record, + ) + r = await cur.fetch(1) + self.assertIs(type(r[0]), asyncpg.Record) + + async with self.con.transaction(): + cur = self.con.cursor( + "SELECT 1 as a, '2' as b", + record_class=asyncpg.Record, + ) + async for r in cur: + self.assertIs(type(r), asyncpg.Record) + + ps = await self.con.prepare( + "SELECT 1 as a, '2' as b", + record_class=asyncpg.Record, + ) + r = await ps.fetchrow() + self.assertIs(type(r), asyncpg.Record) + + r = await ps.fetch() + self.assertIs(type(r[0]), asyncpg.Record) diff --git a/tests/test_timeout.py b/tests/test_timeout.py index c2bca631..152a504a 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -138,9 +138,9 @@ async def test_command_timeout_01(self): class SlowPrepareConnection(pg_connection.Connection): """Connection class to test timeouts.""" - async def _get_statement(self, query, timeout): + async def _get_statement(self, query, timeout, **kwargs): await asyncio.sleep(0.3) - return await super()._get_statement(query, timeout) + return await super()._get_statement(query, timeout, **kwargs) class TestTimeoutCoversPrepare(tb.ConnectedTestCase):