diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index faedb8d2..72089613 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -146,12 +146,14 @@ def convert_column_to_float(type, column, registry=None): @convert_sqlalchemy_type.register(types.Enum) def convert_enum_to_enum(type, column, registry=None): - try: - items = type.enum_class.__members__.items() - except AttributeError: + enum_class = getattr(type, 'enum_class', None) + if enum_class: # Check if an enum.Enum type is used + graphene_type = Enum.from_enum(enum_class) + else: # Nope, just a list of string options items = zip(type.enums, type.enums) + graphene_type = Enum(type.name, items) return Field( - Enum(type.name, items), + graphene_type, description=get_column_doc(column), required=not (is_column_nullable(column)), ) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index be7d5cd2..3ba23a8a 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,6 +6,12 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import mapper, relationship + +class Hairkind(enum.Enum): + LONG = 'long' + SHORT = 'short' + + Base = declarative_base() association_table = Table( @@ -27,6 +33,7 @@ class Pet(Base): id = Column(Integer(), primary_key=True) name = Column(String(30)) pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) + hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 9055f7fa..c2ec3e49 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -271,7 +271,7 @@ def test_should_postgresql_uuid_convert(): def test_should_postgresql_enum_convert(): field = assert_column_conversion( - postgresql.ENUM(enum.Enum("one", "two"), name="two_numbers"), graphene.Field + postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field ) field_type = field.type() assert field_type.__class__.__name__ == "two_numbers" @@ -279,6 +279,16 @@ def test_should_postgresql_enum_convert(): assert hasattr(field_type, "two") +def test_should_postgresql_py_enum_convert(): + field = assert_column_conversion( + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field + ) + field_type = field.type() + assert field_type.__class__.__name__ == "TwoNumbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "two") + + def test_should_postgresql_array_convert(): assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index f8bc8403..f1116d9d 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -9,7 +9,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType from ..utils import sort_argument_for_model, sort_enum_for_model -from .models import Article, Base, Editor, Pet, Reporter +from .models import Article, Base, Editor, Pet, Reporter, Hairkind db = create_engine("sqlite:///test_sqlalchemy.sqlite3") @@ -34,7 +34,7 @@ def session(): def setup_fixtures(session): - pet = Pet(name="Lassie", pet_kind="dog") + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG) session.add(pet) reporter = Reporter(first_name="ABA", last_name="X") session.add(reporter) @@ -105,16 +105,88 @@ def resolve_pet(self, *args, **kwargs): pet { name, petKind + hairKind } } """ - expected = {"pet": {"name": "Lassie", "petKind": "dog"}} + expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors assert result.data == expected, result.data +def test_enum_parameter(session): + setup_fixtures(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) + + def resolve_pet(self, info, kind=None, *args, **kwargs): + query = session.query(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind) + return query.first() + + query = """ + query PetQuery($kind: pet_kind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "cat"}) + assert not result.errors + assert result.data == {"pet": None} + result = schema.execute(query, variables={"kind": "dog"}) + assert not result.errors + assert result.data == expected, result.data + + +def test_py_enum_parameter(session): + setup_fixtures(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type)) + + def resolve_pet(self, info, kind=None, *args, **kwargs): + query = session.query(Pet) + if kind: + # XXX Why kind passed in as a str instead of a Hairkind instance? + query = query.filter(Pet.hair_kind == Hairkind(kind)) + return query.first() + + query = """ + query PetQuery($kind: Hairkind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "SHORT"}) + assert not result.errors + assert result.data == {"pet": None} + result = schema.execute(query, variables={"kind": "LONG"}) + assert not result.errors + assert result.data == expected, result.data + + def test_should_node(session): setup_fixtures(session) @@ -326,9 +398,9 @@ class Mutation(graphene.ObjectType): def sort_setup(session): pets = [ - Pet(id=2, name="Lassie", pet_kind="dog"), - Pet(id=22, name="Alf", pet_kind="cat"), - Pet(id=3, name="Barf", pet_kind="dog"), + Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG), + Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG), + Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG), ] session.add_all(pets) session.commit()