From fa8101e57dc271d4c22f367e02f7ae86b2153537 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 29 Jun 2022 19:35:33 +0800 Subject: [PATCH] Clear one mapped task's state via the REST API Due to limitations of the Open API specification we use (3.0), I added a new field 'tasks' to the clearTaskInstances API. This field accepts a list of dicts, each dict expecting 'task_id' and optionally 'map_index'. This dict is converted to either a plain string (task ID) or a 2-tuple (task ID and map index) to pass into DAG.clear(). In the future, once we are able to use heterogeneous arrays, we could merge 'task_ids' and 'tasks' back into one property to clean up the interface. --- .../endpoints/task_instance_endpoint.py | 20 ++++++- airflow/api_connexion/openapi/v1.yaml | 29 ++++++++-- .../schemas/task_instance_schema.py | 9 +++ .../endpoints/test_dag_run_endpoint.py | 1 + .../endpoints/test_task_instance_endpoint.py | 55 ++++++++++++++++++- tests/conftest.py | 14 ++++- 6 files changed, 118 insertions(+), 10 deletions(-) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 6cc3e784e62a3..29bd1e1d69d42 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -434,10 +434,26 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> if not dag: error_message = f"Dag id {dag_id} not found" raise NotFound(error_message) + reset_dag_runs = data.pop('reset_dag_runs') dry_run = data.pop('dry_run') - # We always pass dry_run here, otherwise this would try to confirm on the terminal! - task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, **data) + + if "task_ids" not in data and "tasks" not in data: + task_ids = None + else: + task_ids = data.pop("task_ids", []) + for task_data in data.pop("tasks", ()): + if task_data["map_index"] is None: + task_ids.append(task_data["task_id"]) + else: + task_ids.append((task_data["task_id"], task_data["map_index"])) + + task_instances = dag.clear( + dry_run=True, # Otherwise this would try to confirm on the terminal! + dag_bag=get_airflow_app().dag_bag, + task_ids=task_ids, + **data, + ) if not dry_run: clear_task_instances( task_instances.all(), diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 6ce6c51608304..371d6b6b064ca 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -2893,14 +2893,18 @@ components: type: string readOnly: true description: The DAG ID. + dag_run_id: + type: string + readOnly: true + description: The DAG run ID. execution_date: type: string format: datetime readOnly: true - dag_run_id: - type: string + map_index: + type: integer readOnly: true - description: The DAG run ID. + description: Index of a mapped task, or -1 if the task is not mapped. TaskInstanceReferenceCollection: type: object @@ -3439,9 +3443,26 @@ components: type: boolean default: true + tasks: + description: | + A list of {task_id, map_index} combinations to clear. + + *New in version 2.4.0* + type: array + items: + type: object + properties: + task_id: + type: string + map_index: + type: integer + required: + - "task_id" + minItems: 1 + task_ids: description: | - A list of task ids to clear. + A list of task IDs to clear. *New in version 2.1.0* type: array diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 74824dbaf87c6..2e31edf2d02a1 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -88,6 +88,13 @@ class TaskInstanceCollectionSchema(Schema): total_entries = fields.Int() +class TaskInstanceInRunSchema(Schema): + """Schema to specify one task instance in a DAG run.""" + + task_id = fields.String() + map_index = fields.Integer(load_default=None) + + class TaskInstanceBatchFormSchema(Schema): """Schema for the request form passed to Task Instance Batch endpoint""" @@ -118,6 +125,7 @@ class ClearTaskInstanceFormSchema(Schema): include_subdags = fields.Boolean(load_default=False) include_parentdag = fields.Boolean(load_default=False) reset_dag_runs = fields.Boolean(load_default=False) + tasks = fields.List(fields.Nested(TaskInstanceInRunSchema), validate=validate.Length(min=1)) task_ids = fields.List(fields.String(), validate=validate.Length(min=1)) @validates_schema @@ -157,6 +165,7 @@ class TaskInstanceReferenceSchema(Schema): run_id = fields.Str(data_key="dag_run_id") dag_id = fields.Str() execution_date = fields.DateTime() + map_index = fields.Int() class TaskInstanceReferenceCollection(NamedTuple): diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index dd5803564a268..a60decfcab3b8 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1444,6 +1444,7 @@ def test_dry_run(self, dag_maker, session): "dag_run_id": dag_run_id, "execution_date": dr.execution_date.isoformat(), "task_id": "task_id", + "map_index": -1, } ] } diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index b4fd87ae2e0a6..ce074e3ee399c 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -26,7 +26,7 @@ from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user @@ -1009,6 +1009,52 @@ def test_should_respond_200( assert response.status_code == 200 assert len(response.json["task_instances"]) == expected_ti + @pytest.mark.parametrize( + "payload, changed_ti_count", + [ + pytest.param( + {"tasks": [{"task_id": "mapped", "map_index": 1}, {"task_id": "normal", "map_index": 2}]}, + 1, + id="combo", + ), + pytest.param( + {"tasks": [{"task_id": "mapped"}]}, + 2, # The SUCCESS one is not cleared. + id="combo-no-index", + ), + pytest.param( + {"tasks": [{"task_id": "mapped", "map_index": 2}], "task_ids": ["normal"]}, + 2, + id="combo-and-id", + ), + ], + ) + def test_should_respond_200_mapped(self, dag_maker, payload, changed_ti_count): + def echo(i): + pass + + with dag_maker(dag_id="test_clear_for_mapped", dagbag=self.app.dag_bag) as dag: + dag.task(task_id="normal")(echo)(i=-1) # A non-mapped task named 'normal'. + dag.task(task_id="mapped")(echo).expand(i=[0, 1, 2]) # A mapped task named 'mapped'. + + dr = dag_maker.create_dagrun(execution_date=DEFAULT_DATETIME_1, run_type=DagRunType.SCHEDULED) + dag.get_task("mapped").expand_mapped_task(dr.run_id, session=dag_maker.session) + tis = {(ti.task_id, ti.map_index): ti for ti in dr.get_task_instances(session=dag_maker.session)} + tis["normal", -1].state = TaskInstanceState.FAILED + tis["mapped", 0].state = TaskInstanceState.SUCCESS + tis["mapped", 1].state = TaskInstanceState.FAILED + tis["mapped", 2].state = TaskInstanceState.UPSTREAM_FAILED + + dag_maker.session.commit() + + response = self.client.post( + "/api/v1/dags/test_clear_for_mapped/clearTaskInstances", + environ_overrides={"REMOTE_USER": "test"}, + json={"dry_run": True, **payload}, + ) + assert response.status_code == 200 + assert len(response.json["task_instances"]) == changed_ti_count + @mock.patch("airflow.api_connexion.endpoints.task_instance_endpoint.clear_task_instances") def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, session): """Test that if reset_dag_runs is True, then clear_task_instances is called with State.QUEUED""" @@ -1096,30 +1142,35 @@ def test_should_respond_200_with_reset_dag_run(self, session): 'dag_run_id': 'TEST_DAG_RUN_ID_0', 'execution_date': '2020-01-01T00:00:00+00:00', 'task_id': 'print_the_context', + "map_index": -1, }, { 'dag_id': 'example_python_operator', 'dag_run_id': 'TEST_DAG_RUN_ID_1', 'execution_date': '2020-01-02T00:00:00+00:00', 'task_id': 'sleep_for_0', + "map_index": -1, }, { 'dag_id': 'example_python_operator', 'dag_run_id': 'TEST_DAG_RUN_ID_2', 'execution_date': '2020-01-03T00:00:00+00:00', 'task_id': 'sleep_for_1', + "map_index": -1, }, { 'dag_id': 'example_python_operator', 'dag_run_id': 'TEST_DAG_RUN_ID_3', 'execution_date': '2020-01-04T00:00:00+00:00', 'task_id': 'sleep_for_2', + "map_index": -1, }, { 'dag_id': 'example_python_operator', 'dag_run_id': 'TEST_DAG_RUN_ID_4', 'execution_date': '2020-01-05T00:00:00+00:00', 'task_id': 'sleep_for_3', + "map_index": -1, }, ] for task_instance in expected_response: @@ -1221,6 +1272,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi 'dag_run_id': 'TEST_DAG_RUN_ID', 'execution_date': '2020-01-01T00:00:00+00:00', 'task_id': 'print_the_context', + "map_index": -1, } ] } @@ -1270,6 +1322,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ 'dag_run_id': 'TEST_DAG_RUN_ID', 'execution_date': '2020-01-01T00:00:00+00:00', 'task_id': 'print_the_context', + "map_index": -1, } ] } diff --git a/tests/conftest.py b/tests/conftest.py index b153c213d5f93..a81ec5160780f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -515,10 +515,11 @@ def __exit__(self, type, value, traceback): self.serialized_model = SerializedDagModel(dag) self.session.merge(self.serialized_model) serialized_dag = self._serialized_dag() - self.dagbag.bag_dag(serialized_dag, root_dag=serialized_dag) + self.active_dagbag.bag_dag(serialized_dag, root_dag=serialized_dag) self.session.flush() else: - self.dagbag.bag_dag(self.dag, self.dag) + self.active_dagbag.bag_dag(self.dag, self.dag) + del self.active_dagbag def create_dagrun(self, **kwargs): from airflow.utils import timezone @@ -568,7 +569,13 @@ def create_dagrun_after(self, dagrun, **kwargs): ) def __call__( - self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs + self, + dag_id='test_dag', + dagbag=None, + serialized=want_serialized, + fileloc=None, + session=None, + **kwargs, ): from airflow import settings from airflow.models import DAG @@ -578,6 +585,7 @@ def __call__( self._own_session = True session = settings.Session() + self.active_dagbag = dagbag or self.dagbag self.kwargs = kwargs self.session = session self.start_date = self.kwargs.get('start_date', None)