Skip to content

Commit

Permalink
[MSGPACK IDL] Gate feature by setting ENV (#2894)
Browse files Browse the repository at this point in the history
* Gate MSGPACK IDL feature by setting ENV

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* Add async to def dict_to_old_generic_idl

Signed-off-by: Future-Outlier <[email protected]>

* print

Signed-off-by: Future-Outlier <[email protected]>

* Fix structured dataset bug

Signed-off-by: Future-Outlier <[email protected]>

* dict update

Signed-off-by: Future-Outlier <[email protected]>

* remvoe breakpoint

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* update

Signed-off-by: Future-Outlier <[email protected]>

* remove promise.py

Signed-off-by: Future-Outlier <[email protected]>

* ad tests

Signed-off-by: Future-Outlier <[email protected]>

* Add Comments

Signed-off-by: Future-Outlier <[email protected]>

* add commetns

Signed-off-by: Future-Outlier <[email protected]>

* apply naming suggestion from Kevin

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pingsutw  <[email protected]>

* nit

Signed-off-by: Future-Outlier <[email protected]>

* use flyte_use_old_dc_format as constant

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: pingsutw <[email protected]>
  • Loading branch information
Future-Outlier and pingsutw authored Nov 6, 2024
1 parent a9f4f22 commit 2fbdc63
Show file tree
Hide file tree
Showing 14 changed files with 6,605 additions and 544 deletions.
3 changes: 3 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@
# Binary IDL Serialization Format
MESSAGEPACK = "msgpack"

# Use the old way to create protobuf struct for dict, dataclass, and pydantic basemodel.
FLYTE_USE_OLD_DC_FORMAT = "FLYTE_USE_OLD_DC_FORMAT"

# Set this environment variable to true to force the task to return non-zero exit code on failure.
FLYTE_FAIL_ON_ERROR = "FLYTE_FAIL_ON_ERROR"
6 changes: 6 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise:
if len(p.attr_path) > 0 and type(curr_val.value) is _literals_models.Scalar:
# We keep it for reference task local execution in the future.
if type(curr_val.value.value) is _struct.Struct:
"""
Local execution currently has issues with struct attribute resolution.
This works correctly in remote execution.
Issue Link: https://github.com/flyteorg/flyte/issues/5959
"""

st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
Expand Down
103 changes: 94 additions & 9 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import json
import mimetypes
import os
import sys
import textwrap
import threading
Expand All @@ -29,17 +30,17 @@
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.utils import load_proto_from_file, timeit
from flytekit.core.utils import load_proto_from_file, str2bool, timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.lazy_import.lazy_module import is_imported
Expand Down Expand Up @@ -498,7 +499,8 @@ class Test(DataClassJsonMixin):

def __init__(self) -> None:
super().__init__("Object-Dataclass-Transformer", object)
self._decoder: Dict[Type, JSONDecoder] = dict()
self._json_encoder: Dict[Type, JSONEncoder] = dict()
self._json_decoder: Dict[Type, JSONDecoder] = dict()

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
Expand Down Expand Up @@ -655,14 +657,58 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
)

# This is for attribute access in FlytePropeller.
ts = TypeStructure(tag="", dataclass_type=literal_type)

return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_generic_literal(
self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType
) -> Literal:
"""
Serializes a dataclass or dictionary to a Flyte literal, handling both JSON and MessagePack formats.
Set `FLYTE_USE_OLD_DC_FORMAT=true` to use the old JSON-based format.
Note: This is deprecated and will be removed in the future.
"""
if isinstance(python_val, dict):
json_str = json.dumps(python_val)
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())))

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
f"user defined datatypes in Flytekit"
)

self._make_dataclass_serializable(python_val, python_type)

# JSON serialization using mashumaro's DataClassJSONMixin
if isinstance(python_val, DataClassJSONMixin):
json_str = python_val.to_json()
else:
try:
encoder = self._json_encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._json_encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
except NotImplementedError:
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)):
return self.to_generic_literal(ctx, python_val, python_type, expected)

if isinstance(python_val, dict):
msgpack_bytes = msgpack.dumps(python_val)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
Expand Down Expand Up @@ -697,7 +743,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
Expand Down Expand Up @@ -863,10 +909,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
decoder = self._json_decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder
self._json_decoder[expected_python_type] = decoder

dc = decoder.decode(json_str)

Expand Down Expand Up @@ -1929,6 +1975,43 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return _args # type: ignore
return None, None

@staticmethod
async def dict_to_generic_literal(
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
) -> Literal:
"""
This is deprecated from flytekit 1.14.0.
Creates a flyte-specific ``Literal`` value from a native python dictionary.
Note: This is deprecated and will be removed in the future.
"""
from flytekit.types.pickle import FlytePickle

try:
try:
# JSONEncoder is mashumaro's codec and this can triggered Flyte Types customized serialization and deserialization.
encoder = JSONEncoder(python_type)
json_str = encoder.encode(v)
except NotImplementedError:
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(
scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())),
metadata={"format": "json"},
)
except TypeError as e:
if allow_pickle:
remote_path = await FlytePickle.to_pickle(ctx, v)
return Literal(
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
),
metadata={"format": "pickle"},
)
raise TypeTransformerFailedError(f"Cannot convert `{v}` to Flyte Literal.\n" f"Error Message: {e}")

@staticmethod
async def dict_to_binary_literal(
ctx: FlyteContext, v: dict, python_type: Type[dict], allow_pickle: bool
Expand All @@ -1943,7 +2026,7 @@ async def dict_to_binary_literal(
# Handle dictionaries with non-string keys (e.g., Dict[int, Type])
encoder = MessagePackEncoder(python_type)
msgpack_bytes = encoder.encode(v)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
except TypeError as e:
if allow_pickle:
remote_path = await FlytePickle.to_pickle(ctx, v)
Expand Down Expand Up @@ -2004,6 +2087,8 @@ async def async_to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)):
return await self.dict_to_generic_literal(ctx, python_val, python_type, allow_pickle)
return await self.dict_to_binary_literal(ctx, python_val, python_type, allow_pickle)

lit_map = {}
Expand Down
4 changes: 2 additions & 2 deletions flytekit/extras/pydantic_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

from . import transformer
except (ImportError, OSError) as e:
logger.warning(f"Meet error when importing pydantic: `{e}`")
logger.warning("Flytekit only support pydantic version > 2.")
logger.debug(f"Meet error when importing pydantic: `{e}`")
logger.debug("Flytekit only support pydantic version > 2.")
2 changes: 1 addition & 1 deletion flytekit/extras/pydantic_transformer/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
It looks nicer in the real Flyte File/Directory class, but we also want it to not fail.
"""

logger.warning(
logger.debug(
"Pydantic is not installed.\n" "Please install Pydantic version > 2 to use FlyteTypes in pydantic BaseModel."
)

Expand Down
22 changes: 21 additions & 1 deletion flytekit/extras/pydantic_transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import json
import os
from typing import Type

import msgpack
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from pydantic import BaseModel

from flytekit import FlyteContext
from flytekit.core.constants import MESSAGEPACK
from flytekit.core.constants import FLYTE_USE_OLD_DC_FORMAT, MESSAGEPACK
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.utils import str2bool
from flytekit.loggers import logger
from flytekit.models import types
from flytekit.models.literals import Binary, Literal, Scalar
Expand All @@ -31,10 +34,24 @@ def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e)
)

# This is for attribute access in FlytePropeller.
ts = TypeStructure(tag="", dataclass_type=literal_type)

return types.LiteralType(simple=types.SimpleType.STRUCT, metadata=schema, structure=ts)

def to_generic_literal(
self,
ctx: FlyteContext,
python_val: BaseModel,
python_type: Type[BaseModel],
expected: types.LiteralType,
) -> Literal:
"""
Note: This is deprecated and will be removed in the future.
"""
json_str = python_val.model_dump_json()
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())))

def to_literal(
self,
ctx: FlyteContext,
Expand All @@ -47,6 +64,9 @@ def to_literal(
This is for handling enum in basemodel.
More details: https://github.com/flyteorg/flytekit/pull/2792
"""
if str2bool(os.getenv(FLYTE_USE_OLD_DC_FORMAT)):
return self.to_generic_literal(ctx, python_val, python_type, expected)

json_str = python_val.model_dump_json()
dict_obj = json.loads(json_str)
msgpack_bytes = msgpack.dumps(dict_obj)
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ async def async_to_literal(
# that we will need to invoke an encoder for. Figure out which encoder to call and invoke it.
df_type = type(python_val.dataframe)
protocol = self._protocol_from_type_or_prefix(ctx, df_type, python_val.uri)

return self.encode(
ctx,
python_val,
Expand Down
Loading

0 comments on commit 2fbdc63

Please sign in to comment.