From d1d1bd911cacb36d732ecb57977ece7931076b50 Mon Sep 17 00:00:00 2001 From: Yue Shang <138256885+ysysys3074@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:59:54 -0800 Subject: [PATCH] [caching] Add overwrite_cache flag when creating launch plan (#2029) Signed-off-by: Yue Shang Co-authored-by: Kevin Su --- flytekit/core/launch_plan.py | 14 ++++++++++++++ flytekit/models/launch_plan.py | 8 ++++++++ flytekit/tools/translator.py | 2 ++ tests/flytekit/unit/core/test_launch_plan.py | 18 ++++++++++++++++++ tests/flytekit/unit/models/test_launch_plan.py | 3 +++ 5 files changed, 45 insertions(+) diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index a96f83b8ce..7f45287428 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -123,6 +123,7 @@ def create( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, + overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() default_inputs = default_inputs or {} @@ -173,6 +174,7 @@ def create( raw_output_data_config=raw_output_data_config, max_parallelism=max_parallelism, security_context=security_context, + overwrite_cache=overwrite_cache, ) # This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out @@ -201,6 +203,7 @@ def get_or_create( max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, + overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: """ This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not @@ -238,6 +241,7 @@ def get_or_create( or auth_role is not None or max_parallelism is not None or security_context is not None + or overwrite_cache is not None ): raise ValueError( "Only named launchplans can be created that have other properties. Drop the name if you want to create a default launchplan. Default launchplans cannot have any other associations" @@ -269,6 +273,7 @@ def get_or_create( or raw_output_data_config != cached_outputs["_raw_output_data_config"] or max_parallelism != cached_outputs["_max_parallelism"] or security_context != cached_outputs["_security_context"] + or overwrite_cache != cached_outputs["_overwrite_cache"] ): raise AssertionError("The cached values aren't the same as the current call arguments") @@ -294,6 +299,7 @@ def get_or_create( max_parallelism, auth_role=auth_role, security_context=security_context, + overwrite_cache=overwrite_cache, ) LaunchPlan.CACHE[name or workflow.name] = lp return lp @@ -311,6 +317,7 @@ def __init__( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + overwrite_cache: Optional[bool] = None, additional_metadata: Optional[Any] = None, ): self._name = name @@ -329,6 +336,7 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self._overwrite_cache = overwrite_cache self._additional_metadata = additional_metadata FlyteEntities.entities.append(self) @@ -345,6 +353,7 @@ def clone_with( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + overwrite_cache: Optional[bool] = None, ) -> LaunchPlan: return LaunchPlan( name=name, @@ -358,8 +367,13 @@ def clone_with( raw_output_data_config=raw_output_data_config or self.raw_output_data_config, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, + overwrite_cache=overwrite_cache or self.overwrite_cache, ) + @property + def overwrite_cache(self) -> Optional[bool]: + return self._overwrite_cache + @property def python_interface(self) -> Interface: return self.workflow.python_interface diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index 9f2af1b92e..2dc1a1947b 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -137,6 +137,7 @@ def __init__( raw_output_data_config: _common.RawOutputDataConfig, max_parallelism: typing.Optional[int] = None, security_context: typing.Optional[security.SecurityContext] = None, + overwrite_cache: typing.Optional[bool] = None, ): """ The spec for a Launch Plan. @@ -168,6 +169,7 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self._overwrite_cache = overwrite_cache @property def workflow_id(self): @@ -240,6 +242,10 @@ def max_parallelism(self) -> typing.Optional[int]: def security_context(self) -> typing.Optional[security.SecurityContext]: return self._security_context + @property + def overwrite_cache(self) -> typing.Optional[bool]: + return self._overwrite_cache + def to_flyte_idl(self): """ :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec @@ -255,6 +261,7 @@ def to_flyte_idl(self): raw_output_data_config=self.raw_output_data_config.to_flyte_idl(), max_parallelism=self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, + overwrite_cache=self.overwrite_cache if self.overwrite_cache else None, ) @classmethod @@ -287,6 +294,7 @@ def from_flyte_idl(cls, pb2): security_context=security.SecurityContext.from_flyte_idl(pb2.security_context) if pb2.security_context else None, + overwrite_cache=pb2.overwrite_cache if pb2.overwrite_cache else None, ) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 1c2016b681..6d696bc4d6 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -84,6 +84,7 @@ class Options(object): max_parallelism: typing.Optional[int] = None notifications: typing.Optional[typing.List[common_models.Notification]] = None disable_notifications: typing.Optional[bool] = None + overwrite_cache: typing.Optional[bool] = None @classmethod def default_from( @@ -382,6 +383,7 @@ def get_serializable_launch_plan( raw_output_data_config=raw_prefix_config, max_parallelism=options.max_parallelism or entity.max_parallelism, security_context=options.security_context or entity.security_context, + overwrite_cache=options.overwrite_cache or entity.overwrite_cache, ) lp_id = _identifier_model.Identifier( diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index 50c35b224e..20426a254d 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -284,6 +284,20 @@ def wf(a: int, c: str) -> (int, str): ) assert max_parallelism_lp1 is max_parallelism_lp2 + # Overwrite cache + overwrite_cache = True + overwrite_cache_lp1 = launch_plan.LaunchPlan.get_or_create( + workflow=wf, + name="get_or_create_overwrite_cache", + overwrite_cache=overwrite_cache, + ) + overwrite_cache_lp2 = launch_plan.LaunchPlan.get_or_create( + workflow=wf, + name="get_or_create_overwrite_cache", + overwrite_cache=overwrite_cache, + ) + assert overwrite_cache_lp1 is overwrite_cache_lp2 + # Default LaunchPlan name_lp = launch_plan.LaunchPlan.get_or_create(workflow=wf) name_lp1 = launch_plan.LaunchPlan.get_or_create(workflow=wf) @@ -318,6 +332,7 @@ def wf(a: int, c: str) -> str: labels = Labels({"label": "foo"}) annotations = Annotations({"anno": "bar"}) raw_output_data_config = RawOutputDataConfig("s3://foo/output") + overwrite_cache = True lp = launch_plan.LaunchPlan.get_or_create( workflow=wf, @@ -330,6 +345,7 @@ def wf(a: int, c: str) -> str: labels=labels, annotations=annotations, raw_output_data_config=raw_output_data_config, + overwrite_cache=overwrite_cache, ) lp2 = launch_plan.LaunchPlan.get_or_create( workflow=wf, @@ -342,6 +358,7 @@ def wf(a: int, c: str) -> str: labels=labels, annotations=annotations, raw_output_data_config=raw_output_data_config, + overwrite_cache=overwrite_cache, ) assert lp is lp2 @@ -358,6 +375,7 @@ def wf(a: int, c: str) -> str: labels=labels, annotations=annotations, raw_output_data_config=raw_output_data_config, + overwrite_cache=overwrite_cache, ) diff --git a/tests/flytekit/unit/models/test_launch_plan.py b/tests/flytekit/unit/models/test_launch_plan.py index 6c0b8b2f13..b15af0c9a7 100644 --- a/tests/flytekit/unit/models/test_launch_plan.py +++ b/tests/flytekit/unit/models/test_launch_plan.py @@ -60,6 +60,7 @@ def test_launch_plan_spec(): raw_data_output_config = common.RawOutputDataConfig("s3://bucket") empty_raw_data_output_config = common.RawOutputDataConfig("") max_parallelism = 100 + overwrite_cache = True lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec( identifier_model, @@ -71,6 +72,7 @@ def test_launch_plan_spec(): auth_role_model, raw_data_output_config, max_parallelism, + overwrite_cache=overwrite_cache, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(lp_spec_raw_output_prefixed.to_flyte_idl()) @@ -86,6 +88,7 @@ def test_launch_plan_spec(): auth_role_model, empty_raw_data_output_config, max_parallelism, + overwrite_cache=overwrite_cache, ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(lp_spec_no_prefix.to_flyte_idl())