Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Flytekit should support unsafe mode for types #2419

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
9 changes: 8 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def transform_function_to_interface(
fn: typing.Callable,
docstring: Optional[Docstring] = None,
is_reference_entity: bool = False,
unsafe: bool = False,
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
Expand All @@ -395,6 +396,9 @@ def transform_function_to_interface(
type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
# If the return annotation is None and the unsafe is True, we will use it as Any
if return_annotation is None and unsafe:
return_annotation = Any

ctx = FlyteContextManager.current_context()

Expand All @@ -420,7 +424,10 @@ def transform_function_to_interface(
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
if annotation is None:
raise FlyteMissingTypeException(fn=fn, param_name=k)
if not unsafe:
raise FlyteMissingTypeException(fn=fn, param_name=k)
# If the unsafe is True, we will use it as Any
annotation = Any
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
node_dependency_hints: Optional[
Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]
] = None,
unsafe: bool = False,
**kwargs,
):
"""
Expand All @@ -128,7 +129,9 @@ def __init__(
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
self._native_interface = transform_function_to_interface(task_function, Docstring(callable_=task_function))
self._native_interface = transform_function_to_interface(
task_function, Docstring(callable_=task_function), unsafe=unsafe
)
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
name, _, _, _ = extract_task_module(task_function)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
unsafe: bool = ...,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to have a more explicit name. Something like:

Suggested change
unsafe: bool = ...,
pickle_untyped: bool = ...,

) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -168,6 +169,7 @@ def task(
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
unsafe: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...


Expand Down Expand Up @@ -211,6 +213,7 @@ def task(
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
unsafe: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Expand Down Expand Up @@ -333,6 +336,7 @@ def launch_dynamically():
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param unsafe: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -368,6 +372,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
unsafe=unsafe,
)
update_wrapper(task_instance, decorated_fn)
return task_instance
Expand Down
9 changes: 7 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
Expand Down Expand Up @@ -1124,7 +1124,12 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if (python_val is None and python_type != type(None)) and expected and expected.union_type is None:
if (
(python_val is None and python_type != type(None))
and expected
and expected.union_type is None
and python_type is not Any
):
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,11 @@ def __init__(
docstring: Optional[Docstring] = None,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
unsafe: bool = False,
):
name, _, _, _ = extract_task_module(workflow_function)
self._workflow_function = workflow_function
native_interface = transform_function_to_interface(workflow_function, docstring=docstring)
native_interface = transform_function_to_interface(workflow_function, docstring=docstring, unsafe=unsafe)

# TODO do we need this - can this not be in launchplan only?
# This can be in launch plan only, but is here only so that we don't have to re-evaluate. Or
Expand Down Expand Up @@ -828,6 +829,7 @@ def workflow(
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
unsafe: bool = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...


Expand All @@ -847,6 +849,7 @@ def workflow(
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
unsafe: bool = False,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
Expand Down Expand Up @@ -878,6 +881,8 @@ def workflow(
:param on_failure: Invoke this workflow or task on failure. The Workflow / task has to match the signature of
the current workflow, with an additional parameter called `error` Error
:param docs: Description entity for the workflow
:param unsafe: This is a flag that allows users to bypass the type-checking that Flytekit does when constructing
the workflow. This is not recommended for general use.
"""

def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
Expand All @@ -892,6 +897,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
docstring=Docstring(callable_=fn),
on_failure=on_failure,
docs=docs,
unsafe=unsafe,
)
update_wrapper(workflow_instance, fn)
return workflow_instance
Expand Down
11 changes: 7 additions & 4 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import typing
from typing import Type
from typing import Optional, Type

import cloudpickle

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Void
from flytekit.models.types import LiteralType

T = typing.TypeVar("T")
Expand Down Expand Up @@ -86,13 +86,16 @@ def assert_type(self, t: Type[T], v: T):
# Every type can serialize to pickle, so we don't need to check the type here.
...

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]:
if lv.scalar.blob is None:
return None
uri = lv.scalar.blob.uri
return FlytePickle.from_pickle(uri)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if python_val is None:
raise AssertionError("Cannot pickle None Value.")
# raise AssertionError("Cannot pickle None Value.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sure. I'll remove that recently

return Literal(scalar=Scalar(none_type=Void()))
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,9 +2061,9 @@ def __init__(self, number: int):
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert Foo(1).number == pv.number

with pytest.raises(AssertionError, match="Cannot pickle None Value"):
lt = TypeEngine.to_literal_type(typing.Optional[typing.Any])
TypeEngine.to_literal(ctx, None, FlytePickle, lt)
lt = TypeEngine.to_literal_type(typing.Optional[typing.Any])
lv = TypeEngine.to_literal(ctx, None, FlytePickle, lt)
assert lv.scalar.none_type == Void()

with pytest.raises(
AssertionError,
Expand Down
114 changes: 111 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flytekit.core.testing import patch, task_mock
from flytekit.core.type_engine import RestrictedTypeError, SimpleTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException
from flytekit.exceptions.user import FlyteValidationException, FlyteFailureNodeInputMismatchException, FlyteMissingTypeException
from flytekit.models import literals as _literal_models
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
Expand Down Expand Up @@ -81,7 +81,9 @@ def test_forwardref_namedtuple_output():
# This test case tests typing.NamedTuple outputs for cases where eg.
# from __future__ import annotations is enabled, such that all type hints become ForwardRef
@task
def my_task(a: int) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")):
def my_task(
a: int,
) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int"), c=typing.ForwardRef("str")):
ctx = flytekit.current_context()
assert str(ctx.execution_id) == "ex:local:local:local"
return a + 2, "hello world"
Expand Down Expand Up @@ -1967,7 +1969,12 @@ def wf() -> pd.DataFrame:

df = wf()

expected_df = pd.DataFrame(data={"col1": [1 + 10 + 100, 2 + 20 + 200], "col2": [3 + 30 + 300, 4 + 40 + 400]})
expected_df = pd.DataFrame(
data={
"col1": [1 + 10 + 100, 2 + 20 + 200],
"col2": [3 + 30 + 300, 4 + 40 + 400],
}
)
assert expected_df.equals(df)


Expand Down Expand Up @@ -2052,3 +2059,104 @@ def my_wf(a: int, retries: int) -> int:

with pytest.raises(AssertionError):
my_wf(a=1, retries=1)


def test_unsafe_input_wf_and_task():
@task(unsafe=True)
def t1(a) -> int:
if type(a) == int:
return a + 1
return 0

with pytest.raises(FlyteMissingTypeException):
@task
def t2_wo_unsafe(a) -> int:
return a + 1

@workflow(unsafe=True)
def wf1_with_unsafe(a) -> int:
return t1(a=a)

assert wf1_with_unsafe(a=1) == 2
assert wf1_with_unsafe(a="1") == 0
assert wf1_with_unsafe(a=None) == 0

with pytest.raises(FlyteMissingTypeException):
@workflow
def wf1_wo_unsafe(a) -> int:
return t1(a=a)


def test_unsafe_wf_and_task():
@task(unsafe=True)
def t1(a):
if type(a) != int:
return None
return a + 1

@task(unsafe=True)
def t2(a):
if type(a) != int:
return None
return a + 2

@workflow(unsafe=True)
def wf1_with_unsafe(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_unsafe(a=1) == 4
assert wf1_with_unsafe(a="1") is None


def test_wf_with_unsafe_and_safe_tasks():
@task(unsafe=True)
def t1(a):
if type(a) != int:
return None
return a + 1

@task
def t2(a: typing.Any) -> typing.Any:
if type(a) != int:
return None
return a + 2

@workflow(unsafe=True)
def wf1_with_unsafe(a):
a1 = t1(a=a)
return t2(a=a1)

assert wf1_with_unsafe(a=1) == 4
assert wf1_with_unsafe(a="1") is None

@workflow(unsafe=True)
def wf2_with_unsafe(a):
a1 = t2(a=a)
return t1(a=a1)

assert wf2_with_unsafe(a=1) == 4
assert wf2_with_unsafe(a="1") is None


def test_unsafe_task_with_specified_input():
@task(unsafe=True)
def t1(a, b: typing.Any):
if type(a) != int:
if type(b) != int:
return None
else:
return b
elif type(b) != int:
return a
return a + b

@workflow(unsafe=True)
def wf1_with_unsafe(a: typing.Any, b):
r = t1(a=a, b=b)
return r

assert wf1_with_unsafe(a=1, b=2) == 3
assert wf1_with_unsafe(a="1", b=2) == 2
assert wf1_with_unsafe(a=1, b="2") == 1
assert wf1_with_unsafe(a="1", b="2") is None
Loading