Skip to content

Commit

Permalink
S01E07
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 3, 2024
1 parent e08eb4f commit 1c58a73
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
6 changes: 4 additions & 2 deletions databases/backends/common/records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing
from collections import namedtuple

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.row import Row as SQLRow
Expand Down Expand Up @@ -53,7 +52,10 @@ def values(self) -> typing.ValuesView:

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0:
return self._row[key]
try:
return self._row[key]
except TypeError:
return self._mapping[key]
elif isinstance(key, Column):
idx, datatype = self._column_map_full[str(key)]
elif isinstance(key, int):
Expand Down
28 changes: 24 additions & 4 deletions databases/backends/psycopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import psycopg
import psycopg_pool
from psycopg.rows import namedtuple_row
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.schema import Column

from databases.backends.common.records import Record, create_column_maps
from databases.core import DatabaseURL
Expand All @@ -31,6 +32,7 @@ def __init__(
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = PGDialect_psycopg()
self._dialect.implicit_returning = True
self._pool = None

async def connect(self) -> None:
Expand Down Expand Up @@ -94,7 +96,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
rows = await cursor.fetchall()

column_maps = create_column_maps(result_columns)
return [Record(row, result_columns, self._dialect, column_maps) for row in rows]
return [PsycopgRecord(row, result_columns, self._dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
if self._connection is None:
Expand All @@ -109,7 +111,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa
if row is None:
return None

return Record(
return PsycopgRecord(
row,
result_columns,
self._dialect,
Expand Down Expand Up @@ -154,7 +156,7 @@ async def iterate(
if row is None:
break

yield Record(row, result_columns, self._dialect, column_maps)
yield PsycopgRecord(row, result_columns, self._dialect, column_maps)

def transaction(self) -> "TransactionBackend":
return PsycopgTransaction(connection=self)
Expand Down Expand Up @@ -214,3 +216,21 @@ async def rollback(self) -> None:

async with self._transaction._conn.lock:
await self._transaction._conn.wait(self._transaction._rollback_gen(None))


class PsycopgRecord(Record):
@property
def _mapping(self) -> typing.Mapping:
return self._row._asdict()

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0:
return self._mapping[key]
elif isinstance(key, Column):
idx, datatype = self._column_map_full[str(key)]
elif isinstance(key, int):
idx, datatype = self._column_map_int[key]
else:
idx, datatype = self._column_map[key]

return self._row[idx]

0 comments on commit 1c58a73

Please sign in to comment.