Skip to content

Commit

Permalink
Use ImageSpec inside raw container task (#1944)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Nov 22, 2023
1 parent 38b15ca commit e5c3c78
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 52 deletions.
14 changes: 10 additions & 4 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import typing
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, OrderedDict, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.context_manager import FlyteContext
from flytekit.core.interface import Interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext

Expand Down Expand Up @@ -37,9 +40,9 @@ class IOStrategy(Enum):
def __init__(
self,
name: str,
image: str,
image: typing.Union[str, ImageSpec],
command: List[str],
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[OrderedDict[str, Type]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
Expand Down Expand Up @@ -112,8 +115,11 @@ def _get_data_loading_config(self) -> _task_model.DataLoadingConfig:
def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = settings.env or {}
env = {**env, **self.environment} if self.environment else env
if isinstance(self._image, ImageSpec):
if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled:
self._image.source_root = settings.source_root
return _get_container_definition(
image=self._image,
image=get_registerable_container_image(self._image, settings.image_config),
command=self._cmd,
args=self._args,
data_loading_config=self._get_data_loading_config(),
Expand Down
10 changes: 3 additions & 7 deletions tests/flytekit/unit/cli/pyflyte/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
from click.testing import CliRunner

from flytekit.clis.sdk_in_container import pyflyte
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder
from flytekit.image_spec.image_spec import ImageBuildEngine

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_spec_wf.py")


def test_build():
class TestImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
...

ImageBuildEngine.register("test", TestImageSpecBuilder())
def test_build(mock_image_spec_builder):
ImageBuildEngine.register("test", mock_image_spec_builder)
runner = CliRunner()
result = runner.invoke(pyflyte.main, ["build", "--fast", WORKFLOW_FILE, "wf"])
assert result.exit_code == 0
Expand Down
13 changes: 5 additions & 8 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flytekit.clis.sdk_in_container.run import RunLevelParams, get_entities_in_file, run_command
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.task import task
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.interaction.click_types import FileParamType
from flytekit.remote import FlyteRemote

Expand Down Expand Up @@ -308,14 +308,11 @@ def test_list_default_arguments(wf_path):
os.environ.get("GITHUB_ACTIONS") == "true" and sys.platform == "darwin",
reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777",
)
def test_pyflyte_run_run(mock_image, image_string, leaf_configuration_file_name, final_image_config):
def test_pyflyte_run_run(
mock_image, image_string, leaf_configuration_file_name, final_image_config, mock_image_spec_builder
):
mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest"

class TestImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
...

ImageBuildEngine.register("test", TestImageSpecBuilder())
ImageBuildEngine.register("test", mock_image_spec_builder)

@task
def tk():
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from flytekit.image_spec.image_spec import ImageSpecBuilder


class MockImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
print("Building an image...")


@pytest.fixture()
def mock_image_spec_builder():
return MockImageSpecBuilder()
10 changes: 3 additions & 7 deletions tests/flytekit/unit/core/image_spec/test_image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from flytekit.core import context_manager
from flytekit.core.context_manager import ExecutionState
from flytekit.image_spec import ImageSpec
from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, ImageSpecBuilder, calculate_hash_from_image_spec
from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, calculate_hash_from_image_spec

REQUIREMENT_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
REGISTRY_CONFIG_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "registry_config.json")


def test_image_spec():
def test_image_spec(mock_image_spec_builder):
image_spec = ImageSpec(
name="FLYTEKIT",
builder="dummy",
Expand Down Expand Up @@ -51,11 +51,7 @@ def test_image_spec():
os.environ[_F_IMG_ID] = "flytekit:123"
assert image_spec.is_container() is False

class DummyImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
...

ImageBuildEngine.register("dummy", DummyImageSpecBuilder())
ImageBuildEngine.register("dummy", mock_image_spec_builder)
ImageBuildEngine.build(image_spec)

assert "dummy" in ImageBuildEngine._REGISTRY
Expand Down
34 changes: 26 additions & 8 deletions tests/flytekit/unit/core/test_container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.container_task import ContainerTask
from flytekit.core.pod_template import PodTemplate
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, ImageSpecBuilder
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.tools.translator import get_serializable_task


Expand Down Expand Up @@ -97,14 +97,32 @@ def test_local_execution():
ct()


class MockImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
print("Building an image...")
def test_raw_container_with_image_spec(mock_image_spec_builder):
ImageBuildEngine.register("test-raw-container", mock_image_spec_builder)
image_spec = ImageSpec(registry="flyte", base_image="r-base", builder="test-raw-container")

calculate_ellipse_area_r = ContainerTask(
name="ellipse-area-metadata-r",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=float, b=float),
outputs=kwtypes(area=float, metadata=str),
image=image_spec,
command=[
"Rscript",
"--vanilla",
"/root/calculate-ellipse-area.R",
"{{.inputs.a}}",
"{{.inputs.b}}",
"/var/outputs",
],
)

@pytest.fixture()
def mock_image_spec_builder():
return MockImageSpecBuilder()
default_serialization_settings = SerializationSettings(
project="p", domain="d", version="v", image_config=ImageConfig.auto()
)
container = calculate_ellipse_area_r.get_container(default_serialization_settings)
assert container.image == image_spec.image_name()


def test_container_task_image_spec(mock_image_spec_builder):
Expand Down Expand Up @@ -152,7 +170,7 @@ def test_container_task_image_spec(mock_image_spec_builder):
ct = ContainerTask(
name="x",
image="ddd",
command="ccc",
command=["ccc"],
pod_template=pt,
)
ImageBuildEngine.register("test", mock_image_spec_builder)
Expand Down
12 changes: 1 addition & 11 deletions tests/flytekit/unit/core/test_python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image
from flytekit.core.resources import Resources
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, ImageSpecBuilder
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.tools.translator import get_serializable_task


Expand All @@ -30,16 +30,6 @@ def minimal_serialization_settings(default_image_config):
return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config)


class MockImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
print("Building an image...")


@pytest.fixture()
def mock_image_spec_builder():
return MockImageSpecBuilder()


def test_image_name_interpolation(default_image_config):
img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special"
img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config)
Expand Down
10 changes: 3 additions & 7 deletions tests/flytekit/unit/core/test_python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.tracker import isnested, istestfunction
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, ImageSpecBuilder
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.tools.translator import get_serializable_task
from tests.flytekit.unit.core import tasks

Expand All @@ -34,7 +34,7 @@ def test_istestfunction():
assert istestfunction(tasks.tasks) is False


def test_container_image_conversion():
def test_container_image_conversion(mock_image_spec_builder):
default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1")
other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other")
cfg = ImageConfig(default_image=default_img, images=[default_img, other_img])
Expand Down Expand Up @@ -68,11 +68,7 @@ def test_container_image_conversion():

assert get_registerable_container_image("{{.image.default}}", cfg) == "xyz.com/abc:tag1"

class TestImageSpecBuilder(ImageSpecBuilder):
def build_image(self, img):
...

ImageBuildEngine.register("test", TestImageSpecBuilder())
ImageBuildEngine.register("test", mock_image_spec_builder)
image_spec = ImageSpec(builder="test", python_version="3.7", registry="")
assert get_registerable_container_image(image_spec, cfg) == image_spec.image_name()

Expand Down

0 comments on commit e5c3c78

Please sign in to comment.