diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index f38d5533af..d606675299 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,13 +1,16 @@ +import typing from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, OrderedDict, Type from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.context_manager import FlyteContext from flytekit.core.interface import Interface from flytekit.core.pod_template import PodTemplate +from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.utils import _get_container_definition, _serialize_pod_spec +from flytekit.image_spec.image_spec import ImageSpec from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -37,9 +40,9 @@ class IOStrategy(Enum): def __init__( self, name: str, - image: str, + image: typing.Union[str, ImageSpec], command: List[str], - inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, + inputs: Optional[OrderedDict[str, Type]] = None, metadata: Optional[TaskMetadata] = None, arguments: Optional[List[str]] = None, outputs: Optional[Dict[str, Type]] = None, @@ -112,8 +115,11 @@ def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env + if isinstance(self._image, ImageSpec): + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + self._image.source_root = settings.source_root return _get_container_definition( - image=self._image, + image=get_registerable_container_image(self._image, settings.image_config), command=self._cmd, args=self._args, data_loading_config=self._get_data_loading_config(), diff --git a/tests/flytekit/unit/cli/pyflyte/test_build.py b/tests/flytekit/unit/cli/pyflyte/test_build.py index 7b4b26fb69..15ab09066f 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_build.py +++ b/tests/flytekit/unit/cli/pyflyte/test_build.py @@ -3,17 +3,13 @@ from click.testing import CliRunner from flytekit.clis.sdk_in_container import pyflyte -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder +from flytekit.image_spec.image_spec import ImageBuildEngine WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_spec_wf.py") -def test_build(): - class TestImageSpecBuilder(ImageSpecBuilder): - def build_image(self, img): - ... - - ImageBuildEngine.register("test", TestImageSpecBuilder()) +def test_build(mock_image_spec_builder): + ImageBuildEngine.register("test", mock_image_spec_builder) runner = CliRunner() result = runner.invoke(pyflyte.main, ["build", "--fast", WORKFLOW_FILE, "wf"]) assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 9048cce9b1..c9c04d0347 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -12,7 +12,7 @@ from flytekit.clis.sdk_in_container.run import RunLevelParams, get_entities_in_file, run_command from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder +from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.interaction.click_types import FileParamType from flytekit.remote import FlyteRemote @@ -308,14 +308,11 @@ def test_list_default_arguments(wf_path): os.environ.get("GITHUB_ACTIONS") == "true" and sys.platform == "darwin", reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777", ) -def test_pyflyte_run_run(mock_image, image_string, leaf_configuration_file_name, final_image_config): +def test_pyflyte_run_run( + mock_image, image_string, leaf_configuration_file_name, final_image_config, mock_image_spec_builder +): mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest" - - class TestImageSpecBuilder(ImageSpecBuilder): - def build_image(self, img): - ... - - ImageBuildEngine.register("test", TestImageSpecBuilder()) + ImageBuildEngine.register("test", mock_image_spec_builder) @task def tk(): diff --git a/tests/flytekit/unit/conftest.py b/tests/flytekit/unit/conftest.py new file mode 100644 index 0000000000..7414b5e064 --- /dev/null +++ b/tests/flytekit/unit/conftest.py @@ -0,0 +1,13 @@ +import pytest + +from flytekit.image_spec.image_spec import ImageSpecBuilder + + +class MockImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + print("Building an image...") + + +@pytest.fixture() +def mock_image_spec_builder(): + return MockImageSpecBuilder() diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py index f2f71f80b3..96f84d5d88 100644 --- a/tests/flytekit/unit/core/image_spec/test_image_spec.py +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -5,13 +5,13 @@ from flytekit.core import context_manager from flytekit.core.context_manager import ExecutionState from flytekit.image_spec import ImageSpec -from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, ImageSpecBuilder, calculate_hash_from_image_spec +from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, calculate_hash_from_image_spec REQUIREMENT_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") REGISTRY_CONFIG_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "registry_config.json") -def test_image_spec(): +def test_image_spec(mock_image_spec_builder): image_spec = ImageSpec( name="FLYTEKIT", builder="dummy", @@ -51,11 +51,7 @@ def test_image_spec(): os.environ[_F_IMG_ID] = "flytekit:123" assert image_spec.is_container() is False - class DummyImageSpecBuilder(ImageSpecBuilder): - def build_image(self, img): - ... - - ImageBuildEngine.register("dummy", DummyImageSpecBuilder()) + ImageBuildEngine.register("dummy", mock_image_spec_builder) ImageBuildEngine.build(image_spec) assert "dummy" in ImageBuildEngine._REGISTRY diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index 1c1b718638..3ac4a47cc4 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -15,7 +15,7 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.container_task import ContainerTask from flytekit.core.pod_template import PodTemplate -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, ImageSpecBuilder +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.tools.translator import get_serializable_task @@ -97,14 +97,32 @@ def test_local_execution(): ct() -class MockImageSpecBuilder(ImageSpecBuilder): - def build_image(self, img): - print("Building an image...") +def test_raw_container_with_image_spec(mock_image_spec_builder): + ImageBuildEngine.register("test-raw-container", mock_image_spec_builder) + image_spec = ImageSpec(registry="flyte", base_image="r-base", builder="test-raw-container") + calculate_ellipse_area_r = ContainerTask( + name="ellipse-area-metadata-r", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float, metadata=str), + image=image_spec, + command=[ + "Rscript", + "--vanilla", + "/root/calculate-ellipse-area.R", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ], + ) -@pytest.fixture() -def mock_image_spec_builder(): - return MockImageSpecBuilder() + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=ImageConfig.auto() + ) + container = calculate_ellipse_area_r.get_container(default_serialization_settings) + assert container.image == image_spec.image_name() def test_container_task_image_spec(mock_image_spec_builder): @@ -152,7 +170,7 @@ def test_container_task_image_spec(mock_image_spec_builder): ct = ContainerTask( name="x", image="ddd", - command="ccc", + command=["ccc"], pod_template=pt, ) ImageBuildEngine.register("test", mock_image_spec_builder) diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 2599866588..fed612c98c 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -8,7 +8,7 @@ from flytekit.core.pod_template import PodTemplate from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image from flytekit.core.resources import Resources -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, ImageSpecBuilder +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.tools.translator import get_serializable_task @@ -30,16 +30,6 @@ def minimal_serialization_settings(default_image_config): return SerializationSettings(project="p", domain="d", version="v", image_config=default_image_config) -class MockImageSpecBuilder(ImageSpecBuilder): - def build_image(self, img): - print("Building an image...") - - -@pytest.fixture() -def mock_image_spec_builder(): - return MockImageSpecBuilder() - - def test_image_name_interpolation(default_image_config): img_to_interpolate = "{{.image.default.fqn}}:{{.image.default.version}}-special" img = get_registerable_container_image(img=img_to_interpolate, cfg=default_image_config) diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 3e0b016efb..9ba9d4c780 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -7,7 +7,7 @@ from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.tracker import isnested, istestfunction -from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, ImageSpecBuilder +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.tools.translator import get_serializable_task from tests.flytekit.unit.core import tasks @@ -34,7 +34,7 @@ def test_istestfunction(): assert istestfunction(tasks.tasks) is False -def test_container_image_conversion(): +def test_container_image_conversion(mock_image_spec_builder): default_img = Image(name="default", fqn="xyz.com/abc", tag="tag1") other_img = Image(name="other", fqn="xyz.com/other", tag="tag-other") cfg = ImageConfig(default_image=default_img, images=[default_img, other_img]) @@ -68,11 +68,7 @@ def test_container_image_conversion(): assert get_registerable_container_image("{{.image.default}}", cfg) == "xyz.com/abc:tag1" - class TestImageSpecBuilder(ImageSpecBuilder): - def build_image(self, img): - ... - - ImageBuildEngine.register("test", TestImageSpecBuilder()) + ImageBuildEngine.register("test", mock_image_spec_builder) image_spec = ImageSpec(builder="test", python_version="3.7", registry="") assert get_registerable_container_image(image_spec, cfg) == image_spec.image_name()