Skip to content

Commit

Permalink
Merge pull request #154 from curvetips/fix-enum-conversion
Browse files Browse the repository at this point in the history
Fix creation of graphene.Enum from enum.Enum
  • Loading branch information
syrusakbary authored Oct 26, 2018
2 parents 33d5b74 + d4365e1 commit 6a96d37
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 11 deletions.
10 changes: 6 additions & 4 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,14 @@ def convert_column_to_float(type, column, registry=None):

@convert_sqlalchemy_type.register(types.Enum)
def convert_enum_to_enum(type, column, registry=None):
try:
items = type.enum_class.__members__.items()
except AttributeError:
enum_class = getattr(type, 'enum_class', None)
if enum_class: # Check if an enum.Enum type is used
graphene_type = Enum.from_enum(enum_class)
else: # Nope, just a list of string options
items = zip(type.enums, type.enums)
graphene_type = Enum(type.name, items)
return Field(
Enum(type.name, items),
graphene_type,
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
Expand Down
7 changes: 7 additions & 0 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import mapper, relationship


class Hairkind(enum.Enum):
LONG = 'long'
SHORT = 'short'


Base = declarative_base()

association_table = Table(
Expand All @@ -27,6 +33,7 @@ class Pet(Base):
id = Column(Integer(), primary_key=True)
name = Column(String(30))
pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False)
hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False)
reporter_id = Column(Integer(), ForeignKey("reporters.id"))


Expand Down
12 changes: 11 additions & 1 deletion graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,24 @@ def test_should_postgresql_uuid_convert():

def test_should_postgresql_enum_convert():
field = assert_column_conversion(
postgresql.ENUM(enum.Enum("one", "two"), name="two_numbers"), graphene.Field
postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field
)
field_type = field.type()
assert field_type.__class__.__name__ == "two_numbers"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")


def test_should_postgresql_py_enum_convert():
field = assert_column_conversion(
postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field
)
field_type = field.type()
assert field_type.__class__.__name__ == "TwoNumbers"
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, "two")


def test_should_postgresql_array_convert():
assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List)

Expand Down
84 changes: 78 additions & 6 deletions graphene_sqlalchemy/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..fields import SQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from ..utils import sort_argument_for_model, sort_enum_for_model
from .models import Article, Base, Editor, Pet, Reporter
from .models import Article, Base, Editor, Pet, Reporter, Hairkind

db = create_engine("sqlite:///test_sqlalchemy.sqlite3")

Expand All @@ -34,7 +34,7 @@ def session():


def setup_fixtures(session):
pet = Pet(name="Lassie", pet_kind="dog")
pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG)
session.add(pet)
reporter = Reporter(first_name="ABA", last_name="X")
session.add(reporter)
Expand Down Expand Up @@ -105,16 +105,88 @@ def resolve_pet(self, *args, **kwargs):
pet {
name,
petKind
hairKind
}
}
"""
expected = {"pet": {"name": "Lassie", "petKind": "dog"}}
expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
schema = graphene.Schema(query=Query)
result = schema.execute(query)
assert not result.errors
assert result.data == expected, result.data


def test_enum_parameter(session):
setup_fixtures(session)

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet

class Query(graphene.ObjectType):
pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type))

def resolve_pet(self, info, kind=None, *args, **kwargs):
query = session.query(Pet)
if kind:
query = query.filter(Pet.pet_kind == kind)
return query.first()

query = """
query PetQuery($kind: pet_kind) {
pet(kind: $kind) {
name,
petKind
hairKind
}
}
"""
expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, variables={"kind": "cat"})
assert not result.errors
assert result.data == {"pet": None}
result = schema.execute(query, variables={"kind": "dog"})
assert not result.errors
assert result.data == expected, result.data


def test_py_enum_parameter(session):
setup_fixtures(session)

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet

class Query(graphene.ObjectType):
pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type))

def resolve_pet(self, info, kind=None, *args, **kwargs):
query = session.query(Pet)
if kind:
# XXX Why kind passed in as a str instead of a Hairkind instance?
query = query.filter(Pet.hair_kind == Hairkind(kind))
return query.first()

query = """
query PetQuery($kind: Hairkind) {
pet(kind: $kind) {
name,
petKind
hairKind
}
}
"""
expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
schema = graphene.Schema(query=Query)
result = schema.execute(query, variables={"kind": "SHORT"})
assert not result.errors
assert result.data == {"pet": None}
result = schema.execute(query, variables={"kind": "LONG"})
assert not result.errors
assert result.data == expected, result.data


def test_should_node(session):
setup_fixtures(session)

Expand Down Expand Up @@ -326,9 +398,9 @@ class Mutation(graphene.ObjectType):

def sort_setup(session):
pets = [
Pet(id=2, name="Lassie", pet_kind="dog"),
Pet(id=22, name="Alf", pet_kind="cat"),
Pet(id=3, name="Barf", pet_kind="dog"),
Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG),
Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG),
Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG),
]
session.add_all(pets)
session.commit()
Expand Down

0 comments on commit 6a96d37

Please sign in to comment.