Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear one mapped task's state via the REST API #24732

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,26 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) ->
if not dag:
error_message = f"Dag id {dag_id} not found"
raise NotFound(error_message)

reset_dag_runs = data.pop('reset_dag_runs')
dry_run = data.pop('dry_run')
# We always pass dry_run here, otherwise this would try to confirm on the terminal!
task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, **data)

if "task_ids" not in data and "tasks" not in data:
task_ids = None
else:
task_ids = data.pop("task_ids", [])
for task_data in data.pop("tasks", ()):
if task_data["map_index"] is None:
task_ids.append(task_data["task_id"])
else:
task_ids.append((task_data["task_id"], task_data["map_index"]))

task_instances = dag.clear(
dry_run=True, # Otherwise this would try to confirm on the terminal!
dag_bag=get_airflow_app().dag_bag,
task_ids=task_ids,
**data,
)
if not dry_run:
clear_task_instances(
task_instances.all(),
Expand Down
29 changes: 25 additions & 4 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2893,14 +2893,18 @@ components:
type: string
readOnly: true
description: The DAG ID.
dag_run_id:
type: string
readOnly: true
description: The DAG run ID.
execution_date:
type: string
format: datetime
readOnly: true
dag_run_id:
type: string
map_index:
type: integer
readOnly: true
description: The DAG run ID.
description: Index of a mapped task, or -1 if the task is not mapped.

TaskInstanceReferenceCollection:
type: object
Expand Down Expand Up @@ -3439,9 +3443,26 @@ components:
type: boolean
default: true

tasks:
description: |
A list of {task_id, map_index} combinations to clear.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth mentioning (somewhere? In the example? Is that possible) that you can pass both tasks and task_ids at the same time to this endpoint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to be communicated to the users? If not, YAML can contain comments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was thinking about communicating this to users via the rendered API docs, on https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html


*New in version 2.4.0*
type: array
items:
type: object
properties:
task_id:
type: string
map_index:
type: integer
required:
- "task_id"
minItems: 1

task_ids:
description: |
A list of task ids to clear.
A list of task IDs to clear.

*New in version 2.1.0*
type: array
Expand Down
9 changes: 9 additions & 0 deletions airflow/api_connexion/schemas/task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class TaskInstanceCollectionSchema(Schema):
total_entries = fields.Int()


class TaskInstanceInRunSchema(Schema):
"""Schema to specify one task instance in a DAG run."""

task_id = fields.String()
map_index = fields.Integer(load_default=None)


class TaskInstanceBatchFormSchema(Schema):
"""Schema for the request form passed to Task Instance Batch endpoint"""

Expand Down Expand Up @@ -118,6 +125,7 @@ class ClearTaskInstanceFormSchema(Schema):
include_subdags = fields.Boolean(load_default=False)
include_parentdag = fields.Boolean(load_default=False)
reset_dag_runs = fields.Boolean(load_default=False)
tasks = fields.List(fields.Nested(TaskInstanceInRunSchema), validate=validate.Length(min=1))
task_ids = fields.List(fields.String(), validate=validate.Length(min=1))

@validates_schema
Expand Down Expand Up @@ -157,6 +165,7 @@ class TaskInstanceReferenceSchema(Schema):
run_id = fields.Str(data_key="dag_run_id")
dag_id = fields.Str()
execution_date = fields.DateTime()
map_index = fields.Int()


class TaskInstanceReferenceCollection(NamedTuple):
Expand Down
1 change: 1 addition & 0 deletions tests/api_connexion/endpoints/test_dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,7 @@ def test_dry_run(self, dag_maker, session):
"dag_run_id": dag_run_id,
"execution_date": dr.execution_date.isoformat(),
"task_id": "task_id",
"map_index": -1,
}
]
}
Expand Down
55 changes: 54 additions & 1 deletion tests/api_connexion/endpoints/test_task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from airflow.security import permissions
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user
Expand Down Expand Up @@ -1009,6 +1009,52 @@ def test_should_respond_200(
assert response.status_code == 200
assert len(response.json["task_instances"]) == expected_ti

@pytest.mark.parametrize(
"payload, changed_ti_count",
[
pytest.param(
{"tasks": [{"task_id": "mapped", "map_index": 1}, {"task_id": "normal", "map_index": 2}]},
1,
id="combo",
),
pytest.param(
{"tasks": [{"task_id": "mapped"}]},
2, # The SUCCESS one is not cleared.
id="combo-no-index",
),
pytest.param(
{"tasks": [{"task_id": "mapped", "map_index": 2}], "task_ids": ["normal"]},
2,
id="combo-and-id",
),
],
)
def test_should_respond_200_mapped(self, dag_maker, payload, changed_ti_count):
def echo(i):
pass

with dag_maker(dag_id="test_clear_for_mapped", dagbag=self.app.dag_bag) as dag:
dag.task(task_id="normal")(echo)(i=-1) # A non-mapped task named 'normal'.
dag.task(task_id="mapped")(echo).expand(i=[0, 1, 2]) # A mapped task named 'mapped'.

dr = dag_maker.create_dagrun(execution_date=DEFAULT_DATETIME_1, run_type=DagRunType.SCHEDULED)
dag.get_task("mapped").expand_mapped_task(dr.run_id, session=dag_maker.session)
tis = {(ti.task_id, ti.map_index): ti for ti in dr.get_task_instances(session=dag_maker.session)}
tis["normal", -1].state = TaskInstanceState.FAILED
tis["mapped", 0].state = TaskInstanceState.SUCCESS
tis["mapped", 1].state = TaskInstanceState.FAILED
tis["mapped", 2].state = TaskInstanceState.UPSTREAM_FAILED

dag_maker.session.commit()

response = self.client.post(
"/api/v1/dags/test_clear_for_mapped/clearTaskInstances",
environ_overrides={"REMOTE_USER": "test"},
json={"dry_run": True, **payload},
)
assert response.status_code == 200
assert len(response.json["task_instances"]) == changed_ti_count

@mock.patch("airflow.api_connexion.endpoints.task_instance_endpoint.clear_task_instances")
def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, session):
"""Test that if reset_dag_runs is True, then clear_task_instances is called with State.QUEUED"""
Expand Down Expand Up @@ -1096,30 +1142,35 @@ def test_should_respond_200_with_reset_dag_run(self, session):
'dag_run_id': 'TEST_DAG_RUN_ID_0',
'execution_date': '2020-01-01T00:00:00+00:00',
'task_id': 'print_the_context',
"map_index": -1,
},
{
'dag_id': 'example_python_operator',
'dag_run_id': 'TEST_DAG_RUN_ID_1',
'execution_date': '2020-01-02T00:00:00+00:00',
'task_id': 'sleep_for_0',
"map_index": -1,
},
{
'dag_id': 'example_python_operator',
'dag_run_id': 'TEST_DAG_RUN_ID_2',
'execution_date': '2020-01-03T00:00:00+00:00',
'task_id': 'sleep_for_1',
"map_index": -1,
},
{
'dag_id': 'example_python_operator',
'dag_run_id': 'TEST_DAG_RUN_ID_3',
'execution_date': '2020-01-04T00:00:00+00:00',
'task_id': 'sleep_for_2',
"map_index": -1,
},
{
'dag_id': 'example_python_operator',
'dag_run_id': 'TEST_DAG_RUN_ID_4',
'execution_date': '2020-01-05T00:00:00+00:00',
'task_id': 'sleep_for_3',
"map_index": -1,
},
]
for task_instance in expected_response:
Expand Down Expand Up @@ -1221,6 +1272,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi
'dag_run_id': 'TEST_DAG_RUN_ID',
'execution_date': '2020-01-01T00:00:00+00:00',
'task_id': 'print_the_context',
"map_index": -1,
}
]
}
Expand Down Expand Up @@ -1270,6 +1322,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_
'dag_run_id': 'TEST_DAG_RUN_ID',
'execution_date': '2020-01-01T00:00:00+00:00',
'task_id': 'print_the_context',
"map_index": -1,
}
]
}
Expand Down
14 changes: 11 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,11 @@ def __exit__(self, type, value, traceback):
self.serialized_model = SerializedDagModel(dag)
self.session.merge(self.serialized_model)
serialized_dag = self._serialized_dag()
self.dagbag.bag_dag(serialized_dag, root_dag=serialized_dag)
self.active_dagbag.bag_dag(serialized_dag, root_dag=serialized_dag)
self.session.flush()
else:
self.dagbag.bag_dag(self.dag, self.dag)
self.active_dagbag.bag_dag(self.dag, self.dag)
del self.active_dagbag

def create_dagrun(self, **kwargs):
from airflow.utils import timezone
Expand Down Expand Up @@ -568,7 +569,13 @@ def create_dagrun_after(self, dagrun, **kwargs):
)

def __call__(
self, dag_id='test_dag', serialized=want_serialized, fileloc=None, session=None, **kwargs
self,
dag_id='test_dag',
dagbag=None,
serialized=want_serialized,
fileloc=None,
session=None,
**kwargs,
):
from airflow import settings
from airflow.models import DAG
Expand All @@ -578,6 +585,7 @@ def __call__(
self._own_session = True
session = settings.Session()

self.active_dagbag = dagbag or self.dagbag
self.kwargs = kwargs
self.session = session
self.start_date = self.kwargs.get('start_date', None)
Expand Down