Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic sort argument for SQLAlchemyInterface #400

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from sqlalchemy import types as sqa_types
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
ColumnProperty,
RelationshipProperty,
class_mapper,
interfaces,
strategies,
)
from sqlalchemy.ext.hybrid import hybrid_property

import graphene
from graphene.types.json import JSONString
Expand Down Expand Up @@ -159,7 +159,7 @@ def convert_sqlalchemy_relationship(
):
"""
:param sqlalchemy.RelationshipProperty relationship_prop:
:param SQLAlchemyObjectType obj_type:
:param SQLAlchemyBase obj_type:
:param function|None connection_field_factory:
:param bool batching:
:param str orm_field_name:
Expand Down Expand Up @@ -202,7 +202,7 @@ def _convert_o2o_or_m2o_relationship(
Convert one-to-one or many-to-one relationshsip. Return an object field.

:param sqlalchemy.RelationshipProperty relationship_prop:
:param SQLAlchemyObjectType obj_type:
:param SQLAlchemyBase obj_type:
:param bool batching:
:param str orm_field_name:
:param dict field_kwargs:
Expand Down Expand Up @@ -230,7 +230,7 @@ def _convert_o2m_or_m2m_relationship(
Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field.

:param sqlalchemy.RelationshipProperty relationship_prop:
:param SQLAlchemyObjectType obj_type:
:param SQLAlchemyBase obj_type:
:param bool batching:
:param function|None connection_field_factory:
:param dict field_kwargs:
Expand Down Expand Up @@ -362,7 +362,7 @@ def get_type_from_registry():
raise TypeError(
"No model found in Registry for type %s. "
"Only references to SQLAlchemy Models mapped to "
"SQLAlchemyObjectTypes are allowed." % type_arg
"SQLAlchemyBase types are allowed." % type_arg
)

return get_type_from_registry()
Expand Down Expand Up @@ -680,7 +680,7 @@ def forward_reference_solver():
raise TypeError(
"No model found in Registry for forward reference for type %s. "
"Only forward references to other SQLAlchemy Models mapped to "
"SQLAlchemyObjectTypes are allowed." % type_arg
"SQLAlchemyBase types are allowed." % type_arg
)
# Always fall back to string if no ForwardRef type found.
return get_global_registry().get_type_for_model(model)
Expand Down
14 changes: 7 additions & 7 deletions graphene_sqlalchemy/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def enum_for_sa_enum(sa_enum, registry):

def enum_for_field(obj_type, field_name):
"""Return the Graphene Enum type for the specified Graphene field."""
from .types import SQLAlchemyObjectType
from .types import SQLAlchemyBase

if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType):
raise TypeError("Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type))
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase):
raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type))
if not field_name or not isinstance(field_name, str):
raise TypeError("Expected a field name, but got: {!r}".format(field_name))
registry = obj_type._meta.registry
Expand Down Expand Up @@ -88,10 +88,10 @@ def _default_sort_enum_symbol_name(column_name, sort_asc=True):
def sort_enum_for_object_type(
obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None
):
"""Return Graphene Enum for sorting the given SQLAlchemyObjectType.
"""Return Graphene Enum for sorting the given SQLAlchemyBase.

Parameters
- obj_type : SQLAlchemyObjectType
- obj_type : SQLAlchemyBase
The object type for which the sort Enum shall be generated.
- name : str, optional, default None
Name to use for the sort Enum.
Expand Down Expand Up @@ -160,10 +160,10 @@ def sort_argument_for_object_type(
get_symbol_name=None,
has_default=True,
):
""" "Returns Graphene Argument for sorting the given SQLAlchemyObjectType.
""" "Returns Graphene Argument for sorting the given SQLAlchemyBase.

Parameters
- obj_type : SQLAlchemyObjectType
- obj_type : SQLAlchemyBase
The object type for which the sort Argument shall be generated.
- enum_name : str, optional, default None
Name to use for the sort Enum.
Expand Down
16 changes: 8 additions & 8 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@
class SQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType
from .types import SQLAlchemyBase

type_ = super(ConnectionField, self).type
nullable_type = get_nullable_type(type_)
if issubclass(nullable_type, Connection):
return type_
assert issubclass(nullable_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
assert issubclass(nullable_type, SQLAlchemyBase), (
"SQLALchemyConnectionField only accepts SQLAlchemyBase types, not {}"
).format(nullable_type.__name__)
assert nullable_type.connection, "The type {} doesn't have a connection".format(
nullable_type.__name__
)
assert type_ == nullable_type, (
"Passing a SQLAlchemyObjectType instance is deprecated. "
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
"Passing a SQLAlchemyBase instance is deprecated. "
"Pass the connection type instead accessible via SQLAlchemyBase.connection"
)
return nullable_type.connection

Expand Down Expand Up @@ -266,7 +266,7 @@ def default_connection_field_factory(relationship, registry, **field_kwargs):
def createConnectionField(type_, **field_kwargs):
warnings.warn(
"createConnectionField is deprecated and will be removed in the next "
"major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.",
"major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.",
DeprecationWarning,
)
return __connectionFactory(type_, **field_kwargs)
Expand All @@ -275,7 +275,7 @@ def createConnectionField(type_, **field_kwargs):
def registerConnectionFieldFactory(factoryMethod):
warnings.warn(
"registerConnectionFieldFactory is deprecated and will be removed in the next "
"major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.",
"major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.",
DeprecationWarning,
)
global __connectionFactory
Expand All @@ -285,7 +285,7 @@ def registerConnectionFieldFactory(factoryMethod):
def unregisterConnectionFieldFactory():
warnings.warn(
"registerConnectionFieldFactory is deprecated and will be removed in the next "
"major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.",
"major version. Use SQLAlchemyBase.Meta.connection_field_factory instead.",
DeprecationWarning,
)
global __connectionFactory
Expand Down
11 changes: 3 additions & 8 deletions graphene_sqlalchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,10 @@ def get_graphene_enum_for_sa_enum(self, sa_enum: SQLAlchemyEnumType):
return self._registry_enums.get(sa_enum)

def register_sort_enum(self, obj_type, sort_enum: Enum):
from .types import SQLAlchemyBase

from .types import SQLAlchemyObjectType

if not isinstance(obj_type, type) or not issubclass(
obj_type, SQLAlchemyObjectType
):
raise TypeError(
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
)
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyBase):
raise TypeError("Expected SQLAlchemyBase, but got: {!r}".format(obj_type))
if not isinstance(sort_enum, type(Enum)):
raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum))
self._registry_sort_enums[obj_type] = sort_enum
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_attr_resolver(obj_type, model_attr):
In order to support field renaming via `ORMField.model_attr`,
we need to define resolver functions for each field.

:param SQLAlchemyObjectType obj_type:
:param SQLAlchemyBase obj_type:
:param str model_attr: the name of the SQLAlchemy attribute
:rtype: Callable
"""
Expand Down
40 changes: 40 additions & 0 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,46 @@ class Employee(Person):
}


class Owner(Base):
id = Column(Integer(), primary_key=True)
name = Column(String())

accounts = relationship(lambda: Account, back_populates="owner", lazy="selectin")

__tablename__ = "owner"


class Account(Base):
id = Column(Integer(), primary_key=True)
type = Column(String())

owner_id = Column(Integer(), ForeignKey(Owner.__table__.c.id))
owner = relationship(Owner, back_populates="accounts")

balance = Column(Integer())

__tablename__ = "account"
__mapper_args__ = {
"polymorphic_on": type,
}


class CurrentAccount(Account):
overdraft = Column(Integer())

__mapper_args__ = {
"polymorphic_identity": "current",
}


class SavingsAccount(Account):
interest_rate = Column(Integer())

__mapper_args__ = {
"polymorphic_identity": "savings",
}


############################################
# Custom Test Models
############################################
Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def hybrid_prop(self) -> "MyTypeNotInRegistry":
with pytest.raises(
TypeError,
match=r"(.*)Only forward references to other SQLAlchemy Models mapped to "
"SQLAlchemyObjectTypes are allowed.(.*)",
"SQLAlchemyBase types are allowed.(.*)",
):
get_hybrid_property_type(hybrid_prop).type

Expand Down
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,6 @@ class Meta:
with pytest.raises(TypeError, match=re_err):
PetType.enum_for_field(None)

re_err = "Expected SQLAlchemyObjectType, but got: None"
re_err = "Expected SQLAlchemyBase, but got: None"
with pytest.raises(TypeError, match=re_err):
enum_for_field(None, "other_kind")
25 changes: 23 additions & 2 deletions graphene_sqlalchemy/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from graphene.relay import Connection, Node

from ..fields import SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from ..types import SQLAlchemyInterface, SQLAlchemyObjectType
from .models import Editor as EditorModel
from .models import Employee as EmployeeModel
from .models import Person as PersonModel
from .models import Pet as PetModel


Expand All @@ -21,6 +23,18 @@ class Meta:
model = EditorModel


class Person(SQLAlchemyInterface):
class Meta:
model = PersonModel
use_connection = True


class Employee(SQLAlchemyObjectType):
class Meta:
model = EmployeeModel
interfaces = (Person, Node)


##
# SQLAlchemyConnectionField
##
Expand Down Expand Up @@ -51,7 +65,7 @@ def resolver(_obj, _info):


def test_type_assert_sqlalchemy_object_type():
with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"):
with pytest.raises(AssertionError, match="only accepts SQLAlchemyBase types"):
SQLAlchemyConnectionField(ObjectType).type


Expand Down Expand Up @@ -91,3 +105,10 @@ def test_custom_sort():
def test_sort_init_raises():
with pytest.raises(TypeError, match="Cannot create sort"):
SQLAlchemyConnectionField(Connection)


def test_interface_required_sqlalachemy_connection():
field = SQLAlchemyConnectionField(Person.connection, required=True)
assert isinstance(field.type, NonNull)
assert issubclass(field.type.of_type, Connection)
assert field.type.of_type._meta.node is Person
Loading