Skip to content

Commit

Permalink
Improved Type engine for generic types and performance (#2815)
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
5 people authored Oct 17, 2024
1 parent 1f8a273 commit 945d2df
Show file tree
Hide file tree
Showing 7 changed files with 575 additions and 389 deletions.
299 changes: 123 additions & 176 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def __init__(
self._to_literal_transformer = to_literal_transformer
self._from_literal_transformer = from_literal_transformer

@property
def base_type(self) -> Type:
return self._type

def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType:
return LiteralType.from_flyte_idl(self._lt.to_flyte_idl())

Expand Down Expand Up @@ -909,8 +913,9 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:


def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list: typing.List[tuple[Any, GenericAlias]] = []
attribute_list: typing.List[typing.Tuple[Any, Any]] = []
for property_key, property_val in schema["properties"].items():
property_type = ""
if property_val.get("anyOf"):
property_type = property_val["anyOf"][0]["type"]
elif property_val.get("enum"):
Expand All @@ -934,9 +939,8 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
)
)
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore
)
elem_type = _get_element_type(property_val["additionalProperties"])
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
else:
sub_schemea_name = property_val["title"]
attribute_list.append(
Expand Down Expand Up @@ -1003,114 +1007,64 @@ def register_additional_type(cls, transformer: TypeTransformer[T], additional_ty
cls._REGISTRY[additional_type] = transformer

@classmethod
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
The TypeEngine hierarchy for flyteKit. This method looks up and selects the type transformer. The algorithm is
as follows
d = dictionary of registered transformers, where is a python `type`
v = lookup type
Step 1:
If the type is annotated with a TypeTransformer instance, use that.
Step 2:
find a transformer that matches v exactly
Step 3:
find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc
Step 4:
Walk the inheritance hierarchy of v and find a transformer that matches the first base class.
This is potentially non-deterministic - will depend on the registration pattern.
Special case:
If v inherits from Enum, use the Enum transformer even if Enum is not the first base class.
TODO lets make this deterministic by using an ordered dict
Step 5:
if v is of type data class, use the dataclass transformer
Step 6:
Pickle transformer is used
"""
def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
cls.lazy_import_transformers()
# Step 1
if is_annotated(python_type):
args = get_args(python_type)
for annotation in args:
if isinstance(annotation, TypeTransformer):
return annotation
return cls.get_transformer(args[0])

python_type = args[0]

# Step 2
# this makes sure that if it's a list/dict of annotated types, we hit the unwrapping code in step 2
# see test_list_of_annotated in test_structured_dataset.py
if (
(not hasattr(python_type, "__origin__"))
or (
hasattr(python_type, "__origin__")
and (python_type.__origin__ is not list and python_type.__origin__ is not dict)
)
) and python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
return cls._ENUM_TRANSFORMER

# Step 3
if hasattr(python_type, "__origin__"):
# Handling of annotated generics, eg:
# Annotated[typing.List[int], 'foo']
if is_annotated(python_type):
return cls.get_transformer(get_args(python_type)[0])

# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
# or List[int] has been specifically registered; we should check for the entire type.
# The challenge is for StructuredDataset, example List[StructuredDataset] the column names is an OrderedDict
# are not hashable, thus looking up this type is not possible.
# In such as case, we will have to skip the "type" lookup and use the origin type only
try:
if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]
except TypeError:
pass
if python_type.__origin__ in cls._REGISTRY:
return cls._REGISTRY[python_type.__origin__]

raise ValueError(f"Generic Type {python_type.__origin__} not supported currently in Flytekit.")

# Step 4
# To facilitate cases where users may specify one transformer for multiple types that all inherit from one
# parent.
if inspect.isclass(python_type) and issubclass(python_type, enum.Enum):
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
return cls._ENUM_TRANSFORMER
# Handling UnionType specially - PEP 604
if sys.version_info >= (3, 10):
import types

from flytekit.types.iterator.json_iterator import JSONIterator
if isinstance(python_type, types.UnionType):
return cls._REGISTRY[types.UnionType]

for base_type in cls._REGISTRY.keys():
if base_type is None:
continue # None is actually one of the keys, but isinstance/issubclass doesn't work on it
try:
origin_type: Optional[typing.Any] = base_type
if hasattr(base_type, "__args__"):
origin_base_type = get_origin(base_type)
if isinstance(origin_base_type, type) and issubclass(
origin_base_type, typing.Iterator
): # Iterator[JSON]
origin_type = origin_base_type

if isinstance(python_type, origin_type) or ( # type: ignore[arg-type]
inspect.isclass(python_type) and issubclass(python_type, origin_type) # type: ignore[arg-type]
):
# Consider Iterator[JSON] but not vanilla Iterator when the value is a JSON iterator.
if (
isinstance(python_type, type)
and issubclass(python_type, JSONIterator)
and not get_args(base_type)
):
continue
return cls._REGISTRY[base_type]
except TypeError:
# As of python 3.9, calls to isinstance raise a TypeError if the base type is not a valid type, which
# is the case for one of the restricted types, namely NamedTuple.
logger.debug(f"Invalid base type {base_type} in call to isinstance", exc_info=True)
if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]

# Step 5
if dataclasses.is_dataclass(python_type):
return cls._DATACLASS_TRANSFORMER

# Step 6
return None

@classmethod
def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
Implements a recursive search for the transformer.
"""
v = cls._get_transformer(python_type)
if v is not None:
return v

if hasattr(python_type, "__mro__"):
class_tree = inspect.getmro(python_type)
for t in class_tree:
v = cls._get_transformer(t)
if v is not None:
return v

display_pickle_warning(str(python_type))
from flytekit.types.pickle.pickle import FlytePickleTransformer

Expand Down Expand Up @@ -2207,7 +2161,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
)
element_format = element_property["format"] if "format" in element_property else None

if type(element_type) == list:
if isinstance(element_type, list):
# Element type of Optional[int] is [integer, None]
return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore

Expand Down Expand Up @@ -2255,89 +2209,82 @@ def _check_and_convert_void(lv: Literal) -> None:
return None


def _register_default_type_transformers():
TypeEngine.register(
SimpleTransformer(
"int",
int,
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER),
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
lambda x: x.scalar.primitive.integer,
)
)

TypeEngine.register(
SimpleTransformer(
"float",
float,
_type_models.LiteralType(simple=_type_models.SimpleType.FLOAT),
lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))),
_check_and_covert_float,
)
)

TypeEngine.register(
SimpleTransformer(
"bool",
bool,
_type_models.LiteralType(simple=_type_models.SimpleType.BOOLEAN),
lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
lambda x: x.scalar.primitive.boolean,
)
)
IntTransformer = SimpleTransformer(
"int",
int,
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER),
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
lambda x: x.scalar.primitive.integer,
)

FloatTransformer = SimpleTransformer(
"float",
float,
_type_models.LiteralType(simple=_type_models.SimpleType.FLOAT),
lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))),
_check_and_covert_float,
)

BoolTransformer = SimpleTransformer(
"bool",
bool,
_type_models.LiteralType(simple=_type_models.SimpleType.BOOLEAN),
lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
lambda x: x.scalar.primitive.boolean,
)

StrTransformer = SimpleTransformer(
"str",
str,
_type_models.LiteralType(simple=_type_models.SimpleType.STRING),
lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x))),
lambda x: x.scalar.primitive.string_value,
)

DatetimeTransformer = SimpleTransformer(
"datetime",
datetime.datetime,
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
lambda x: x.scalar.primitive.datetime,
)

TimedeltaTransformer = SimpleTransformer(
"timedelta",
datetime.timedelta,
_type_models.LiteralType(simple=_type_models.SimpleType.DURATION),
lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
lambda x: x.scalar.primitive.duration,
)

DateTransformer = SimpleTransformer(
"date",
datetime.date,
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
lambda x: Literal(
scalar=Scalar(primitive=Primitive(datetime=datetime.datetime.combine(x, datetime.time.min)))
), # convert datetime to date
lambda x: x.scalar.primitive.datetime.date(), # get date from datetime
)

NoneTransformer = SimpleTransformer(
"none",
type(None),
_type_models.LiteralType(simple=_type_models.SimpleType.NONE),
lambda x: Literal(scalar=Scalar(none_type=Void())),
lambda x: _check_and_convert_void(x),
)

TypeEngine.register(
SimpleTransformer(
"str",
str,
_type_models.LiteralType(simple=_type_models.SimpleType.STRING),
lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x))),
lambda x: x.scalar.primitive.string_value,
)
)

TypeEngine.register(
SimpleTransformer(
"datetime",
datetime.datetime,
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
lambda x: x.scalar.primitive.datetime,
)
)

TypeEngine.register(
SimpleTransformer(
"timedelta",
datetime.timedelta,
_type_models.LiteralType(simple=_type_models.SimpleType.DURATION),
lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
lambda x: x.scalar.primitive.duration,
)
)

TypeEngine.register(
SimpleTransformer(
"date",
datetime.date,
_type_models.LiteralType(simple=_type_models.SimpleType.DATETIME),
lambda x: Literal(
scalar=Scalar(primitive=Primitive(datetime=datetime.datetime.combine(x, datetime.time.min)))
), # convert datetime to date
lambda x: x.scalar.primitive.datetime.date(), # get date from datetime
)
)

TypeEngine.register(
SimpleTransformer(
"none",
type(None),
_type_models.LiteralType(simple=_type_models.SimpleType.NONE),
lambda x: Literal(scalar=Scalar(none_type=Void())),
lambda x: _check_and_convert_void(x),
),
[None],
)
def _register_default_type_transformers():
TypeEngine.register(IntTransformer)
TypeEngine.register(FloatTransformer)
TypeEngine.register(StrTransformer)
TypeEngine.register(DatetimeTransformer)
TypeEngine.register(DateTransformer)
TypeEngine.register(TimedeltaTransformer)
TypeEngine.register(BoolTransformer)
TypeEngine.register(NoneTransformer, [None]) # noqa
TypeEngine.register(ListTransformer())
if sys.version_info >= (3, 10):
from types import UnionType
Expand Down
Loading

0 comments on commit 945d2df

Please sign in to comment.