Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv committed Aug 3, 2024
1 parent a7c19cf commit 7ae0db7
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 92 deletions.
77 changes: 52 additions & 25 deletions tests/providers/amazon/aws/sensors/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
BatchSensor,
)
from airflow.providers.amazon.aws.triggers.batch import BatchJobTrigger
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS

TASK_ID = "batch_job_sensor"
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
Expand Down Expand Up @@ -104,24 +101,25 @@ def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: Batch
with pytest.raises(AirflowException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
def test_execute_failure_in_deferrable_mode_with_fail_policy(self):
"""Tests that an AirflowSkipException is raised in case of error event and fail_policy is set to True"""

args = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

args["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
args["soft_fail"] = True
deferrable_batch_sensor = BatchSensor(
task_id="task",
job_id=JOB_ID,
region_name=AWS_REGION,
deferrable=True,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True, **args
)

with pytest.raises(AirflowSkipException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowException), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException)),
"catch_mode, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@pytest.mark.parametrize(
"state, error_message",
Expand All @@ -139,11 +137,22 @@ def test_fail_poke(
mock_get_job_description,
state,
error_message,
fail_policy,
catch_mode,
expected_exception,
):
args = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

if catch_mode:
args["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
args["fail_policy"] = FailPolicy.NONE
else:
args["soft_fail"] = catch_mode

mock_get_job_description.return_value = {"status": state}
batch_sensor = BatchSensor(task_id="batch_job_sensor", job_id=JOB_ID, fail_policy=fail_policy)
batch_sensor = BatchSensor(task_id="batch_job_sensor", job_id=JOB_ID, **args)
with pytest.raises(expected_exception, match=error_message):
batch_sensor.poke({})

Expand Down Expand Up @@ -215,10 +224,8 @@ def test_poke_invalid(
)
assert "AWS Batch compute environment failed" in str(ctx.value)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowException), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException)),
"catch_mode, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@pytest.mark.parametrize(
"compute_env, error_message",
Expand All @@ -236,14 +243,25 @@ def test_fail_poke(
mock_batch_client,
compute_env,
error_message,
fail_policy,
catch_mode,
expected_exception,
):
args = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

if catch_mode:
args["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
args["fail_policy"] = FailPolicy.NONE
else:
args["soft_fail"] = catch_mode

mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": compute_env}
batch_compute_environment_sensor = BatchComputeEnvironmentSensor(
task_id="test_batch_compute_environment_sensor",
compute_environment=ENVIRONMENT_NAME,
fail_policy=fail_policy,
**args,
)

with pytest.raises(expected_exception, match=error_message):
Expand Down Expand Up @@ -318,23 +336,32 @@ def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQ
)
assert "AWS Batch job queue failed" in str(ctx.value)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@pytest.mark.parametrize(
"fail_policy, expected_exception",
((FailPolicy.NONE, AirflowException), (FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException)),
"catch_mode, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@pytest.mark.parametrize("job_queue", ([], [{"status": "UNKNOWN_STATUS"}]))
@mock.patch.object(BatchClientHook, "client")
def test_fail_poke(
self,
mock_batch_client,
job_queue,
fail_policy,
catch_mode,
expected_exception,
):
args = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

if catch_mode:
args["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
args["fail_policy"] = FailPolicy.NONE
else:
args["soft_fail"] = catch_mode

mock_batch_client.describe_job_queues.return_value = {"jobQueues": job_queue}
batch_job_queue_sensor = BatchJobQueueSensor(
task_id="test_batch_job_queue_sensor", job_queue=JOB_QUEUE, fail_policy=fail_policy
task_id="test_batch_job_queue_sensor", job_queue=JOB_QUEUE, **args
)
batch_job_queue_sensor.treat_non_existing_as_deleted = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,26 @@

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS


class TestEmrServerlessApplicationSensor:
def setup_method(self, fail_policy):
def setup_method(self, args, optional_arg=None):
self.app_id = "vzwemreks"
self.job_run_id = "job1234"
self.sensor = EmrServerlessApplicationSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
aws_conn_id="aws_default",
fail_policy=fail_policy,
)
if optional_arg:
self.sensor = EmrServerlessApplicationSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
aws_conn_id="aws_default",
**optional_arg,
)
else:
self.sensor = EmrServerlessApplicationSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
aws_conn_id="aws_default",
)

def set_get_application_return_value(self, return_value: dict[str, str]):
self.mock_hook = MagicMock()
Expand Down Expand Up @@ -82,10 +86,17 @@ def test_poke_raises_airflow_exception_with_failure_states(self, state):
self.assert_get_application_was_called_once_with_app_id()


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
class TestPokeRaisesAirflowSkipException(TestEmrServerlessApplicationSensor):
def setup_method(self, fail_policy=None):
super().setup_method(FailPolicy.SKIP_ON_TIMEOUT)
def setup_method(self, args, optional_arg=None):
optional_arg = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

optional_arg["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
optional_arg["soft_fail"] = True

super().setup_method(args, optional_arg)

def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self):
self.set_get_application_return_value(
Expand All @@ -95,4 +106,3 @@ def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_excep
self.sensor.poke(None)
assert "EMR Serverless application failed: mock stopped" == str(ctx.value)
self.assert_get_application_was_called_once_with_app_id()
self.sensor.soft_fail = False
43 changes: 27 additions & 16 deletions tests/providers/amazon/aws/sensors/test_emr_serverless_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,28 @@

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.emr import EmrServerlessJobSensor
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS


class TestEmrServerlessJobSensor:
def setup_method(self, fail_policy):
def setup_method(self, args, optional_arg=None):
self.app_id = "vzwemreks"
self.job_run_id = "job1234"
self.sensor = EmrServerlessJobSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
job_run_id=self.job_run_id,
aws_conn_id="aws_default",
fail_policy=fail_policy,
)
if optional_arg:
self.sensor = EmrServerlessJobSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
job_run_id=self.job_run_id,
aws_conn_id="aws_default",
**optional_arg,
)
else:
self.sensor = EmrServerlessJobSensor(
task_id="test_emrcontainer_sensor",
application_id=self.app_id,
job_run_id=self.job_run_id,
aws_conn_id="aws_default",
)

def set_get_job_run_return_value(self, return_value: dict[str, str]):
self.mock_hook = MagicMock()
Expand Down Expand Up @@ -85,15 +90,21 @@ def test_poke_raises_airflow_exception_with_specified_states(self, state):
self.assert_get_job_run_was_called_once_with_app_and_run_id()


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
class TestPokeRaisesAirflowSkipException(TestEmrServerlessJobSensor):
def setup_method(self, fail_policy=None):
super().setup_method(FailPolicy.SKIP_ON_TIMEOUT)
def setup_method(self, args, optional_arg=None):
optional_arg = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

optional_arg["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
optional_arg["soft_fail"] = True

super().setup_method(args, optional_arg)

def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self):
self.set_get_job_run_return_value({"jobRun": {"state": "FAILED", "stateDetails": "mock failed"}})
with pytest.raises(AirflowSkipException) as ctx:
self.sensor.poke(None)
assert "EMR Serverless job failed: mock failed" == str(ctx.value)
self.assert_get_job_run_was_called_once_with_app_and_run_id()
self.sensor.soft_fail = False
16 changes: 10 additions & 6 deletions tests/providers/ftp/sensors/test_ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
from airflow.exceptions import AirflowSensorTimeout, AirflowSkipException
from airflow.providers.ftp.hooks.ftp import FTPHook
from airflow.providers.ftp.sensors.ftp import FTPSensor
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS


class TestFTPSensor:
Expand Down Expand Up @@ -74,14 +71,21 @@ def test_poke_fail_on_transient_error(self, mock_hook):

assert "434" in str(ctx.value.__cause__)

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook)
def test_poke_fail_on_transient_error_and_skip(self, mock_hook):
args = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

args["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
args["soft_fail"] = True

op = FTPSensor(
path="foobar.json",
ftp_conn_id="bob_ftp",
task_id="test_task",
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
**args,
)

mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = error_perm(
Expand Down
16 changes: 10 additions & 6 deletions tests/providers/http/sensors/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@
from airflow.providers.http.operators.http import HttpOperator
from airflow.providers.http.sensors.http import HttpSensor
from airflow.providers.http.triggers.http import HttpSensorTrigger
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error

with ignore_provider_compatibility_error("2.10.0", __file__):
from airflow.sensors.base import FailPolicy
from airflow.utils.timezone import datetime
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -74,7 +71,6 @@ def resp_check(_):
with pytest.raises(AirflowException, match="AirflowException raised here!"):
task.execute(context={})

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0")
@patch("airflow.providers.http.hooks.http.requests.Session.send")
def test_poke_exception_with_skip_on_timeout(self, mock_session_send, create_task_of_operator):
"""
Expand All @@ -87,6 +83,14 @@ def test_poke_exception_with_skip_on_timeout(self, mock_session_send, create_tas
def resp_check(_):
raise AirflowSensorTimeout("AirflowSensorTimeout raised here!")

args = {}
if AIRFLOW_V_2_10_PLUS:
from airflow.sensors.base import FailPolicy

args["fail_policy"] = FailPolicy.SKIP_ON_TIMEOUT
else:
args["soft_fail"] = True

task = create_task_of_operator(
HttpSensor,
dag_id="http_sensor_poke_exception",
Expand All @@ -97,7 +101,7 @@ def resp_check(_):
response_check=resp_check,
timeout=5,
poke_interval=1,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
**args,
)
with pytest.raises(AirflowSkipException):
task.execute(context={})
Expand Down
Loading

0 comments on commit 7ae0db7

Please sign in to comment.