Skip to content

Commit

Permalink
Use to_flyte_literal_type() when comparing schema columns (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor authored Jun 22, 2020
1 parent 58fd9cd commit a33f156
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import flytekit.plugins

__version__ = '0.9.3'
__version__ = '0.9.4'
3 changes: 2 additions & 1 deletion flytekit/common/types/impl/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,8 @@ def cast_to(self, other_type):
additional_msg="Cannot cast because a required column '{}' was not found.".format(k),
received_value=self
)
if v != self.type.sdk_columns[k]:
if not isinstance(v, _base_sdk_types.FlyteSdkType) or \
v.to_flyte_literal_type() != self.type.sdk_columns[k].to_flyte_literal_type():
raise _user_exceptions.FlyteTypeException(
self.type.sdk_columns[k],
v,
Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/common_tests/types/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,20 @@ def test_typed_schema():
assert len(b.type.columns) == len(_ALL_COLUMN_TYPES)
assert list(b.type.sdk_columns.items()) == _ALL_COLUMN_TYPES
assert b.remote_location.startswith(t.name)


# Ensures that subclassing types works inside a schema.
def test_casting():
class MyDateTime(primitives.Datetime):
...

with test_utils.LocalTestFileSystem() as t:
test_columns_1 = [('altered', MyDateTime)]
test_columns_2 = [('altered', primitives.Datetime)]

instantiator_1 = schema.schema_instantiator(test_columns_1)
a = instantiator_1()

instantiator_2 = schema.schema_instantiator(test_columns_2)

a.cast_to(instantiator_2._schema_type)

0 comments on commit a33f156

Please sign in to comment.