Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Flytekit should support unsafe mode for types #2419

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
9 changes: 8 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def transform_function_to_interface(
fn: typing.Callable,
docstring: Optional[Docstring] = None,
is_reference_entity: bool = False,
pickle_untyped: bool = False,
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
Expand All @@ -395,6 +396,9 @@ def transform_function_to_interface(
type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
# If the return annotation is None and the pickle_untyped is True, we will use it as Any
if return_annotation is None and pickle_untyped:
return_annotation = Any

ctx = FlyteContextManager.current_context()

Expand All @@ -420,7 +424,10 @@ def transform_function_to_interface(
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
if annotation is None:
raise FlyteMissingTypeException(fn=fn, param_name=k)
if not pickle_untyped:
raise FlyteMissingTypeException(fn=fn, param_name=k)
# If the pickle_untyped is True, we will use it as Any
annotation = Any
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
node_dependency_hints: Optional[
Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]
] = None,
pickle_untyped: bool = False,
Copy link
Member

@thomasjpfan thomasjpfan Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumare3 Since you proposed the safe parameter name, are you okay with using pickle_untyped?

My concerned is that safe=False could mean so many unsafe behavior. pickle_untyped is explicit about the behavior.

**kwargs,
):
"""
Expand All @@ -125,10 +126,15 @@ def __init__(
:param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints:
A list of tasks, launchplans, or workflows that this task depends on. This is only
for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime.
:param bool pickle_untyped: If set to True, the task will pickle untyped outputs. This is just a convenience
flag to avoid having to specify the output types in the interface. This is not recommended for production
use.
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
self._native_interface = transform_function_to_interface(task_function, Docstring(callable_=task_function))
self._native_interface = transform_function_to_interface(
task_function, Docstring(callable_=task_function), pickle_untyped=pickle_untyped
)
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
name, _, _, _ = extract_task_module(task_function)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -168,6 +169,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -211,6 +213,7 @@ def task(
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Expand Down Expand Up @@ -333,6 +336,7 @@ def launch_dynamically():
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -368,6 +372,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
pickle_untyped=pickle_untyped,
)
update_wrapper(task_instance, decorated_fn)
return task_instance
Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,12 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if (python_val is None and python_type != type(None)) and expected and expected.union_type is None:
if (
(python_val is None and python_type != type(None))
and expected
and expected.union_type is None
and python_type is not Any
):
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
11 changes: 10 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,14 @@ def __init__(
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
):
name, _, _, _ = extract_task_module(workflow_function)
self._workflow_function = workflow_function
native_interface = transform_function_to_interface(workflow_function, docstring=docstring)
native_interface = transform_function_to_interface(
workflow_function, docstring=docstring, pickle_untyped=pickle_untyped
)

# TODO do we need this - can this not be in launchplan only?
# This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
Expand Down Expand Up @@ -837,6 +840,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...

Expand All @@ -848,6 +852,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...

Expand All @@ -858,6 +863,7 @@ def workflow(
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
Expand Down Expand Up @@ -890,6 +896,8 @@ def workflow(
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
:param pickle_untyped: This is a flag that allows users to bypass the type-checking that Flytekit does when constructing
the workflow. This is not recommended for general use.
:param default_options: Default options for the workflow when creating a default launch plan. Currently only
the labels and annotations are allowed to be set as defaults.
"""
Expand All @@ -906,6 +914,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
pickle_untyped=pickle_untyped,
default_options=default_options,
)
update_wrapper(workflow_instance, fn)
Expand Down
11 changes: 7 additions & 4 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import typing
from typing import Type
from typing import Optional, Type

import cloudpickle

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Void
from flytekit.models.types import LiteralType

T = typing.TypeVar("T")
Expand Down Expand Up @@ -86,13 +86,16 @@ def assert_type(self, t: Type[T], v: T):
# Every type can serialize to pickle, so we don't need to check the type here.
...

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]:
if lv.scalar.blob is None:
return None
uri = lv.scalar.blob.uri
return FlytePickle.from_pickle(uri)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if python_val is None:
raise AssertionError("Cannot pickle None Value.")
# raise AssertionError("Cannot pickle None Value.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sure. I'll remove that recently

return Literal(scalar=Scalar(none_type=Void()))
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,9 +2070,9 @@ def __init__(self, number: int):
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert Foo(1).number == pv.number

with pytest.raises(AssertionError, match="Cannot pickle None Value"):
lt = TypeEngine.to_literal_type(typing.Optional[typing.Any])
TypeEngine.to_literal(ctx, None, FlytePickle, lt)
lt = TypeEngine.to_literal_type(typing.Optional[typing.Any])
lv = TypeEngine.to_literal(ctx, None, FlytePickle, lt)
assert lv.scalar.none_type == Void()

with pytest.raises(
AssertionError,
Expand Down
114 changes: 111 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flytekit.core.testing import patch, task_mock
from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException
from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException, FlyteMissingTypeException
from flytekit.models import literals as _literal_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
Expand Down Expand Up @@ -82,7 +82,9 @@ def test_forwardref_namedtuple_output():
# This test case tests typing.NamedTuple outputs for cases where eg.
# from __future__ import annotations is enabled, such that all type hints become ForwardRef
@task
def my_task(a: int) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")):
def my_task(
a: int,
) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")):
ctx = flytekit.current_context()
assert str(ctx.execution_id) == "ex:local:local:local"
return a + 2, "hello world"
Expand Down Expand Up @@ -1967,7 +1969,12 @@ def wf() -> pd.DataFrame:

df = wf()

expected_df = pd.DataFrame(data={"col1": [1 + 10 + 100, 2 + 20 + 200], "col2": [3 + 30 + 300, 4 + 40 + 400]})
expected_df = pd.DataFrame(
data={
"col1": [1 + 10 + 100, 2 + 20 + 200],
"col2": [3 + 30 + 300, 4 + 40 + 400],
}
)
assert expected_df.equals(df)


Expand Down Expand Up @@ -2052,3 +2059,104 @@ def my_wf(a: int, retries: int) -> int:

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)


def test_pickle_untyped_input_wf_and_task():
@task(pickle_untyped=True)
def t1(a) -> int:
if type(a) == int:
return a + 1
return 0

with pytest.raises(FlyteMissingTypeException):
@task
def t2_wo_pickle_untyped(a) -> int:
return a + 1

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a) -> int:
return t1(a=a)

assert wf1_with_pickle_untyped(a=1) == 2
assert wf1_with_pickle_untyped(a="1") == 0
assert wf1_with_pickle_untyped(a=None) == 0

with pytest.raises(FlyteMissingTypeException):
@workflow
def wf1_wo_pickle_untyped(a) -> int:
return t1(a=a)


def test_pickle_untyped_wf_and_task():
@task(pickle_untyped=True)
def t1(a):
if type(a) != int:
return None
return a + 1

@task(pickle_untyped=True)
def t2(a):
if type(a) != int:
return None
return a + 2

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_pickle_untyped(a=1) == 4
assert wf1_with_pickle_untyped(a="1") is None


def test_wf_with_pickle_untyped_and_safe_tasks():
@task(pickle_untyped=True)
def t1(a):
if type(a) != int:
return None
return a + 1

@task
def t2(a: typing.Any) -> typing.Any:
if type(a) != int:
return None
return a + 2

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_pickle_untyped(a=1) == 4
assert wf1_with_pickle_untyped(a="1") is None

@workflow(pickle_untyped=True)
def wf2_with_pickle_untyped(a):
a1 = t2(a=a)
return t1(a=a1)

assert wf2_with_pickle_untyped(a=1) == 4
assert wf2_with_pickle_untyped(a="1") is None


def test_pickle_untyped_task_with_specified_input():
@task(pickle_untyped=True)
def t1(a, b: typing.Any):
if type(a) != int:
if type(b) != int:
return None
else:
return b
elif type(b) != int:
return a
return a + b

@workflow(pickle_untyped=True)
def wf1_with_pickle_untyped(a: typing.Any, b):
r = t1(a=a, b=b)
return r

assert wf1_with_pickle_untyped(a=1, b=2) == 3
assert wf1_with_pickle_untyped(a="1", b=2) == 2
assert wf1_with_pickle_untyped(a=1, b="2") == 1
assert wf1_with_pickle_untyped(a="1", b="2") is None
Loading