diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py index 2bd3bab18c23..9552370dc406 100644 --- a/sdk/python/kfp/cli/compile_.py +++ b/sdk/python/kfp/cli/compile_.py @@ -22,6 +22,7 @@ import click from kfp import compiler +from kfp.cli.utils import parsing from kfp.dsl import base_component from kfp.dsl import graph_component @@ -133,12 +134,19 @@ def parse_parameters(parameters: Optional[str]) -> Dict: is_flag=True, default=False, help='Whether to disable type checking.') +@click.option( + '--enable-caching/--disable-caching', + type=bool, + default=None, + help=parsing.get_param_descr(compiler.Compiler.compile, 'enable_caching'), +) def compile_( py: str, output: str, function_name: Optional[str] = None, pipeline_parameters: Optional[str] = None, disable_type_check: bool = False, + enable_caching: Optional[bool] = None, ) -> None: """Compiles a pipeline or component written in a .py file.""" pipeline_func = collect_pipeline_or_component_func( @@ -149,7 +157,8 @@ def compile_( pipeline_func=pipeline_func, pipeline_parameters=parsed_parameters, package_path=package_path, - type_check=not disable_type_check) + type_check=not disable_type_check, + enable_caching=enable_caching) click.echo(package_path) diff --git a/sdk/python/kfp/client/client.py b/sdk/python/kfp/client/client.py index f88972363433..591b0200e4e5 100644 --- a/sdk/python/kfp/client/client.py +++ b/sdk/python/kfp/client/client.py @@ -955,7 +955,7 @@ def _create_job_config( # Caching option set at submission time overrides the compile time # settings. if enable_caching is not None: - _override_caching_options(pipeline_doc.pipeline_spec, + compiler.override_caching_options(pipeline_doc.pipeline_spec, enable_caching) pipeline_spec = pipeline_doc.to_dict() @@ -1676,17 +1676,3 @@ def _safe_load_yaml(stream: TextIO) -> _PipelineDoc: raise ValueError( f'The package_file {package_file} should end with one of the ' 'following formats: [.tar.gz, .tgz, .zip, .yaml, .yml].') - - -def _override_caching_options( - pipeline_spec: pipeline_spec_pb2.PipelineSpec, - enable_caching: bool, -) -> None: - """Overrides caching options. - - Args: - pipeline_spec: The PipelineSpec object to update in-place. - enable_caching: Overrides options, one of True, False. - """ - for _, task_spec in pipeline_spec.root.dag.tasks.items(): - task_spec.caching_options.enable_cache = enable_caching diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index a77f606e89c5..3f54b587f3fb 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -22,6 +22,7 @@ from kfp.compiler import pipeline_spec_builder as builder from kfp.dsl import base_component from kfp.dsl.types import type_utils +from kfp.pipeline_spec import pipeline_spec_pb2 class Compiler: @@ -53,6 +54,7 @@ def compile( pipeline_name: Optional[str] = None, pipeline_parameters: Optional[Dict[str, Any]] = None, type_check: bool = True, + enable_caching: Optional[bool] = None, ) -> None: """Compiles the pipeline or component function into IR YAML. @@ -62,6 +64,12 @@ def compile( pipeline_name: Name of the pipeline. pipeline_parameters: Map of parameter names to argument values. type_check: Whether to enable type checking of component interfaces during compilation. + enable_caching: Whether or not to enable caching for the + run. If not set, defaults to the compile time settings, which + is ``True`` for all tasks by default, while users may specify + different caching options for individual tasks. If set, the + setting applies to all tasks in the pipeline (overrides the + compile time settings). """ with type_utils.TypeCheckManager(enable=type_check): @@ -78,9 +86,26 @@ def compile( pipeline_parameters=pipeline_parameters, ) + if enable_caching is not None: + override_caching_options(pipeline_spec, enable_caching) + builder.write_pipeline_spec_to_file( pipeline_spec=pipeline_spec, pipeline_description=pipeline_func.description, platform_spec=pipeline_func.platform_spec, package_path=package_path, ) + + +def override_caching_options( + pipeline_spec: pipeline_spec_pb2.PipelineSpec, + enable_caching: bool, +) -> None: + """Overrides caching options. + + Args: + pipeline_spec: The PipelineSpec object to update in-place. + enable_caching: Overrides options, one of True, False. + """ + for _, task_spec in pipeline_spec.root.dag.tasks.items(): + task_spec.caching_options.enable_cache = enable_caching