From 3f0b218987ada75a30e03bb4916325493d4e1938 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Fri, 11 Oct 2024 18:23:14 -0700 Subject: [PATCH] Run active launchplan when available to launch, else run the latest one (#2796) Signed-off-by: Ketan Umare --- flytekit/clis/sdk_in_container/get.py | 41 +++++----- flytekit/clis/sdk_in_container/run.py | 80 ++++++++++++++----- flytekit/remote/remote.py | 38 +++++++-- .../integration/remote/test_remote.py | 9 ++- tests/flytekit/unit/remote/test_remote.py | 12 ++- 5 files changed, 127 insertions(+), 53 deletions(-) diff --git a/flytekit/clis/sdk_in_container/get.py b/flytekit/clis/sdk_in_container/get.py index 473dd82b15..c88b20203d 100644 --- a/flytekit/clis/sdk_in_container/get.py +++ b/flytekit/clis/sdk_in_container/get.py @@ -40,35 +40,34 @@ def launchplan( remote: FlyteRemote = get_and_save_remote_with_click_context(ctx, project="flytesnacks", domain="development") console = Console() + lps = [] + title = f"LaunchPlans for {project}/{domain}" if launchplan_name: - if not version: + if version: + lp = remote.client.get_launch_plan( + Identifier(ResourceType.LAUNCH_PLAN, project, domain, launchplan_name, version) + ) + j = MessageToJson(lp.to_flyte_idl()) + print(j) + return + else: lps, _ = remote.client.list_launch_plans_paginated( NamedEntityIdentifier(project, domain, name=launchplan_name), - limit=1, + limit=limit, sort_by=Sort(key="updated_at", direction=Sort.Direction.DESCENDING), ) - if len(lps) > 0: - version = lps[0].id.version - lp = remote.client.get_launch_plan( - Identifier(ResourceType.LAUNCH_PLAN, project, domain, launchplan_name, version) - ) - j = MessageToJson(lp.to_flyte_idl()) - print(j) - return - - title = f"LaunchPlans for {project}/{domain}" - if active_only: - title += " (active only)" - lps, _ = remote.client.list_active_launch_plans_paginated(project, domain, limit=limit) else: - lps, _ = remote.client.list_launch_plans_paginated( - NamedEntityIdentifier(project, domain), - limit=limit, - sort_by=Sort(key="updated_at", direction=Sort.Direction.DESCENDING), - ) + if active_only: + title += " (active only)" + lps, _ = remote.client.list_active_launch_plans_paginated(project, domain, limit=limit) + else: + lps, _ = remote.client.list_launch_plans_paginated( + NamedEntityIdentifier(project, domain), + limit=limit, + sort_by=Sort(key="updated_at", direction=Sort.Direction.DESCENDING), + ) table = Table(title=title) - table.add_column("Name", justify="right", style="cyan") table.add_column("Version", justify="right", style="cyan") table.add_column("State", justify="right", style="green") diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index eac6dcbc6b..9bdfb12cc7 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -15,10 +15,10 @@ import yaml from click import Context from mashumaro.codecs.json import JSONEncoder -from rich.progress import Progress +from rich.progress import Progress, TextColumn, TimeElapsedColumn from typing_extensions import get_origin -from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal +from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal, WorkflowExecutionPhase from flytekit.clis.sdk_in_container.helpers import ( parse_copy, patch_image_config, @@ -499,21 +499,28 @@ def run_remote( Helper method that executes the given remote FlyteLaunchplan, FlyteWorkflow or FlyteTask """ - execution = remote.execute( - entity, - inputs=inputs, - project=project, - domain=domain, - execution_name=run_level_params.name, - wait=run_level_params.wait_execution, - options=options_from_run_params(run_level_params), - type_hints=type_hints, - overwrite_cache=run_level_params.overwrite_cache, - envs=run_level_params.envvars, - tags=run_level_params.tags, - cluster_pool=run_level_params.cluster_pool, - execution_cluster_label=run_level_params.execution_cluster_label, - ) + msg = "Running execution on remote." + if run_level_params.wait_execution: + msg += " Waiting to complete..." + p = Progress(TimeElapsedColumn(), TextColumn(msg), transient=True) + t = p.add_task("exec") + with p: + p.start_task(t) + execution = remote.execute( + entity, + inputs=inputs, + project=project, + domain=domain, + execution_name=run_level_params.name, + wait=run_level_params.wait_execution, + options=options_from_run_params(run_level_params), + type_hints=type_hints, + overwrite_cache=run_level_params.overwrite_cache, + envs=run_level_params.envvars, + tags=run_level_params.tags, + cluster_pool=run_level_params.cluster_pool, + execution_cluster_label=run_level_params.execution_cluster_label, + ) console_url = remote.generate_console_url(execution) s = ( @@ -524,6 +531,19 @@ def run_remote( ) click.echo(s) + if run_level_params.wait_execution: + if execution.closure.phase != WorkflowExecutionPhase.SUCCEEDED: + click.secho( + f"Execution {execution.id.name} did not complete successfully, " + f"phase {WorkflowExecutionPhase.enum_to_string(execution.closure.phase)}", + fg="red", + ) + if execution.closure.error: + click.secho(f"{execution.closure.error.message}", fg="red") + sys.exit(-1) + else: + click.secho(f"Execution {execution.id.name} has succeeded.", fg="green") + if run_level_params.dump_snippet: dump_flyte_remote_snippet(execution, project, domain) @@ -718,9 +738,26 @@ def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, Fly run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() if self._launcher == self.LP_LAUNCHER: - entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name) + parts = self._entity_name.split(":") + if len(parts) == 2: + entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, parts[0], parts[1]) + else: + entity = r.fetch_active_launchplan(run_level_params.project, run_level_params.domain, self._entity_name) + if not entity: + click.echo( + click.style( + f"No active launch plan found with name {self._entity_name}," + f" using the latest version by created time.", + fg="yellow", + ) + ) + entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name) else: - entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name) + parts = self._entity_name.split(":") + if len(parts) == 2: + entity = r.fetch_task(run_level_params.project, run_level_params.domain, parts[0], parts[1]) + else: + entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name) self._entity = entity return entity @@ -798,7 +835,10 @@ class RemoteEntityGroup(click.RichGroup): def __init__(self, command_name: str): super().__init__( name=command_name, - help=f"Retrieve {command_name} from a remote flyte instance and execute them.", + help=f"Retrieve {command_name} from a remote flyte instance and execute them. The command only lists the " + f"names of the entities, but it is possible to pass in a specific version of the entity if known in " + f"the format :. If version is not provided, the latest version is used for tasks and " + f"active or latest version is used for launchplans.", ) self._command_name = command_name self._entities = [] diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 319b67d0d5..0528d0d155 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -64,6 +64,7 @@ from flytekit.models.admin import common as admin_common_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.common import Sort +from flytekit.models.common import NamedEntityIdentifier from flytekit.models.core import identifier as id_models from flytekit.models.core import workflow as workflow_model from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier @@ -462,6 +463,34 @@ def get_launch_plan_from_then_node( get_launch_plan_from_branch(node.branch_node, node_launch_plans) return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans) + def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchPlan: + """ + Given a remote returned launchplan, upgrade it to a FlyteLaunchPlan object that holds the interface and + the FlyteWorkflow object. This can be used in the SDK. + """ + flyte_lp = FlyteLaunchPlan.promote_from_model(lp.id, lp.spec) + wf_id = flyte_lp.workflow_id + workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) + flyte_lp._interface = workflow.interface + flyte_lp._flyte_workflow = workflow + return flyte_lp + + def fetch_active_launchplan( + self, project: str = None, domain: str = None, name: str = None + ) -> typing.Optional[FlyteLaunchPlan]: + """ + Returns the active version of the launch plan if it exists or returns None + """ + try: + lp = self.client.get_active_launch_plan( + NamedEntityIdentifier(project or self.default_project, domain or self.default_domain, name) + ) + if lp is not None: + return self._upgrade_launchplan(lp) + except FlyteEntityNotExistException as e: + logger.debug(f"Launch plan not found, error:{str(e)}") + return None + def fetch_launch_plan( self, project: str = None, domain: str = None, name: str = None, version: str = None ) -> FlyteLaunchPlan: @@ -486,14 +515,7 @@ def fetch_launch_plan( version, ) admin_launch_plan = self.client.get_launch_plan(launch_plan_id) - flyte_launch_plan = FlyteLaunchPlan.promote_from_model(launch_plan_id, admin_launch_plan.spec) - - wf_id = flyte_launch_plan.workflow_id - workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version) - flyte_launch_plan._interface = workflow.interface - flyte_launch_plan._flyte_workflow = workflow - - return flyte_launch_plan + return self._upgrade_launchplan(admin_launch_plan) def fetch_execution(self, project: str = None, domain: str = None, name: str = None) -> FlyteWorkflowExecution: """Fetch a workflow execution entity from flyte admin. diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index f80a76b4c5..4d77e1b610 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -121,7 +121,9 @@ def test_get_download_artifact_signed_url(register): # Check if the signed URL is valid and starts with the expected prefix signed_url = download_link_response.signed_url[0] - assert signed_url.startswith(f"http://localhost:30002/my-s3-bucket/metadata/propeller/{project}-{domain}-{name}/n0/data/0/deck.html") + assert signed_url.startswith( + f"http://localhost:30002/my-s3-bucket/metadata/propeller/{project}-{domain}-{name}/n0/data/0/deck.html") + def test_fetch_execute_launch_plan_with_args(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) @@ -755,3 +757,8 @@ def test_register_wf_fast(register): subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"} + + +def test_fetch_active_launchplan_not_found(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 655cd5cc1c..98b50bbc2b 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -21,6 +21,7 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions +from flytekit.exceptions.user import FlyteEntityNotExistException from flytekit.models import common as common_models from flytekit.models import security from flytekit.models.admin.workflow import Workflow, WorkflowClosure @@ -54,7 +55,6 @@ ResourceType.LAUNCH_PLAN: "Launch Plan", } - obj = _workflow.Node( id="some:node:id", metadata="1", @@ -501,7 +501,7 @@ def test_fetch_workflow_with_nested_branch(mock_promote, mock_workflow, remote): @mock.patch("flytekit.remote.remote.compress_scripts") @pytest.mark.serial def test_get_image_names( - compress_scripts_mock, upload_file_mock, register_workflow_mock, version_from_hash_mock, read_bytes_mock + compress_scripts_mock, upload_file_mock, register_workflow_mock, version_from_hash_mock, read_bytes_mock ): md5_bytes = bytes([1, 2, 3]) read_bytes_mock.return_value = bytes([4, 5, 6]) @@ -603,7 +603,7 @@ def test_execution_name(mock_client, mock_uuid): ] ) with pytest.raises( - ValueError, match="Only one of execution_name and execution_name_prefix can be set, but got both set" + ValueError, match="Only one of execution_name and execution_name_prefix can be set, but got both set" ): remote._execute( entity=ft, @@ -684,3 +684,9 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist source_path=str(pathlib.Path(flytekit.__file__).parent.parent), module_name="tests.flytekit.unit.remote.resources", ) + + +@mock.patch("flytekit.remote.remote.FlyteRemote.client") +def test_fetch_active_launchplan_not_found(mock_client, remote): + mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found") + assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None