Skip to content

Commit a5211f2

Browse files
authored
Run Task failure callbacks on DAG Processor when task is externally killed (#53058)
Until #44354 is implemented, tasks killed externally or when supervisor process dies unexpectedly, users have no way of knowing this happened. This has been a blocker for Airflow 3.0 adoption for some: - #44354 - https://apache-airflow.slack.com/archives/C07813CNKA8/p1751057525231389 #44354 is more involved and we might not get to it for Airflow 3.1 -- so this is a good fix until then similar to how we run Dag Run callback.
1 parent 3eaf4e9 commit a5211f2

File tree

8 files changed

+469
-123
lines changed

8 files changed

+469
-123
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class TIRunContext(BaseModel):
301301
dag_run: DagRun
302302
"""DAG run information for the task instance."""
303303

304-
task_reschedule_count: Annotated[int, Field(default=0)]
304+
task_reschedule_count: int = 0
305305
"""How many times the task has been rescheduled."""
306306

307307
max_tries: int
@@ -327,7 +327,7 @@ class TIRunContext(BaseModel):
327327
xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)]
328328
"""List of Xcom keys that need to be cleared and purged on by the worker."""
329329

330-
should_retry: bool
330+
should_retry: bool = False
331331
"""If the ti encounters an error, whether it should enter retry or failed state."""
332332

333333

airflow-core/src/airflow/callbacks/callback_requests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class TaskCallbackRequest(BaseCallbackRequest):
6161
"""Simplified Task Instance representation"""
6262
task_callback_type: TaskInstanceState | None = None
6363
"""Whether on success, on failure, on retry"""
64+
context_from_server: ti_datamodel.TIRunContext | None = None
65+
"""Task execution context from the Server"""
6466
type: Literal["TaskCallbackRequest"] = "TaskCallbackRequest"
6567

6668
@property

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import contextlib
1920
import importlib
2021
import os
2122
import sys
2223
import traceback
23-
from collections.abc import Callable
24+
from collections.abc import Callable, Sequence
2425
from pathlib import Path
2526
from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Literal
2627

@@ -45,9 +46,11 @@
4546
VariableResult,
4647
)
4748
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
49+
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
4850
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
4951
from airflow.stats import Stats
5052
from airflow.utils.file import iter_airflow_imports
53+
from airflow.utils.state import TaskInstanceState
5154

5255
if TYPE_CHECKING:
5356
from structlog.typing import FilteringBoundLogger
@@ -201,10 +204,7 @@ def _execute_callbacks(
201204
for request in callback_requests:
202205
log.debug("Processing Callback Request", request=request.to_json())
203206
if isinstance(request, TaskCallbackRequest):
204-
raise NotImplementedError(
205-
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
206-
)
207-
# _execute_task_callbacks(dagbag, request)
207+
_execute_task_callbacks(dagbag, request, log)
208208
if isinstance(request, DagCallbackRequest):
209209
_execute_dag_callbacks(dagbag, request, log)
210210

@@ -238,6 +238,67 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
238238
Stats.incr("dag.callback_exceptions", tags={"dag_id": request.dag_id})
239239

240240

241+
def _execute_task_callbacks(dagbag: DagBag, request: TaskCallbackRequest, log: FilteringBoundLogger) -> None:
242+
if not request.is_failure_callback:
243+
log.warning(
244+
"Task callback requested but is not a failure callback",
245+
dag_id=request.ti.dag_id,
246+
task_id=request.ti.task_id,
247+
run_id=request.ti.run_id,
248+
)
249+
return
250+
251+
dag = dagbag.dags[request.ti.dag_id]
252+
task = dag.get_task(request.ti.task_id)
253+
254+
if request.task_callback_type is TaskInstanceState.UP_FOR_RETRY:
255+
callbacks = task.on_retry_callback
256+
else:
257+
callbacks = task.on_failure_callback
258+
259+
if not callbacks:
260+
log.warning(
261+
"Callback requested but no callback found",
262+
dag_id=request.ti.dag_id,
263+
task_id=request.ti.task_id,
264+
run_id=request.ti.run_id,
265+
ti_id=request.ti.id,
266+
)
267+
return
268+
269+
callbacks = callbacks if isinstance(callbacks, Sequence) else [callbacks]
270+
ctx_from_server = request.context_from_server
271+
272+
if ctx_from_server is not None:
273+
runtime_ti = RuntimeTaskInstance.model_construct(
274+
**request.ti.model_dump(exclude_unset=True),
275+
task=task,
276+
_ti_context_from_server=ctx_from_server,
277+
max_tries=ctx_from_server.max_tries,
278+
)
279+
else:
280+
runtime_ti = RuntimeTaskInstance.model_construct(
281+
**request.ti.model_dump(exclude_unset=True),
282+
task=task,
283+
)
284+
context = runtime_ti.get_template_context()
285+
286+
def get_callback_representation(callback):
287+
with contextlib.suppress(AttributeError):
288+
return callback.__name__
289+
with contextlib.suppress(AttributeError):
290+
return callback.__class__.__name__
291+
return callback
292+
293+
for idx, callback in enumerate(callbacks):
294+
callback_repr = get_callback_representation(callback)
295+
log.info("Executing Task callback at index %d: %s", idx, callback_repr)
296+
try:
297+
callback(context)
298+
except Exception:
299+
log.exception("Error in callback at index %d: %s", idx, callback_repr)
300+
301+
241302
def in_process_api_server() -> InProcessExecutionAPI:
242303
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
243304

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from sqlalchemy.sql import expression
3939

4040
from airflow import settings
41+
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext
4142
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
4243
from airflow.configuration import conf
4344
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
@@ -927,10 +928,16 @@ def process_executor_events(
927928
bundle_version=ti.dag_version.bundle_version,
928929
ti=ti,
929930
msg=msg,
931+
context_from_server=TIRunContext(
932+
dag_run=ti.dag_run,
933+
max_tries=ti.max_tries,
934+
variables=[],
935+
connections=[],
936+
xcom_keys_to_clear=[],
937+
),
930938
)
931939
executor.send_callback(request)
932-
else:
933-
ti.handle_failure(error=msg, session=session)
940+
ti.handle_failure(error=msg, session=session)
934941

935942
return len(event_buffer)
936943

@@ -2283,6 +2290,13 @@ def _purge_task_instances_without_heartbeats(
22832290
bundle_version=ti.dag_run.bundle_version,
22842291
ti=ti,
22852292
msg=str(task_instance_heartbeat_timeout_message_details),
2293+
context_from_server=TIRunContext(
2294+
dag_run=ti.dag_run,
2295+
max_tries=ti.max_tries,
2296+
variables=[],
2297+
connections=[],
2298+
xcom_keys_to_clear=[],
2299+
),
22862300
)
22872301
session.add(
22882302
Log(

airflow-core/tests/unit/callbacks/test_callback_requests.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from airflow.models.taskinstance import TaskInstance
3030
from airflow.providers.standard.operators.bash import BashOperator
3131
from airflow.utils import timezone
32-
from airflow.utils.state import State
32+
from airflow.utils.state import State, TaskInstanceState
3333

3434
pytestmark = pytest.mark.db_test
3535

@@ -87,3 +87,30 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
8787
json_str = input.to_json()
8888
result = TaskCallbackRequest.from_json(json_str)
8989
assert input == result
90+
91+
@pytest.mark.parametrize(
92+
"task_callback_type,expected_is_failure",
93+
[
94+
(None, True),
95+
(TaskInstanceState.FAILED, True),
96+
(TaskInstanceState.UP_FOR_RETRY, True),
97+
(TaskInstanceState.UPSTREAM_FAILED, True),
98+
(TaskInstanceState.SUCCESS, False),
99+
(TaskInstanceState.RUNNING, False),
100+
],
101+
)
102+
def test_is_failure_callback_property(
103+
self, task_callback_type, expected_is_failure, create_task_instance
104+
):
105+
"""Test is_failure_callback property with different task callback types"""
106+
ti = create_task_instance()
107+
108+
request = TaskCallbackRequest(
109+
filepath="filepath",
110+
ti=ti,
111+
bundle_name="testing",
112+
bundle_version=None,
113+
task_callback_type=task_callback_type,
114+
)
115+
116+
assert request.is_failure_callback == expected_is_failure

0 commit comments

Comments
 (0)