Skip to content

Commit

Permalink
introduce fail_policy
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv committed Dec 9, 2024
1 parent ee07d63 commit baf9e1c
Show file tree
Hide file tree
Showing 16 changed files with 204 additions and 144 deletions.
5 changes: 3 additions & 2 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ from airflow.decorators.short_circuit import short_circuit_task
from airflow.decorators.task_group import task_group
from airflow.models.dag import dag
from airflow.providers.cncf.kubernetes.secret import Secret
from airflow.sensors.base import FailPolicy
from airflow.typing_compat import Literal

# Please keep this in sync with __init__.py's __all__.
Expand Down Expand Up @@ -690,7 +691,7 @@ class TaskDecoratorCollection:
*,
poke_interval: float = ...,
timeout: float = ...,
soft_fail: bool = False,
fail_policy: FailPolicy = ...,
mode: str = ...,
exponential_backoff: bool = False,
max_wait: timedelta | float | None = None,
Expand All @@ -702,7 +703,7 @@ class TaskDecoratorCollection:
:param poke_interval: Time in seconds that the job should wait in
between each try
:param timeout: Time, in seconds before the task times out and fails.
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param fail_policy: TODO.
:param mode: How the sensor operates.
Options are: ``{ poke | reschedule }``, default is ``poke``.
When set to ``poke`` the sensor is taking up a worker slot for its
Expand Down
22 changes: 17 additions & 5 deletions airflow/example_dags/example_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.models.dag import DAG
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.sensors.base import FailPolicy
from airflow.providers.standard.sensors.bash import BashSensor
from airflow.providers.standard.sensors.filesystem import FileSensor
from airflow.providers.standard.sensors.python import PythonSensor
Expand Down Expand Up @@ -68,7 +69,7 @@ def failure_callable():
t2 = TimeSensor(
task_id="timeout_after_second_date_in_the_future",
timeout=1,
soft_fail=True,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(),
)
# [END example_time_sensors]
Expand All @@ -81,15 +82,20 @@ def failure_callable():
t2a = TimeSensorAsync(
task_id="timeout_after_second_date_in_the_future_async",
timeout=1,
soft_fail=True,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(),
)
# [END example_time_sensors_async]

# [START example_bash_sensors]
t3 = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0")

t4 = BashSensor(task_id="Sensor_fails_after_3_seconds", timeout=3, soft_fail=True, bash_command="exit 1")
t4 = BashSensor(
task_id="Sensor_fails_after_3_seconds",
timeout=3,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
bash_command="exit 1",
)
# [END example_bash_sensors]

t5 = BashOperator(task_id="remove_file", bash_command="rm -rf /tmp/temporary_file_for_testing")
Expand All @@ -112,13 +118,19 @@ def failure_callable():
t9 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable)

t10 = PythonSensor(
task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable
task_id="failure_timeout_sensor_python",
timeout=3,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
python_callable=failure_callable,
)
# [END example_python_sensors]

# [START example_day_of_week_sensor]
t11 = DayOfWeekSensor(
task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY
task_id="week_day_sensor_failing_on_timeout",
timeout=3,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
week_day=WeekDay.MONDAY,
)
# [END example_day_of_week_sensor]

Expand Down
4 changes: 4 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class AirflowSensorTimeout(AirflowException):
"""Raise when there is a timeout on sensor polling."""


class AirflowPokeFailException(AirflowException):
"""Raise when a sensor must not try to poke again."""


class AirflowRescheduleException(AirflowException):
"""
Raise when the task should be re-scheduled at a later time.
Expand Down
75 changes: 42 additions & 33 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import datetime
import enum
import functools
import hashlib
import time
Expand All @@ -32,7 +33,7 @@
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowPokeFailException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
Expand Down Expand Up @@ -107,15 +108,31 @@ def _orig_start_date(
)


class FailPolicy(str, enum.Enum):
"""Class with sensor's fail policies."""

# if poke method raise an exception, sensor will not be skipped on.
NONE = "none"

# If poke method raises an exception, sensor will be skipped on.
SKIP_ON_ANY_ERROR = "skip_on_any_error"

# If poke method raises AirflowSensorTimeout, AirflowTaskTimeout,AirflowPokeFailException or AirflowSkipException
# sensor will be skipped on.
SKIP_ON_TIMEOUT = "skip_on_timeout"

# If poke method raises an exception different from AirflowSensorTimeout, AirflowTaskTimeout,
# AirflowSkipException or AirflowFailException sensor will ignore exception and re-poke until timeout.
IGNORE_ERROR = "ignore_error"


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Sensor operators keep executing at a time interval and succeed when
a criteria is met and fail if and when they time out.
:param soft_fail: Set to true to mark the task as SKIPPED on failure.
Mutually exclusive with never_fail.
:param poke_interval: Time that the job should wait in between each try.
Can be ``timedelta`` or ``float`` seconds.
:param timeout: Time elapsed before the task times out and fails.
Expand Down Expand Up @@ -143,13 +160,10 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
:param exponential_backoff: allow progressive longer waits between
pokes by using exponential backoff algorithm
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
:param silent_fail: If true, and poke method raises an exception different from
AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException
and AirflowFailException, the sensor will log the error and continue
its execution. Otherwise, the sensor task fails, and it can be retried
based on the provided `retries` parameter.
:param never_fail: If true, and poke method raises an exception, sensor will be skipped.
Mutually exclusive with soft_fail.
:param fail_policy: defines the rule by which sensor skip itself. Options are:
``{ none | skip_on_any_error | skip_on_timeout | ignore_error }``
default is ``none``. Options can be set as string or
using the constants defined in the static class ``airflow.sensors.base.FailPolicy``
"""

ui_color: str = "#e6f1f2"
Expand All @@ -164,26 +178,19 @@ def __init__(
*,
poke_interval: timedelta | float = 60,
timeout: timedelta | float = conf.getfloat("sensors", "default_timeout"),
soft_fail: bool = False,
mode: str = "poke",
exponential_backoff: bool = False,
max_wait: timedelta | float | None = None,
silent_fail: bool = False,
never_fail: bool = False,
fail_policy: str = FailPolicy.NONE,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds()
self.soft_fail = soft_fail
self.timeout: int | float = self._coerce_timeout(timeout).total_seconds()
self.mode = mode
self.exponential_backoff = exponential_backoff
self.max_wait = self._coerce_max_wait(max_wait)
if soft_fail is True and never_fail is True:
raise ValueError("soft_fail and never_fail are mutually exclusive, you can not provide both.")

self.silent_fail = silent_fail
self.never_fail = never_fail
self.fail_policy = fail_policy
self._validate_input_values()

@staticmethod
Expand Down Expand Up @@ -287,21 +294,20 @@ def run_duration() -> float:
except (
AirflowSensorTimeout,
AirflowTaskTimeout,
AirflowFailException,
AirflowPokeFailException,
AirflowSkipException,
) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
elif self.never_fail:
raise AirflowSkipException("Skipping due to never_fail is set to True.") from e
raise e
except AirflowSkipException as e:
if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT:
raise AirflowSkipException("Skipping due fail_policy set to SKIP_ON_TIMEOUT.") from e
elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e
raise e
except Exception as e:
if self.silent_fail:
if self.fail_policy == FailPolicy.IGNORE_ERROR:
self.log.error("Sensor poke failed: \n %s", traceback.format_exc())
poke_return = False
elif self.never_fail:
raise AirflowSkipException("Skipping due to never_fail is set to True.") from e
elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e
else:
raise e

Expand All @@ -311,13 +317,13 @@ def run_duration() -> float:
break

if run_duration() > self.timeout:
# If sensor is in soft fail mode but times out raise AirflowSkipException.
# If sensor is in SKIP_ON_TIMEOUT mode but times out it raise AirflowSkipException.
message = (
f"Sensor has timed out; run duration of {run_duration()} seconds exceeds "
f"the specified timeout of {self.timeout}."
)

if self.soft_fail:
if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT:
raise AirflowSkipException(message)
else:
raise AirflowSensorTimeout(message)
Expand All @@ -340,9 +346,12 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None,
try:
return super().resume_execution(next_method, next_kwargs, context)
except TaskDeferralTimeout as e:
raise AirflowSensorTimeout(*e.args) from e
if self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException(str(e)) from e
else:
raise AirflowSensorTimeout(*e.args) from e
except (AirflowException, TaskDeferralError) as e:
if self.soft_fail:
if self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException(str(e)) from e
raise

Expand Down
3 changes: 2 additions & 1 deletion providers/src/airflow/providers/ftp/sensors/ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowPokeFailException
from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -83,7 +84,7 @@ def poke(self, context: Context) -> bool:
if (error_code != 550) and (
self.fail_on_transient_errors or (error_code not in self.transient_errors)
):
raise e
raise AirflowPokeFailException from e

return False

Expand Down
4 changes: 2 additions & 2 deletions providers/src/airflow/providers/sftp/sensors/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from paramiko.sftp import SFTP_NO_SUCH_FILE

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowPokeFailException
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.triggers.sftp import SFTPTrigger
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
Expand Down Expand Up @@ -99,7 +99,7 @@ def poke(self, context: Context) -> PokeReturnValue | bool:
self.log.info("Found File %s last modified: %s", actual_file_to_check, mod_time)
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
raise AirflowException from e
raise AirflowPokeFailException from e
continue

if self.newer_than:
Expand Down
Loading

0 comments on commit baf9e1c

Please sign in to comment.