Skip to content

Commit

Permalink
feat: SQLAlchemy 2.0 support (#368)
Browse files Browse the repository at this point in the history
This PR updates the dataloader and unit tests to be compatible with sqlalchemy 2.0
  • Loading branch information
erikwrede authored May 14, 2023
1 parent 882205d commit d0668cc
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jobs:
strategy:
max-parallel: 10
matrix:
sql-alchemy: ["1.2", "1.3", "1.4"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
sql-alchemy: [ "1.2", "1.3", "1.4","2.0" ]
python-version: [ "3.7", "3.8", "3.9", "3.10" ]

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ __pycache__/
.Python
env/
.venv/
venv/
build/
develop-eggs/
dist/
Expand Down
20 changes: 18 additions & 2 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import sqlalchemy
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext
from sqlalchemy.util import immutabledict

from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than
from .utils import (
SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
SQL_VERSION_HIGHER_EQUAL_THAN_2,
is_graphene_version_less_than,
)


def get_data_loader_impl() -> Any: # pragma: no cover
Expand Down Expand Up @@ -76,7 +81,18 @@ async def batch_load_fn(self, parents):
query_context = parent_mapper_query._compile_context()
else:
query_context = QueryContext(session.query(parent_mapper.entity))
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
if SQL_VERSION_HIGHER_EQUAL_THAN_2: # pragma: no cover
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
None, # recursion depth can be none
immutabledict(), # default value for selectinload->lazyload
)
elif SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
Expand Down
23 changes: 18 additions & 5 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@
String,
Table,
func,
select,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, column_property, composite, mapper, relationship
from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter
from sqlalchemy.sql.type_api import TypeEngine

from graphene_sqlalchemy.tests.utils import wrap_select_func
from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2

# fmt: off
import sqlalchemy
if SQL_VERSION_HIGHER_EQUAL_THAN_2:
from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip
else:
from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip
# fmt: on

PetKind = Enum("cat", "dog", name="pet_kind")


Expand Down Expand Up @@ -119,7 +128,7 @@ def hybrid_prop_list(self) -> List[int]:
return [1, 2, 3]

column_prop = column_property(
select([func.cast(func.count(id), Integer)]), doc="Column property"
wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property"
)

composite_prop = composite(
Expand Down Expand Up @@ -163,7 +172,11 @@ def __subclasses__(cls):

editor_table = Table("editors", Base.metadata, autoload=True)

mapper(ReflectedEditor, editor_table)
# TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
Base.registry.map_imperatively(ReflectedEditor, editor_table)
else:
mapper(ReflectedEditor, editor_table)


############################################
Expand Down Expand Up @@ -337,7 +350,7 @@ class Employee(Person):
############################################


class CustomIntegerColumn(_LookupExpressionAdapter, TypeEngine):
class CustomIntegerColumn(HasExpressionLookup, TypeEngine):
"""
Custom Column Type that our converters don't recognize
Adapted from sqlalchemy.Integer
Expand Down
5 changes: 3 additions & 2 deletions graphene_sqlalchemy/tests/models_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
String,
Table,
func,
select,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import column_property, relationship

from graphene_sqlalchemy.tests.utils import wrap_select_func

PetKind = Enum("cat", "dog", name="pet_kind")


Expand Down Expand Up @@ -61,7 +62,7 @@ class Reporter(Base):
favorite_article = relationship("Article", uselist=False)

column_prop = column_property(
select([func.cast(func.count(id), Integer)]), doc="Column property"
wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property"
)


Expand Down
47 changes: 32 additions & 15 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@
import sys
from typing import Dict, Tuple, Union

import graphene
import pytest
import sqlalchemy
import sqlalchemy_utils as sqa_utils
from sqlalchemy import Column, func, select, types
from graphene.relay import Node
from graphene.types.structures import Structure
from sqlalchemy import Column, func, types
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import column_property, composite

import graphene
from graphene.relay import Node
from graphene.types.structures import Structure

from .models import (
Article,
CompositeFullName,
Pet,
Reporter,
ShoppingCart,
ShoppingCartItem,
)
from .utils import wrap_select_func
from ..converter import (
convert_sqlalchemy_column,
convert_sqlalchemy_composite,
Expand All @@ -27,6 +35,7 @@
from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory
from ..registry import Registry, get_global_registry
from ..types import ORMField, SQLAlchemyObjectType
from ..utils import is_sqlalchemy_version_less_than
from .models import (
Article,
CompositeFullName,
Expand Down Expand Up @@ -204,9 +213,9 @@ def prop_method() -> int | str:
return "not allowed in gql schema"

with pytest.raises(
ValueError,
match=r"Cannot convert hybrid_property Union to "
r"graphene.Union: the Union contains scalars. \.*",
ValueError,
match=r"Cannot convert hybrid_property Union to "
r"graphene.Union: the Union contains scalars. \.*",
):
get_hybrid_property_type(prop_method)

Expand Down Expand Up @@ -460,7 +469,7 @@ class TestEnum(enum.IntEnum):

def test_should_columproperty_convert():
field = get_field_from_column(
column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1))
column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1))
)

assert field.type == graphene.Int
Expand All @@ -477,10 +486,18 @@ def test_should_jsontype_convert_jsonstring():
assert get_field(types.JSON).type == graphene.JSONString


@pytest.mark.skipif(
(not is_sqlalchemy_version_less_than("2.0.0b1")),
reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy",
)
def test_should_variant_int_convert_int():
assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int


@pytest.mark.skipif(
(not is_sqlalchemy_version_less_than("2.0.0b1")),
reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy",
)
def test_should_variant_string_convert_string():
assert get_field(types.Variant(types.String(), {})).type == graphene.String

Expand Down Expand Up @@ -811,8 +828,8 @@ class Meta:
)

for (
hybrid_prop_name,
hybrid_prop_expected_return_type,
hybrid_prop_name,
hybrid_prop_expected_return_type,
) in shopping_cart_item_expected_types.items():
hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name]

Expand All @@ -823,7 +840,7 @@ class Meta:
str(hybrid_prop_expected_return_type),
)
assert (
hybrid_prop_field.description is None
hybrid_prop_field.description is None
) # "doc" is ignored by hybrid property

###################################################
Expand Down Expand Up @@ -870,8 +887,8 @@ class Meta:
)

for (
hybrid_prop_name,
hybrid_prop_expected_return_type,
hybrid_prop_name,
hybrid_prop_expected_return_type,
) in shopping_cart_expected_types.items():
hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name]

Expand All @@ -882,5 +899,5 @@ class Meta:
str(hybrid_prop_expected_return_type),
)
assert (
hybrid_prop_field.description is None
hybrid_prop_field.description is None
) # "doc" is ignored by hybrid property
13 changes: 12 additions & 1 deletion graphene_sqlalchemy/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import inspect
import re

from sqlalchemy import select

from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4


def to_std_dicts(value):
"""Convert nested ordered dicts to normal dicts for better comparison."""
Expand All @@ -18,8 +22,15 @@ def remove_cache_miss_stat(message):
return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message)


async def eventually_await_session(session, func, *args):
def wrap_select_func(query):
# TODO remove this when we drop support for sqa < 2.0
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
return select(query)
else:
return select([query])


async def eventually_await_session(session, func, *args):
if inspect.iscoroutinefunction(getattr(session, func)):
await getattr(session, func)(*args)
else:
Expand Down
8 changes: 7 additions & 1 deletion graphene_sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ def is_graphene_version_less_than(version_string): # pragma: no cover

SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False

if not is_sqlalchemy_version_less_than("1.4"):
if not is_sqlalchemy_version_less_than("1.4"): # pragma: no cover
from sqlalchemy.ext.asyncio import AsyncSession

SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True


SQL_VERSION_HIGHER_EQUAL_THAN_2 = False

if not is_sqlalchemy_version_less_than("2.0.0b1"): # pragma: no cover
SQL_VERSION_HIGHER_EQUAL_THAN_2 = True


def get_session(context):
return context.get("session")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# To keep things simple, we only support newer versions of Graphene
"graphene>=3.0.0b7",
"promise>=2.3",
"SQLAlchemy>=1.1,<2",
"SQLAlchemy>=1.1",
"aiodataloader>=0.2.0,<1.0",
]

Expand Down
8 changes: 6 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = pre-commit,py{37,38,39,310}-sql{12,13,14}
envlist = pre-commit,py{37,38,39,310}-sql{12,13,14,20}
skipsdist = true
minversion = 3.7.0

Expand All @@ -15,6 +15,7 @@ SQLALCHEMY =
1.2: sql12
1.3: sql13
1.4: sql14
2.0: sql20

[testenv]
passenv = GITHUB_*
Expand All @@ -23,8 +24,11 @@ deps =
sql12: sqlalchemy>=1.2,<1.3
sql13: sqlalchemy>=1.3,<1.4
sql14: sqlalchemy>=1.4,<1.5
sql20: sqlalchemy>=2.0.0b3
setenv =
SQLALCHEMY_WARN_20 = 1
commands =
pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs}
python -W always -m pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs}

[testenv:pre-commit]
basepython=python3.10
Expand Down

0 comments on commit d0668cc

Please sign in to comment.