Skip to content

Commit

Permalink
[caching] Add overwrite_cache flag when creating launch plan (#2029)
Browse files Browse the repository at this point in the history
Signed-off-by: Yue Shang <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
ysysys3074 and pingsutw authored Feb 21, 2024
1 parent 72252e7 commit d1d1bd9
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 0 deletions.
14 changes: 14 additions & 0 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def create(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
ctx = FlyteContextManager.current_context()
default_inputs = default_inputs or {}
Expand Down Expand Up @@ -173,6 +174,7 @@ def create(
raw_output_data_config=raw_output_data_config,
max_parallelism=max_parallelism,
security_context=security_context,
overwrite_cache=overwrite_cache,
)

# This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
Expand Down Expand Up @@ -201,6 +203,7 @@ def get_or_create(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
"""
This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not
Expand Down Expand Up @@ -238,6 +241,7 @@ def get_or_create(
or auth_role is not None
or max_parallelism is not None
or security_context is not None
or overwrite_cache is not None
):
raise ValueError(
"Only named launchplans can be created that have other properties. Drop the name if you want to create a default launchplan. Default launchplans cannot have any other associations"
Expand Down Expand Up @@ -269,6 +273,7 @@ def get_or_create(
or raw_output_data_config != cached_outputs["_raw_output_data_config"]
or max_parallelism != cached_outputs["_max_parallelism"]
or security_context != cached_outputs["_security_context"]
or overwrite_cache != cached_outputs["_overwrite_cache"]
):
raise AssertionError("The cached values aren't the same as the current call arguments")

Expand All @@ -294,6 +299,7 @@ def get_or_create(
max_parallelism,
auth_role=auth_role,
security_context=security_context,
overwrite_cache=overwrite_cache,
)
LaunchPlan.CACHE[name or workflow.name] = lp
return lp
Expand All @@ -311,6 +317,7 @@ def __init__(
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
overwrite_cache: Optional[bool] = None,
additional_metadata: Optional[Any] = None,
):
self._name = name
Expand All @@ -329,6 +336,7 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._security_context = security_context
self._overwrite_cache = overwrite_cache
self._additional_metadata = additional_metadata

FlyteEntities.entities.append(self)
Expand All @@ -345,6 +353,7 @@ def clone_with(
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
return LaunchPlan(
name=name,
Expand All @@ -358,8 +367,13 @@ def clone_with(
raw_output_data_config=raw_output_data_config or self.raw_output_data_config,
max_parallelism=max_parallelism or self.max_parallelism,
security_context=security_context or self.security_context,
overwrite_cache=overwrite_cache or self.overwrite_cache,
)

@property
def overwrite_cache(self) -> Optional[bool]:
return self._overwrite_cache

@property
def python_interface(self) -> Interface:
return self.workflow.python_interface
Expand Down
8 changes: 8 additions & 0 deletions flytekit/models/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
raw_output_data_config: _common.RawOutputDataConfig,
max_parallelism: typing.Optional[int] = None,
security_context: typing.Optional[security.SecurityContext] = None,
overwrite_cache: typing.Optional[bool] = None,
):
"""
The spec for a Launch Plan.
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._security_context = security_context
self._overwrite_cache = overwrite_cache

@property
def workflow_id(self):
Expand Down Expand Up @@ -240,6 +242,10 @@ def max_parallelism(self) -> typing.Optional[int]:
def security_context(self) -> typing.Optional[security.SecurityContext]:
return self._security_context

@property
def overwrite_cache(self) -> typing.Optional[bool]:
return self._overwrite_cache

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
Expand All @@ -255,6 +261,7 @@ def to_flyte_idl(self):
raw_output_data_config=self.raw_output_data_config.to_flyte_idl(),
max_parallelism=self.max_parallelism,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache if self.overwrite_cache else None,
)

@classmethod
Expand Down Expand Up @@ -287,6 +294,7 @@ def from_flyte_idl(cls, pb2):
security_context=security.SecurityContext.from_flyte_idl(pb2.security_context)
if pb2.security_context
else None,
overwrite_cache=pb2.overwrite_cache if pb2.overwrite_cache else None,
)


Expand Down
2 changes: 2 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class Options(object):
max_parallelism: typing.Optional[int] = None
notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None

@classmethod
def default_from(
Expand Down Expand Up @@ -382,6 +383,7 @@ def get_serializable_launch_plan(
raw_output_data_config=raw_prefix_config,
max_parallelism=options.max_parallelism or entity.max_parallelism,
security_context=options.security_context or entity.security_context,
overwrite_cache=options.overwrite_cache or entity.overwrite_cache,
)

lp_id = _identifier_model.Identifier(
Expand Down
18 changes: 18 additions & 0 deletions tests/flytekit/unit/core/test_launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,20 @@ def wf(a: int, c: str) -> (int, str):
)
assert max_parallelism_lp1 is max_parallelism_lp2

# Overwrite cache
overwrite_cache = True
overwrite_cache_lp1 = launch_plan.LaunchPlan.get_or_create(
workflow=wf,
name="get_or_create_overwrite_cache",
overwrite_cache=overwrite_cache,
)
overwrite_cache_lp2 = launch_plan.LaunchPlan.get_or_create(
workflow=wf,
name="get_or_create_overwrite_cache",
overwrite_cache=overwrite_cache,
)
assert overwrite_cache_lp1 is overwrite_cache_lp2

# Default LaunchPlan
name_lp = launch_plan.LaunchPlan.get_or_create(workflow=wf)
name_lp1 = launch_plan.LaunchPlan.get_or_create(workflow=wf)
Expand Down Expand Up @@ -318,6 +332,7 @@ def wf(a: int, c: str) -> str:
labels = Labels({"label": "foo"})
annotations = Annotations({"anno": "bar"})
raw_output_data_config = RawOutputDataConfig("s3://foo/output")
overwrite_cache = True

lp = launch_plan.LaunchPlan.get_or_create(
workflow=wf,
Expand All @@ -330,6 +345,7 @@ def wf(a: int, c: str) -> str:
labels=labels,
annotations=annotations,
raw_output_data_config=raw_output_data_config,
overwrite_cache=overwrite_cache,
)
lp2 = launch_plan.LaunchPlan.get_or_create(
workflow=wf,
Expand All @@ -342,6 +358,7 @@ def wf(a: int, c: str) -> str:
labels=labels,
annotations=annotations,
raw_output_data_config=raw_output_data_config,
overwrite_cache=overwrite_cache,
)

assert lp is lp2
Expand All @@ -358,6 +375,7 @@ def wf(a: int, c: str) -> str:
labels=labels,
annotations=annotations,
raw_output_data_config=raw_output_data_config,
overwrite_cache=overwrite_cache,
)


Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/unit/models/test_launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_launch_plan_spec():
raw_data_output_config = common.RawOutputDataConfig("s3://bucket")
empty_raw_data_output_config = common.RawOutputDataConfig("")
max_parallelism = 100
overwrite_cache = True

lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec(
identifier_model,
Expand All @@ -71,6 +72,7 @@ def test_launch_plan_spec():
auth_role_model,
raw_data_output_config,
max_parallelism,
overwrite_cache=overwrite_cache,
)

obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(lp_spec_raw_output_prefixed.to_flyte_idl())
Expand All @@ -86,6 +88,7 @@ def test_launch_plan_spec():
auth_role_model,
empty_raw_data_output_config,
max_parallelism,
overwrite_cache=overwrite_cache,
)

obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(lp_spec_no_prefix.to_flyte_idl())
Expand Down

0 comments on commit d1d1bd9

Please sign in to comment.