Skip to content

Commit 0c5ec1e

Browse files
amoghrajeshchoo121600
authored andcommitted
Improve xcom_pull to cover different scenarios for mapped tasks (apache#51568)
1 parent 54458d2 commit 0c5ec1e

File tree

3 files changed

+148
-35
lines changed

3 files changed

+148
-35
lines changed

task-sdk/src/airflow/sdk/bases/xcom.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121

2222
import structlog
2323

24-
from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult
24+
from airflow.sdk.execution_time.comms import (
25+
DeleteXCom,
26+
GetXCom,
27+
GetXComSequenceSlice,
28+
SetXCom,
29+
XComResult,
30+
XComSequenceSliceResult,
31+
)
2532

2633
log = structlog.get_logger(logger_name="task")
2734

@@ -274,6 +281,56 @@ def get_one(
274281
)
275282
return None
276283

284+
@classmethod
285+
def get_all(
286+
cls,
287+
*,
288+
key: str,
289+
dag_id: str,
290+
task_id: str,
291+
run_id: str,
292+
) -> Any:
293+
"""
294+
Retrieve all XCom values for a task, typically from all map indexes.
295+
296+
XComSequenceSliceResult can never have *None* in it, it returns an empty list
297+
if no values were found.
298+
299+
This is particularly useful for getting all XCom values from all map
300+
indexes of a mapped task at once.
301+
302+
:param key: A key for the XCom. Only XComs with this key will be returned.
303+
:param run_id: DAG run ID for the task.
304+
:param dag_id: DAG ID to pull XComs from.
305+
:param task_id: Task ID to pull XComs from.
306+
:return: List of all XCom values if found.
307+
"""
308+
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
309+
310+
# Since Triggers can hit this code path via `sync_to_async` (which uses threads internally)
311+
# we need to make sure that we "atomically" send a request and get the response to that
312+
# back so that two triggers don't end up interleaving requests and create a possible
313+
# race condition where the wrong trigger reads the response.
314+
with SUPERVISOR_COMMS.lock:
315+
SUPERVISOR_COMMS.send_request(
316+
log=log,
317+
msg=GetXComSequenceSlice(
318+
key=key,
319+
dag_id=dag_id,
320+
task_id=task_id,
321+
run_id=run_id,
322+
start=None,
323+
stop=None,
324+
step=None,
325+
),
326+
)
327+
msg = SUPERVISOR_COMMS.get_message()
328+
329+
if not isinstance(msg, XComSequenceSliceResult):
330+
raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}")
331+
332+
return msg.root
333+
277334
@staticmethod
278335
def serialize_value(
279336
value: Any,

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,19 +338,35 @@ def xcom_pull(
338338
run_id = self.run_id
339339

340340
single_task_requested = isinstance(task_ids, (str, type(None)))
341-
single_map_index_requested = isinstance(map_indexes, (int, type(None), ArgNotSet))
341+
single_map_index_requested = isinstance(map_indexes, (int, type(None)))
342342

343343
if task_ids is None:
344344
# default to the current task if not provided
345345
task_ids = [self.task_id]
346346
elif isinstance(task_ids, str):
347347
task_ids = [task_ids]
348348

349-
map_indexes_iterable: Iterable[int | None] = []
350-
# If map_indexes is not provided, default to use the map_index of the calling task
349+
# If map_indexes is not specified, pull xcoms from all map indexes for each task
351350
if isinstance(map_indexes, ArgNotSet):
352-
map_indexes_iterable = [self.map_index]
353-
elif isinstance(map_indexes, int) or map_indexes is None:
351+
xcoms = [
352+
value
353+
for t_id in task_ids
354+
for value in XCom.get_all(
355+
run_id=run_id,
356+
key=key,
357+
task_id=t_id,
358+
dag_id=dag_id,
359+
)
360+
]
361+
362+
# For single task pulling from unmapped task, return single value
363+
if single_task_requested and len(xcoms) == 1:
364+
return xcoms[0]
365+
return xcoms
366+
367+
# Original logic when map_indexes is explicitly specified
368+
map_indexes_iterable: Iterable[int | None] = []
369+
if isinstance(map_indexes, int) or map_indexes is None:
354370
map_indexes_iterable = [map_indexes]
355371
elif isinstance(map_indexes, Iterable):
356372
map_indexes_iterable = map_indexes
@@ -360,10 +376,6 @@ def xcom_pull(
360376
)
361377

362378
xcoms = []
363-
# TODO: AIP 72 Execution API only allows working with a single map_index at a time
364-
# this is inefficient and leads to task_id * map_index requests to the API.
365-
# And we can't achieve the original behavior of XCom pull with multiple tasks
366-
# directly now.
367379
for t_id, m_idx in product(task_ids, map_indexes_iterable):
368380
value = XCom.get_one(
369381
run_id=run_id,

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
GetTICount,
7878
GetVariable,
7979
GetXCom,
80+
GetXComSequenceSlice,
8081
OKResponse,
8182
PrevSuccessfulDagRunResult,
8283
SetRenderedFields,
@@ -91,6 +92,7 @@
9192
TriggerDagRun,
9293
VariableResult,
9394
XComResult,
95+
XComSequenceSliceResult,
9496
)
9597
from airflow.sdk.execution_time.context import (
9698
ConnectionAccessor,
@@ -1113,7 +1115,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s
11131115
task = BaseOperator(task_id="hello")
11141116

11151117
# Assume the context is sent from the API server
1116-
# `task_sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received
1118+
# `task-sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received
11171119
# from the API server
11181120
runtime_ti = create_runtime_ti(task=task, dag_id="basic_task")
11191121

@@ -1387,7 +1389,17 @@ def execute(self, context):
13871389
runtime_ti = create_runtime_ti(task=task, **extra_for_ti)
13881390

13891391
ser_value = BaseXCom.serialize_value(xcom_values)
1390-
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=ser_value)
1392+
1393+
def mock_get_message_side_effect(*args, **kwargs):
1394+
calls = mock_supervisor_comms.send_request.call_args_list
1395+
if calls:
1396+
last_call = calls[-1]
1397+
msg = last_call[1]["msg"]
1398+
if isinstance(msg, GetXComSequenceSlice):
1399+
return XComSequenceSliceResult(root=[ser_value])
1400+
return XComResult(key="key", value=ser_value)
1401+
1402+
mock_supervisor_comms.get_message.side_effect = mock_get_message_side_effect
13911403

13921404
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
13931405

@@ -1403,25 +1415,37 @@ def execute(self, context):
14031415
task_id = test_task_id
14041416
for map_index in map_indexes:
14051417
if map_index == NOTSET:
1406-
map_index = -1
1407-
mock_supervisor_comms.send_request.assert_any_call(
1408-
log=mock.ANY,
1409-
msg=GetXCom(
1410-
key="key",
1411-
dag_id="test_dag",
1412-
run_id="test_run",
1413-
task_id=task_id,
1414-
map_index=map_index,
1415-
),
1416-
)
1418+
mock_supervisor_comms.send_request.assert_any_call(
1419+
log=mock.ANY,
1420+
msg=GetXComSequenceSlice(
1421+
key="key",
1422+
dag_id="test_dag",
1423+
run_id="test_run",
1424+
task_id=task_id,
1425+
start=None,
1426+
stop=None,
1427+
step=None,
1428+
),
1429+
)
1430+
else:
1431+
expected_map_index = map_index if map_index is not None else None
1432+
mock_supervisor_comms.send_request.assert_any_call(
1433+
log=mock.ANY,
1434+
msg=GetXCom(
1435+
key="key",
1436+
dag_id="test_dag",
1437+
run_id="test_run",
1438+
task_id=task_id,
1439+
map_index=expected_map_index,
1440+
),
1441+
)
14171442

14181443
@pytest.mark.parametrize(
14191444
"task_ids, map_indexes, expected_value",
14201445
[
14211446
pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"),
14221447
pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"),
14231448
pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"),
1424-
pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"),
14251449
pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"),
14261450
pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"),
14271451
pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"),
@@ -1431,6 +1455,13 @@ def execute(self, context):
14311455
pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"),
14321456
pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"),
14331457
pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"),
1458+
pytest.param(
1459+
["task_a", "task_b"],
1460+
NOTSET,
1461+
[{"a": 1, "b": 2}, {"c": 3, "d": 4}],
1462+
id="multiple task_ids, map_index is ArgNotSet",
1463+
),
1464+
pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"),
14341465
pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"),
14351466
],
14361467
)
@@ -1444,7 +1475,7 @@ def test_xcom_pull_return_values(
14441475
):
14451476
"""
14461477
Tests return value of xcom_pull under various combinations of task_ids and map_indexes.
1447-
The above test covers the expected calls to supervisor comms.
1478+
Also verifies the correct XCom method (get_one vs get_all) is called.
14481479
"""
14491480

14501481
class CustomOperator(BaseOperator):
@@ -1455,13 +1486,28 @@ def execute(self, context):
14551486
task = CustomOperator(task_id=test_task_id)
14561487
runtime_ti = create_runtime_ti(task=task)
14571488

1458-
value = {"a": 1, "b": 2}
1459-
# API server returns serialised value for xcom result, staging it in that way
1460-
xcom_value = BaseXCom.serialize_value(value)
1461-
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value)
1462-
1463-
returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
1464-
assert returned_xcom == expected_value
1489+
with patch.object(XCom, "get_one") as mock_get_one, patch.object(XCom, "get_all") as mock_get_all:
1490+
if map_indexes == NOTSET:
1491+
# Use side_effect to return different values for different tasks
1492+
def mock_get_all_side_effect(task_id, **kwargs):
1493+
if task_id == "task_b":
1494+
return [{"c": 3, "d": 4}]
1495+
return [{"a": 1, "b": 2}]
1496+
1497+
mock_get_all.side_effect = mock_get_all_side_effect
1498+
mock_get_one.return_value = None
1499+
else:
1500+
mock_get_one.return_value = {"a": 1, "b": 2}
1501+
mock_get_all.return_value = None
1502+
1503+
xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
1504+
assert xcom == expected_value
1505+
if map_indexes == NOTSET:
1506+
assert mock_get_all.called
1507+
assert not mock_get_one.called
1508+
else:
1509+
assert mock_get_one.called
1510+
assert not mock_get_all.called
14651511

14661512
def test_get_param_from_context(
14671513
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
@@ -1910,13 +1956,11 @@ def execute(self, context):
19101956
runtime_ti = create_runtime_ti(task=task)
19111957
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
19121958

1913-
mock_xcom_backend.get_one.assert_called_once_with(
1959+
mock_xcom_backend.get_all.assert_called_once_with(
19141960
key="key",
19151961
dag_id="test_dag",
19161962
task_id="pull_task",
19171963
run_id="test_run",
1918-
map_index=-1,
1919-
include_prior_dates=False,
19201964
)
19211965

19221966
assert not any(

0 commit comments

Comments
 (0)