Skip to content

Commit

Permalink
Allow using custom Record class
Browse files Browse the repository at this point in the history
Add the new `record_class` parameter to the `create_pool()` and
`connect()` functions, as well as to the `cursor()`, `prepare()`,
`fetch()` and `fetchrow()` connection methods.

This not only allows adding custom functionality to the returned
objects, but also assists with typing (see #577 for discussion).

Fixes: #40.
  • Loading branch information
elprans committed Jul 19, 2020
1 parent 39040b3 commit 6aca86e
Show file tree
Hide file tree
Showing 16 changed files with 602 additions and 106 deletions.
3 changes: 3 additions & 0 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest


import asyncpg
from asyncpg import cluster as pg_cluster
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool
Expand Down Expand Up @@ -266,13 +267,15 @@ def create_pool(dsn=None, *,
loop=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
record_class=asyncpg.Record,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs)


Expand Down
27 changes: 20 additions & 7 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,16 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
raise


async def _connect_addr(*, addr, loop, timeout, params, config,
connection_class):
async def _connect_addr(
*,
addr,
loop,
timeout,
params,
config,
connection_class,
record_class,
):
assert loop is not None

if timeout <= 0:
Expand All @@ -613,7 +621,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
params = params._replace(password=password)

proto_factory = lambda: protocol.Protocol(
addr, connected, params, loop)
addr, connected, params, record_class, loop)

if isinstance(addr, str):
# UNIX socket
Expand Down Expand Up @@ -649,7 +657,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
return con


async def _connect(*, loop, timeout, connection_class, **kwargs):
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -661,9 +669,14 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
before = time.monotonic()
try:
con = await _connect_addr(
addr=addr, loop=loop, timeout=timeout,
params=params, config=config,
connection_class=connection_class)
addr=addr,
loop=loop,
timeout=timeout,
params=params,
config=config,
connection_class=connection_class,
record_class=record_class,
)
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
else:
Expand Down
Loading

0 comments on commit 6aca86e

Please sign in to comment.