Skip to content

Commit

Permalink
EMR serverless Create/Start/Stop/Delete Application deferrable mode (a…
Browse files Browse the repository at this point in the history
…pache#32513)

* Minor code refactoring

* Add type annotations

* update doc string about default value of deferrable
  • Loading branch information
syedahsn authored Jul 24, 2023
1 parent 282854b commit 1706f05
Show file tree
Hide file tree
Showing 6 changed files with 417 additions and 33 deletions.
26 changes: 17 additions & 9 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "emr-serverless"
super().__init__(*args, **kwargs)

def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
def cancel_running_jobs(
self, application_id: str, waiter_config: dict | None = None, wait_for_completion: bool = True
) -> int:
"""
List all jobs in an intermediate state, cancel them, then wait for those jobs to reach terminal state.
Cancel jobs in an intermediate state, and return the number of cancelled jobs.
If wait_for_completion is True, then the method will wait until all jobs are
cancelled before returning.
Note: if new jobs are triggered while this operation is ongoing,
it's going to time out and return an error.
Expand All @@ -284,13 +289,16 @@ def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
)
for job_id in job_ids:
self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
self.get_waiter("no_job_running").wait(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig=waiter_config,
)
if wait_for_completion:
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
self.get_waiter("no_job_running").wait(
applicationId=application_id,
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig=waiter_config or {},
)

return count


class EmrContainerHook(AwsBaseHook):
Expand Down
172 changes: 155 additions & 17 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
EmrServerlessCancelJobsTrigger,
EmrServerlessCreateApplicationTrigger,
EmrServerlessDeleteApplicationTrigger,
EmrServerlessStartApplicationTrigger,
EmrServerlessStartJobTrigger,
EmrServerlessStopApplicationTrigger,
EmrTerminateJobFlowTrigger,
)
from airflow.providers.amazon.aws.utils.waiter import waiter
Expand Down Expand Up @@ -974,7 +978,7 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
:param release_label: The EMR release version associated with the application.
:param job_type: The type of application you want to start, such as Spark or Hive.
:param wait_for_completion: If true, wait for the Application to start before returning. Default to True.
If set to False, ``waiter_countdown`` and ``waiter_check_interval_seconds`` will only be applied when
If set to False, ``waiter_max_attempts`` and ``waiter_delay`` will only be applied when
waiting for the application to be in the ``CREATED`` state.
:param client_request_token: The client idempotency token of the application to create.
Its value must be unique for each request.
Expand All @@ -987,6 +991,9 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
:waiter_max_attempts: Number of times the waiter should poll the application to check the state.
If not set, the waiter will use its default value.
:param waiter_delay: Number of seconds between polling the state of the application.
:param deferrable: If True, the operator will wait asynchronously for application to be created.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

def __init__(
Expand All @@ -1001,6 +1008,7 @@ def __init__(
waiter_check_interval_seconds: int | ArgNotSet = NOTSET,
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
Expand Down Expand Up @@ -1032,6 +1040,7 @@ def __init__(
self.config = config or {}
self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.deferrable = deferrable
super().__init__(**kwargs)

self.client_request_token = client_request_token or str(uuid4())
Expand All @@ -1054,8 +1063,19 @@ def execute(self, context: Context) -> str | None:
raise AirflowException(f"Application Creation failed: {response}")

self.log.info("EMR serverless application created: %s", application_id)
waiter = self.hook.get_waiter("serverless_app_created")
if self.deferrable:
self.defer(
trigger=EmrServerlessCreateApplicationTrigger(
application_id=application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="start_application_deferred",
)

waiter = self.hook.get_waiter("serverless_app_created")
wait(
waiter=waiter,
waiter_delay=self.waiter_delay,
Expand All @@ -1081,6 +1101,32 @@ def execute(self, context: Context) -> str | None:
)
return application_id

def start_application_deferred(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] != "success":
raise AirflowException(f"Application {event['application_id']} failed to create")
self.log.info("Starting application %s", event["application_id"])
self.hook.conn.start_application(applicationId=event["application_id"])
self.defer(
trigger=EmrServerlessStartApplicationTrigger(
application_id=event["application_id"],
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None or event["status"] != "success":
raise AirflowException(f"Trigger error: Application failed to start, event is {event}")

self.log.info("Application %s started", event["application_id"])
return event["application_id"]


class EmrServerlessStartJobOperator(BaseOperator):
"""
Expand Down Expand Up @@ -1312,14 +1358,21 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
:param application_id: ID of the EMR Serverless application to stop.
:param wait_for_completion: If true, wait for the Application to stop before returning. Default to True
:param aws_conn_id: AWS connection to use
:param waiter_countdown: Total amount of time, in seconds, the operator will wait for
:param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for
the application be stopped. Defaults to 5 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
Defaults to 30 seconds.
:param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state of the
application. Defaults to 60 seconds.
:param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled.
Otherwise, trying to stop an app with running jobs will return an error.
If you want to wait for the jobs to finish gracefully, use
:class:`airflow.providers.amazon.aws.sensors.emr.EmrServerlessJobSensor`
:waiter_max_attempts: Number of times the waiter should poll the application to check the state.
Default is 25.
:param waiter_delay: Number of seconds between polling the state of the application.
Default is 60 seconds.
:param deferrable: If True, the operator will wait asynchronously for the application to stop.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
"""

template_fields: Sequence[str] = ("application_id",)
Expand All @@ -1334,6 +1387,7 @@ def __init__(
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
Expand All @@ -1359,10 +1413,11 @@ def __init__(
)
self.aws_conn_id = aws_conn_id
self.application_id = application_id
self.wait_for_completion = wait_for_completion
self.wait_for_completion = False if deferrable else wait_for_completion
self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.force_stop = force_stop
self.deferrable = deferrable
super().__init__(**kwargs)

@cached_property
Expand All @@ -1374,16 +1429,46 @@ def execute(self, context: Context) -> None:
self.log.info("Stopping application: %s", self.application_id)

if self.force_stop:
self.hook.cancel_running_jobs(
self.application_id,
waiter_config={
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
},
count = self.hook.cancel_running_jobs(
application_id=self.application_id,
wait_for_completion=False,
)
if count > 0:
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
if self.deferrable:
self.defer(
trigger=EmrServerlessCancelJobsTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="stop_application",
)
self.hook.get_waiter("no_job_running").wait(
applicationId=self.application_id,
states=list(self.hook.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
WaiterConfig={
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
},
)
else:
self.log.info("no running jobs found with application ID %s", self.application_id)

self.hook.conn.stop_application(applicationId=self.application_id)

if self.deferrable:
self.defer(
trigger=EmrServerlessStopApplicationTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)
if self.wait_for_completion:
waiter = self.hook.get_waiter("serverless_app_stopped")
wait(
Expand All @@ -1397,6 +1482,30 @@ def execute(self, context: Context) -> None:
)
self.log.info("EMR serverless application %s stopped successfully", self.application_id)

def stop_application(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "success":
self.hook.conn.stop_application(applicationId=self.application_id)
self.defer(
trigger=EmrServerlessStopApplicationTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "success":
self.log.info("EMR serverless application %s stopped successfully", self.application_id)


class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperator):
"""
Expand All @@ -1410,10 +1519,17 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
:param wait_for_completion: If true, wait for the Application to be deleted before returning.
Defaults to True. Note that this operator will always wait for the application to be STOPPED first.
:param aws_conn_id: AWS connection to use
:param waiter_countdown: Total amount of time, in seconds, the operator will wait for each step of first,
the application to be stopped, and then deleted. Defaults to 25 minutes.
:param waiter_check_interval_seconds: Number of seconds between polling the state of the application.
:param waiter_countdown: (deprecated) Total amount of time, in seconds, the operator will wait for each
step of first,the application to be stopped, and then deleted. Defaults to 25 minutes.
:param waiter_check_interval_seconds: (deprecated) Number of seconds between polling the state
of the application. Defaults to 60 seconds.
:waiter_max_attempts: Number of times the waiter should poll the application to check the state.
Defaults to 25.
:param waiter_delay: Number of seconds between polling the state of the application.
Defaults to 60 seconds.
:param deferrable: If True, the operator will wait asynchronously for application to be deleted.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
:param force_stop: If set to True, any job for that app that is not in a terminal state will be cancelled.
Otherwise, trying to delete an app with running jobs will return an error.
If you want to wait for the jobs to finish gracefully, use
Expand All @@ -1432,6 +1548,7 @@ def __init__(
waiter_max_attempts: int | ArgNotSet = NOTSET,
waiter_delay: int | ArgNotSet = NOTSET,
force_stop: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
):
if waiter_check_interval_seconds is NOTSET:
Expand Down Expand Up @@ -1467,6 +1584,8 @@ def __init__(
force_stop=force_stop,
**kwargs,
)
self.deferrable = deferrable
self.wait_for_delete_completion = False if deferrable else wait_for_completion

def execute(self, context: Context) -> None:
# super stops the app (or makes sure it's already stopped)
Expand All @@ -1478,7 +1597,19 @@ def execute(self, context: Context) -> None:
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Application deletion failed: {response}")

if self.wait_for_delete_completion:
if self.deferrable:
self.defer(
trigger=EmrServerlessDeleteApplicationTrigger(
application_id=self.application_id,
aws_conn_id=self.aws_conn_id,
waiter_delay=self.waiter_delay,
waiter_max_attempts=self.waiter_max_attempts,
),
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
method_name="execute_complete",
)

elif self.wait_for_delete_completion:
waiter = self.hook.get_waiter("serverless_app_terminated")

wait(
Expand All @@ -1492,3 +1623,10 @@ def execute(self, context: Context) -> None:
)

self.log.info("EMR serverless application deleted")

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
elif event["status"] == "success":
self.log.info("EMR serverless application %s deleted successfully", self.application_id)
Loading

0 comments on commit 1706f05

Please sign in to comment.