Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support raw container in the map task #1547

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type, Union

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
name: str,
image: str,
command: List[str],
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[Union[Dict[str, Tuple[Type, Any]], OrderedDict[str, Type]]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
Expand Down
34 changes: 25 additions & 9 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from itertools import count
from typing import Any, Dict, List, Optional

from flytekit import ContainerTask
from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.base_task import PythonTask
Expand All @@ -35,7 +36,7 @@ class MapPythonTask(PythonTask):

def __init__(
self,
python_function_task: PythonFunctionTask,
python_function_task: typing.Union[PythonFunctionTask, ContainerTask],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
**kwargs,
Expand All @@ -55,8 +56,11 @@ def __init__(

collection_interface = transform_interface_to_list_interface(python_function_task.python_interface)
instance = next(self._ids)
_, mod, f, _ = tracker.extract_task_module(python_function_task.task_function)
name = f"{mod}.mapper_{f}_{instance}"
if isinstance(python_function_task, ContainerTask):
name = f"raw_container_task.mapper_{python_function_task.name}_{instance}"
else:
_, mod, f, _ = tracker.extract_task_module(python_function_task.task_function)
name = f"{mod}.mapper_{f}_{instance}"

self._cmd_prefix = None
self._run_task = python_function_task
Expand Down Expand Up @@ -114,14 +118,20 @@ def prepare_target(self):
self._run_task.reset_command_fn()

def get_container(self, settings: SerializationSettings) -> Container:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_container(settings)
with self.prepare_target():
return self._run_task.get_container(settings)

def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_k8s_pod(settings)
with self.prepare_target():
return self._run_task.get_k8s_pod(settings)

def get_sql(self, settings: SerializationSettings) -> Sql:
if isinstance(self._run_task, ContainerTask):
return self._run_task.get_sql(settings)
with self.prepare_target():
return self._run_task.get_sql(settings)

Expand Down Expand Up @@ -221,7 +231,12 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs):
def map_task(
task_function: typing.Union[PythonFunctionTask, ContainerTask],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
):
"""
Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of
any individual :py:class:`flytekit.PythonFunctionTask`.
Expand Down Expand Up @@ -267,8 +282,9 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_succes
successfully before terminating this task and marking it successful.

"""
if not isinstance(task_function, PythonFunctionTask):
raise ValueError(
f"Only Flyte python task types are supported in map tasks currently, received {type(task_function)}"
)
return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs)
if isinstance(task_function, PythonFunctionTask) or isinstance(task_function, ContainerTask):
return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs)

raise ValueError(
f"Only Flyte python-task, and raw-container types are supported in map tasks currently, received {type(task_function)}"
)
34 changes: 33 additions & 1 deletion tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import flytekit.configuration
from flytekit import LaunchPlan, map_task
from flytekit import ContainerTask, LaunchPlan, kwtypes, map_task
from flytekit.configuration import Image, ImageConfig
from flytekit.core.map_task import MapPythonTask
from flytekit.core.task import TaskMetadata, task
Expand All @@ -24,6 +24,22 @@ def serialization_settings():
)


raw_container = ContainerTask(
name="ellipse-area-metadata-python",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=int),
outputs=kwtypes(area=float),
image="flyte/raw-container:v1",
command=[
"python",
"test.py",
"{{.inputs.a}}",
"/var/outputs",
],
)


@task
def t1(a: int) -> str:
b = a + 2
Expand Down Expand Up @@ -96,6 +112,22 @@ def test_serialization(serialization_settings):
]


def test_serialization_with_raw_container(serialization_settings):
maptask = map_task(raw_container, metadata=TaskMetadata(retries=1))
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

# By default all map_task tasks will have their custom fields set.
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "container_array"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.command == [
"python",
"test.py",
"{{.inputs.a}}",
"/var/outputs",
]


@pytest.mark.parametrize(
"custom_fields_dict, expected_custom_fields",
[
Expand Down