Skip to content

Commit

Permalink
Run active launchplan when available to launch, else run the latest o…
Browse files Browse the repository at this point in the history
…ne (#2796)

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored Oct 12, 2024
1 parent 5d1642a commit 3f0b218
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 53 deletions.
41 changes: 20 additions & 21 deletions flytekit/clis/sdk_in_container/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,34 @@ def launchplan(
remote: FlyteRemote = get_and_save_remote_with_click_context(ctx, project="flytesnacks", domain="development")

console = Console()
lps = []
title = f"LaunchPlans for {project}/{domain}"
if launchplan_name:
if not version:
if version:
lp = remote.client.get_launch_plan(
Identifier(ResourceType.LAUNCH_PLAN, project, domain, launchplan_name, version)
)
j = MessageToJson(lp.to_flyte_idl())
print(j)
return
else:
lps, _ = remote.client.list_launch_plans_paginated(
NamedEntityIdentifier(project, domain, name=launchplan_name),
limit=1,
limit=limit,
sort_by=Sort(key="updated_at", direction=Sort.Direction.DESCENDING),
)
if len(lps) > 0:
version = lps[0].id.version
lp = remote.client.get_launch_plan(
Identifier(ResourceType.LAUNCH_PLAN, project, domain, launchplan_name, version)
)
j = MessageToJson(lp.to_flyte_idl())
print(j)
return

title = f"LaunchPlans for {project}/{domain}"
if active_only:
title += " (active only)"
lps, _ = remote.client.list_active_launch_plans_paginated(project, domain, limit=limit)
else:
lps, _ = remote.client.list_launch_plans_paginated(
NamedEntityIdentifier(project, domain),
limit=limit,
sort_by=Sort(key="updated_at", direction=Sort.Direction.DESCENDING),
)
if active_only:
title += " (active only)"
lps, _ = remote.client.list_active_launch_plans_paginated(project, domain, limit=limit)
else:
lps, _ = remote.client.list_launch_plans_paginated(
NamedEntityIdentifier(project, domain),
limit=limit,
sort_by=Sort(key="updated_at", direction=Sort.Direction.DESCENDING),
)

table = Table(title=title)

table.add_column("Name", justify="right", style="cyan")
table.add_column("Version", justify="right", style="cyan")
table.add_column("State", justify="right", style="green")
Expand Down
80 changes: 60 additions & 20 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import yaml
from click import Context
from mashumaro.codecs.json import JSONEncoder
from rich.progress import Progress
from rich.progress import Progress, TextColumn, TimeElapsedColumn
from typing_extensions import get_origin

from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal
from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal, WorkflowExecutionPhase
from flytekit.clis.sdk_in_container.helpers import (
parse_copy,
patch_image_config,
Expand Down Expand Up @@ -499,21 +499,28 @@ def run_remote(
Helper method that executes the given remote FlyteLaunchplan, FlyteWorkflow or FlyteTask
"""

execution = remote.execute(
entity,
inputs=inputs,
project=project,
domain=domain,
execution_name=run_level_params.name,
wait=run_level_params.wait_execution,
options=options_from_run_params(run_level_params),
type_hints=type_hints,
overwrite_cache=run_level_params.overwrite_cache,
envs=run_level_params.envvars,
tags=run_level_params.tags,
cluster_pool=run_level_params.cluster_pool,
execution_cluster_label=run_level_params.execution_cluster_label,
)
msg = "Running execution on remote."
if run_level_params.wait_execution:
msg += " Waiting to complete..."
p = Progress(TimeElapsedColumn(), TextColumn(msg), transient=True)
t = p.add_task("exec")
with p:
p.start_task(t)
execution = remote.execute(
entity,
inputs=inputs,
project=project,
domain=domain,
execution_name=run_level_params.name,
wait=run_level_params.wait_execution,
options=options_from_run_params(run_level_params),
type_hints=type_hints,
overwrite_cache=run_level_params.overwrite_cache,
envs=run_level_params.envvars,
tags=run_level_params.tags,
cluster_pool=run_level_params.cluster_pool,
execution_cluster_label=run_level_params.execution_cluster_label,
)

console_url = remote.generate_console_url(execution)
s = (
Expand All @@ -524,6 +531,19 @@ def run_remote(
)
click.echo(s)

if run_level_params.wait_execution:
if execution.closure.phase != WorkflowExecutionPhase.SUCCEEDED:
click.secho(
f"Execution {execution.id.name} did not complete successfully, "
f"phase {WorkflowExecutionPhase.enum_to_string(execution.closure.phase)}",
fg="red",
)
if execution.closure.error:
click.secho(f"{execution.closure.error.message}", fg="red")
sys.exit(-1)
else:
click.secho(f"Execution {execution.id.name} has succeeded.", fg="green")

if run_level_params.dump_snippet:
dump_flyte_remote_snippet(execution, project, domain)

Expand Down Expand Up @@ -718,9 +738,26 @@ def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, Fly
run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
if self._launcher == self.LP_LAUNCHER:
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name)
parts = self._entity_name.split(":")
if len(parts) == 2:
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, parts[0], parts[1])
else:
entity = r.fetch_active_launchplan(run_level_params.project, run_level_params.domain, self._entity_name)
if not entity:
click.echo(
click.style(
f"No active launch plan found with name {self._entity_name},"
f" using the latest version by created time.",
fg="yellow",
)
)
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name)
else:
entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name)
parts = self._entity_name.split(":")
if len(parts) == 2:
entity = r.fetch_task(run_level_params.project, run_level_params.domain, parts[0], parts[1])
else:
entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name)
self._entity = entity
return entity

Expand Down Expand Up @@ -798,7 +835,10 @@ class RemoteEntityGroup(click.RichGroup):
def __init__(self, command_name: str):
super().__init__(
name=command_name,
help=f"Retrieve {command_name} from a remote flyte instance and execute them.",
help=f"Retrieve {command_name} from a remote flyte instance and execute them. The command only lists the "
f"names of the entities, but it is possible to pass in a specific version of the entity if known in "
f"the format <name>:<version>. If version is not provided, the latest version is used for tasks and "
f"active or latest version is used for launchplans.",
)
self._command_name = command_name
self._entities = []
Expand Down
38 changes: 30 additions & 8 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.common import Sort
from flytekit.models.common import NamedEntityIdentifier
from flytekit.models.core import identifier as id_models
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
Expand Down Expand Up @@ -462,6 +463,34 @@ def get_launch_plan_from_then_node(
get_launch_plan_from_branch(node.branch_node, node_launch_plans)
return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)

def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchPlan:
"""
Given a remote returned launchplan, upgrade it to a FlyteLaunchPlan object that holds the interface and
the FlyteWorkflow object. This can be used in the SDK.
"""
flyte_lp = FlyteLaunchPlan.promote_from_model(lp.id, lp.spec)
wf_id = flyte_lp.workflow_id
workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
flyte_lp._interface = workflow.interface
flyte_lp._flyte_workflow = workflow
return flyte_lp

def fetch_active_launchplan(
self, project: str = None, domain: str = None, name: str = None
) -> typing.Optional[FlyteLaunchPlan]:
"""
Returns the active version of the launch plan if it exists or returns None
"""
try:
lp = self.client.get_active_launch_plan(
NamedEntityIdentifier(project or self.default_project, domain or self.default_domain, name)
)
if lp is not None:
return self._upgrade_launchplan(lp)
except FlyteEntityNotExistException as e:
logger.debug(f"Launch plan not found, error:{str(e)}")
return None

def fetch_launch_plan(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> FlyteLaunchPlan:
Expand All @@ -486,14 +515,7 @@ def fetch_launch_plan(
version,
)
admin_launch_plan = self.client.get_launch_plan(launch_plan_id)
flyte_launch_plan = FlyteLaunchPlan.promote_from_model(launch_plan_id, admin_launch_plan.spec)

wf_id = flyte_launch_plan.workflow_id
workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
flyte_launch_plan._interface = workflow.interface
flyte_launch_plan._flyte_workflow = workflow

return flyte_launch_plan
return self._upgrade_launchplan(admin_launch_plan)

def fetch_execution(self, project: str = None, domain: str = None, name: str = None) -> FlyteWorkflowExecution:
"""Fetch a workflow execution entity from flyte admin.
Expand Down
9 changes: 8 additions & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def test_get_download_artifact_signed_url(register):

# Check if the signed URL is valid and starts with the expected prefix
signed_url = download_link_response.signed_url[0]
assert signed_url.startswith(f"http://localhost:30002/my-s3-bucket/metadata/propeller/{project}-{domain}-{name}/n0/data/0/deck.html")
assert signed_url.startswith(
f"http://localhost:30002/my-s3-bucket/metadata/propeller/{project}-{domain}-{name}/n0/data/0/deck.html")


def test_fetch_execute_launch_plan_with_args(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
Expand Down Expand Up @@ -755,3 +757,8 @@ def test_register_wf_fast(register):
subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions
subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}


def test_fetch_active_launchplan_not_found(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None
12 changes: 9 additions & 3 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import FlyteEntityNotExistException
from flytekit.models import common as common_models
from flytekit.models import security
from flytekit.models.admin.workflow import Workflow, WorkflowClosure
Expand Down Expand Up @@ -54,7 +55,6 @@
ResourceType.LAUNCH_PLAN: "Launch Plan",
}


obj = _workflow.Node(
id="some:node:id",
metadata="1",
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_fetch_workflow_with_nested_branch(mock_promote, mock_workflow, remote):
@mock.patch("flytekit.remote.remote.compress_scripts")
@pytest.mark.serial
def test_get_image_names(
compress_scripts_mock, upload_file_mock, register_workflow_mock, version_from_hash_mock, read_bytes_mock
compress_scripts_mock, upload_file_mock, register_workflow_mock, version_from_hash_mock, read_bytes_mock
):
md5_bytes = bytes([1, 2, 3])
read_bytes_mock.return_value = bytes([4, 5, 6])
Expand Down Expand Up @@ -603,7 +603,7 @@ def test_execution_name(mock_client, mock_uuid):
]
)
with pytest.raises(
ValueError, match="Only one of execution_name and execution_name_prefix can be set, but got both set"
ValueError, match="Only one of execution_name and execution_name_prefix can be set, but got both set"
):
remote._execute(
entity=ft,
Expand Down Expand Up @@ -684,3 +684,9 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist
source_path=str(pathlib.Path(flytekit.__file__).parent.parent),
module_name="tests.flytekit.unit.remote.resources",
)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_fetch_active_launchplan_not_found(mock_client, remote):
mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None

0 comments on commit 3f0b218

Please sign in to comment.