Skip to content

Commit

Permalink
S01E09
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 3, 2024
1 parent 3f26f76 commit cdbf97f
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 131 deletions.
6 changes: 1 addition & 5 deletions databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import uuid

import aiopg
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.backends.compilers.psycopg import PGCompiler_psycopg
from databases.backends.dialects.psycopg import PGDialect_psycopg
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
Expand All @@ -38,12 +36,10 @@ def _get_dialect(self) -> Dialect:
dialect = PGDialect_psycopg(
json_serializer=json.dumps, json_deserializer=lambda x: x
)
dialect.statement_compiler = PGCompiler_psycopg
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
dialect._backslash_escapes = False
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
dialect._has_native_hstore = True
dialect.supports_native_decimal = True

Expand Down
4 changes: 2 additions & 2 deletions databases/backends/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import typing

import asyncpg
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import dialect as psycopg_dialect
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
Expand All @@ -29,7 +29,7 @@ def __init__(
self._pool = None

def _get_dialect(self) -> Dialect:
dialect = psycopg_dialect(paramstyle="pyformat")
dialect = PGDialect_psycopg(paramstyle="pyformat")
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
Expand Down
Empty file.
17 changes: 0 additions & 17 deletions databases/backends/compilers/psycopg.py

This file was deleted.

Empty file.
46 changes: 0 additions & 46 deletions databases/backends/dialects/psycopg.py

This file was deleted.

15 changes: 10 additions & 5 deletions databases/backends/psycopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PsycopgBackend(DatabaseBackend):
_database_url: DatabaseURL
_options: typing.Dict[str, typing.Any]
_dialect: Dialect
_pool: typing.Optional[psycopg_pool.AsyncConnectionPool]
_pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None

def __init__(
self,
Expand All @@ -33,7 +33,6 @@ def __init__(
self._options = options
self._dialect = PGDialect_psycopg()
self._dialect.implicit_returning = True
self._pool = None

async def connect(self) -> None:
if self._pool is not None:
Expand Down Expand Up @@ -95,7 +94,10 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
rows = await cursor.fetchall()

column_maps = create_column_maps(result_columns)
return [PsycopgRecord(row, result_columns, self._dialect, column_maps) for row in rows]
return [
PsycopgRecord(row, result_columns, self._dialect, column_maps)
for row in rows
]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
if self._connection is None:
Expand Down Expand Up @@ -167,7 +169,8 @@ def raw_connection(self) -> typing.Any:
return self._connection

def _compile(
self, query: ClauseElement,
self,
query: ClauseElement,
) -> typing.Tuple[str, typing.Mapping[str, typing.Any], tuple]:
compiled = query.compile(
dialect=self._dialect,
Expand Down Expand Up @@ -224,7 +227,9 @@ def _mapping(self) -> typing.Mapping:

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0:
return self._mapping[key]
if isinstance(key, str):
return self._mapping[key]
return self._row[key]
elif isinstance(key, Column):
idx, datatype = self._column_map_full[str(key)]
elif isinstance(key, int):
Expand Down
86 changes: 30 additions & 56 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,17 @@ async def test_queries(database_url):

assert len(results) == 3
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True
assert results[0]["completed"] is True
assert results[1]["text"] == "example2"
assert results[1]["completed"] == False
assert results[1]["completed"] is False
assert results[2]["text"] == "example3"
assert results[2]["completed"] == True
assert results[2]["completed"] is True

# fetch_one()
query = notes.select()
result = await database.fetch_one(query=query)
assert result["text"] == "example1"
assert result["completed"] == True
assert result["completed"] is True

# fetch_val()
query = sqlalchemy.sql.select(*[notes.c.text])
Expand Down Expand Up @@ -246,11 +246,11 @@ async def test_queries(database_url):
iterate_results.append(result)
assert len(iterate_results) == 3
assert iterate_results[0]["text"] == "example1"
assert iterate_results[0]["completed"] == True
assert iterate_results[0]["completed"] is True
assert iterate_results[1]["text"] == "example2"
assert iterate_results[1]["completed"] == False
assert iterate_results[1]["completed"] is False
assert iterate_results[2]["text"] == "example3"
assert iterate_results[2]["completed"] == True
assert iterate_results[2]["completed"] is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -280,26 +280,26 @@ async def test_queries_raw(database_url):
results = await database.fetch_all(query=query, values={"completed": True})
assert len(results) == 2
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True
assert results[0]["completed"] is True
assert results[1]["text"] == "example3"
assert results[1]["completed"] == True
assert results[1]["completed"] is True

# fetch_one()
query = "SELECT * FROM notes WHERE completed = :completed"
result = await database.fetch_one(query=query, values={"completed": False})
assert result["text"] == "example2"
assert result["completed"] == False
assert result["completed"] is False

# fetch_val()
query = "SELECT completed FROM notes WHERE text = :text"
result = await database.fetch_val(query=query, values={"text": "example1"})
assert result == True
assert result is True

query = "SELECT * FROM notes WHERE text = :text"
result = await database.fetch_val(
query=query, values={"text": "example1"}, column="completed"
)
assert result == True
assert result is True

# iterate()
query = "SELECT * FROM notes"
Expand All @@ -308,11 +308,11 @@ async def test_queries_raw(database_url):
iterate_results.append(result)
assert len(iterate_results) == 3
assert iterate_results[0]["text"] == "example1"
assert iterate_results[0]["completed"] == True
assert iterate_results[0]["completed"] is True
assert iterate_results[1]["text"] == "example2"
assert iterate_results[1]["completed"] == False
assert iterate_results[1]["completed"] is False
assert iterate_results[2]["text"] == "example3"
assert iterate_results[2]["completed"] == True
assert iterate_results[2]["completed"] is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -380,7 +380,7 @@ async def test_results_support_mapping_interface(database_url):

assert isinstance(results_as_dicts[0]["id"], int)
assert results_as_dicts[0]["text"] == "example1"
assert results_as_dicts[0]["completed"] == True
assert results_as_dicts[0]["completed"] is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -467,7 +467,7 @@ async def test_execute_return_val(database_url):
query = notes.select().where(notes.c.id == pk)
result = await database.fetch_one(query)
assert result["text"] == "example1"
assert result["completed"] == True
assert result["completed"] is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -857,7 +857,7 @@ async def test_transaction_commit_low_level(database_url):
try:
query = notes.insert().values(text="example1", completed=True)
await database.execute(query)
except: # pragma: no cover
except Exception: # pragma: no cover
await transaction.rollback()
else:
await transaction.commit()
Expand All @@ -881,7 +881,7 @@ async def test_transaction_rollback_low_level(database_url):
query = notes.insert().values(text="example1", completed=True)
await database.execute(query)
raise RuntimeError()
except:
except Exception:
await transaction.rollback()
else: # pragma: no cover
await transaction.commit()
Expand Down Expand Up @@ -1354,13 +1354,12 @@ async def test_queries_with_expose_backend_connection(database_url):
]:
cursor = await raw_connection.cursor()
await cursor.execute(insert_query, values)
elif database.url.scheme == "mysql+asyncmy":
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
async with raw_connection.cursor() as cursor:
await cursor.execute(insert_query, values)
elif database.url.scheme in [
"postgresql",
"postgresql+asyncpg",
"postgresql+psycopg",
]:
await raw_connection.execute(insert_query, *values)
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
Expand All @@ -1372,7 +1371,7 @@ async def test_queries_with_expose_backend_connection(database_url):
if database.url.scheme in ["mysql", "mysql+aiomysql"]:
cursor = await raw_connection.cursor()
await cursor.executemany(insert_query, values)
elif database.url.scheme == "mysql+asyncmy":
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
async with raw_connection.cursor() as cursor:
await cursor.executemany(insert_query, values)
elif database.url.scheme == "postgresql+aiopg":
Expand All @@ -1395,36 +1394,28 @@ async def test_queries_with_expose_backend_connection(database_url):
cursor = await raw_connection.cursor()
await cursor.execute(select_query)
results = await cursor.fetchall()
elif database.url.scheme == "mysql+asyncmy":
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
async with raw_connection.cursor() as cursor:
await cursor.execute(select_query)
results = await cursor.fetchall()
elif database.url.scheme in [
"postgresql",
"postgresql+asyncpg",
"postgresql+psycopg",
]:
elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
results = await raw_connection.fetch(select_query)
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
results = await raw_connection.execute_fetchall(select_query)

assert len(results) == 3
# Raw output for the raw request
assert results[0][1] == "example1"
assert results[0][2] == True
assert results[0][2] is True
assert results[1][1] == "example2"
assert results[1][2] == False
assert results[1][2] is False
assert results[2][1] == "example3"
assert results[2][2] == True
assert results[2][2] is True

# fetch_one()
if database.url.scheme in [
"postgresql",
"postgresql+asyncpg",
"postgresql+psycopg",
]:
if database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
result = await raw_connection.fetchrow(select_query)
elif database.url.scheme == "mysql+asyncmy":
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
async with raw_connection.cursor() as cursor:
await cursor.execute(select_query)
result = await cursor.fetchone()
Expand All @@ -1435,7 +1426,7 @@ async def test_queries_with_expose_backend_connection(database_url):

# Raw output for the raw request
assert result[1] == "example1"
assert result[2] == True
assert result[2] is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -1606,7 +1597,7 @@ async def test_column_names(database_url, select_query):

assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"]
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True
assert results[0]["completed"] is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -1641,23 +1632,6 @@ async def test_result_named_access(database_url):
assert result.completed is True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_mapping_property_interface(database_url):
"""
Test that all connections implement interface with `_mapping` property
"""
async with Database(database_url) as database:
query = notes.select()
single_result = await database.fetch_one(query=query)
assert single_result._mapping["text"] == "example1"
assert single_result._mapping["completed"] is True

list_result = await database.fetch_all(query=query)
assert list_result[0]._mapping["text"] == "example1"
assert list_result[0]._mapping["completed"] is True


@async_adapter
async def test_should_not_maintain_ref_when_no_cache_param():
async with Database(
Expand Down

0 comments on commit cdbf97f

Please sign in to comment.