Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv committed Aug 3, 2024
1 parent a786a7d commit a7c19cf
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 10 deletions.
9 changes: 8 additions & 1 deletion tests/decorators/test_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
from airflow.decorators import task
from airflow.exceptions import AirflowSensorTimeout
from airflow.models import XCom
from airflow.sensors.base import FailPolicy, PokeReturnValue
from airflow.sensors.base import PokeReturnValue
from tests.test_utils.compat import ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from airflow.utils.state import State
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -141,6 +146,7 @@ def dummy_f():
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_basic_sensor_skip_on_timeout(self, dag_maker):
@task.sensor(timeout=0, fail_policy=FailPolicy.SKIP_ON_TIMEOUT)
def sensor_f():
Expand All @@ -165,6 +171,7 @@ def dummy_f():
if ti.task_id == "dummy_f":
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_basic_sensor_skip_on_timeout_returns_bool(self, dag_maker):
@task.sensor(timeout=0, fail_policy=FailPolicy.SKIP_ON_TIMEOUT)
def sensor_f():
Expand Down
9 changes: 8 additions & 1 deletion tests/providers/amazon/aws/sensors/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
BatchSensor,
)
from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy

TASK_ID = "batch_job_sensor"
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
Expand Down Expand Up @@ -101,6 +104,7 @@ def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: Batch
with pytest.raises(AirflowException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_execute_failure_in_deferrable_mode_with_fail_policy(self):
"""Tests that an AirflowSkipException is raised in case of error event and fail_policy is set to True"""
deferrable_batch_sensor = BatchSensor(
Expand All @@ -114,6 +118,7 @@ def test_execute_failure_in_deferrable_mode_with_fail_policy(self):
with pytest.raises(AirflowSkipException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowException), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException)),
Expand Down Expand Up @@ -210,6 +215,7 @@ def test_poke_invalid(
)
assert "AWS Batch compute environment failed" in str(ctx.value)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowException), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException)),
Expand Down Expand Up @@ -312,6 +318,7 @@ def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQ
)
assert "AWS Batch job queue failed" in str(ctx.value)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowException), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy


class TestEmrServerlessApplicationSensor:
Expand Down Expand Up @@ -79,6 +82,7 @@ def test_poke_raises_airflow_exception_with_failure_states(self, state):
self.assert_get_application_was_called_once_with_app_id()


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
class TestPokeRaisesAirflowSkipException(TestEmrServerlessApplicationSensor):
def setup_method(self, fail_policy=None):
super().setup_method(FailPolicy.SKIP_ON_TIMEOUT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessJobSensor
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy


class TestEmrServerlessJobSensor:
Expand Down Expand Up @@ -82,6 +85,7 @@ def test_poke_raises_airflow_exception_with_specified_states(self, state):
self.assert_get_job_run_was_called_once_with_app_and_run_id()


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
class TestPokeRaisesAirflowSkipException(TestEmrServerlessJobSensor):
def setup_method(self, fail_policy=None):
super().setup_method(FailPolicy.SKIP_ON_TIMEOUT)
Expand Down
6 changes: 5 additions & 1 deletion tests/providers/ftp/sensors/test_ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
from airflow.exceptions import AirflowSensorTimeout, AirflowSkipException
from airflow.providers.ftp.hooks.ftp import FTPHook
from airflow.providers.ftp.sensors.ftp import FTPSensor
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy


class TestFTPSensor:
Expand Down Expand Up @@ -71,6 +74,7 @@ def test_poke_fail_on_transient_error(self, mock_hook):

assert "434" in str(ctx.value.__cause__)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook)
def test_poke_fail_on_transient_error_and_skip(self, mock_hook):
op = FTPSensor(
Expand Down
6 changes: 5 additions & 1 deletion tests/providers/http/sensors/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
from airflow.providers.http.triggers.http import HttpSensorTrigger
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from airflow.utils.timezone import datetime

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -71,6 +74,7 @@ def resp_check(_):
with pytest.raises(AirflowException, match="AirflowException raised here!"):
task.execute(context={})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@patch("airflow.providers.http.hooks.http.requests.Session.send")
def test_poke_exception_with_skip_on_timeout(self, mock_session_send, create_task_of_operator):
"""
Expand Down
7 changes: 6 additions & 1 deletion tests/providers/sftp/sensors/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

from airflow.exceptions import AirflowSensorTimeout
from airflow.providers.sftp.sensors.sftp import SFTPSensor
from airflow.sensors.base import FailPolicy, PokeReturnValue
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from airflow.sensors.base import PokeReturnValue

# Ignore missing args provided by default_args
# mypy: disable-error-code="arg-type"
Expand All @@ -52,6 +56,7 @@ def test_file_absent(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
assert not output

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowSensorTimeout), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSensorTimeout)),
Expand Down
16 changes: 14 additions & 2 deletions tests/sensors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@
from airflow.providers.celery.executors.celery_kubernetes_executor import CeleryKubernetesExecutor
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor
from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor import LocalKubernetesExecutor
from airflow.sensors.base import BaseSensorOperator, FailPolicy, PokeReturnValue, poke_mode_only
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -96,7 +100,7 @@ def __init__(self, return_value=False, **kwargs):
self.return_value = return_value

def execute_complete(self, context, event=None):
raise AirflowException("Should be skipped")
raise AirflowException()


class DummySensorWithXcomValue(BaseSensorOperator):
Expand Down Expand Up @@ -179,6 +183,7 @@ def test_fail(self, make_sensor):
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_skip_on_timeout(self, make_sensor):
sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT)

Expand All @@ -191,6 +196,7 @@ def test_skip_on_timeout(self, make_sensor):
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"exception_cls",
(ValueError,),
Expand All @@ -209,6 +215,7 @@ def test_skip_on_timeout_with_exception(self, make_sensor, exception_cls):
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"exception_cls",
(
Expand All @@ -230,6 +237,7 @@ def test_skip_on_timeout_with_skip_exception(self, make_sensor, exception_cls):
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"exception_cls",
(AirflowSensorTimeout, AirflowTaskTimeout, AirflowFailException, AirflowPokeFailException, Exception),
Expand All @@ -247,6 +255,7 @@ def test_skip_on_any_error_with_skip_exception(self, make_sensor, exception_cls)
if ti.task_id == DUMMY_OP:
assert ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_skip_on_timeout_with_retries(self, make_sensor):
sensor, dr = make_sensor(
return_value=False,
Expand Down Expand Up @@ -360,6 +369,7 @@ def _get_tis():
assert sensor_ti.state == State.FAILED
assert dummy_ti.state == State.NONE

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_skip_on_timeout_with_reschedule(self, make_sensor, time_machine, session):
sensor, dr = make_sensor(
return_value=False,
Expand Down Expand Up @@ -884,6 +894,7 @@ def _increment_try_number():
assert sensor_ti.max_tries == 4
assert sensor_ti.state == State.FAILED

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_reschedule_and_retry_timeout_and_silent_fail(self, make_sensor, time_machine, session):
"""
Test mode="reschedule", silent_fail=True then retries and timeout configurations interact correctly.
Expand Down Expand Up @@ -1117,6 +1128,7 @@ def test_poke_mode_only_bad_poke(self):


class TestAsyncSensor:
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
[
Expand Down
10 changes: 9 additions & 1 deletion tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy

from airflow.sensors.external_task import (
ExternalTaskMarker,
ExternalTaskSensor,
Expand Down Expand Up @@ -336,6 +340,7 @@ def test_external_task_sensor_failed_states_as_success(self, caplog):
f"Poking for tasks ['{TEST_TASK_ID}'] in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... "
) in caplog.messages

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_external_task_sensor_skip_on_timeout_failed_states_as_skipped(self):
self.add_time_sensor()
op = ExternalTaskSensor(
Expand Down Expand Up @@ -474,6 +479,7 @@ def test_external_dag_sensor_log(self, caplog):
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert (f"Poking for DAG 'other_dag' on {DEFAULT_DATE.isoformat()} ... ") in caplog.messages

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_external_dag_sensor_skip_on_timeout_as_skipped(self):
other_dag = DAG("other_dag", default_args=self.args, end_date=DEFAULT_DATE, schedule="@once")
other_dag.create_dagrun(
Expand Down Expand Up @@ -861,6 +867,7 @@ def test_external_task_group_when_there_is_no_TIs(self):
ignore_ti_state=True,
)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"kwargs, expected_message",
(
Expand Down Expand Up @@ -920,6 +927,7 @@ def test_fail_poke(
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"response_get_current, response_exists, kwargs, expected_message",
(
Expand Down

0 comments on commit a7c19cf

Please sign in to comment.