Skip to content

Improve xcom_pull to cover different scenarios for mapped tasks #51568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@

import structlog

from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult
from airflow.sdk.execution_time.comms import (
DeleteXCom,
GetXCom,
GetXComSequenceSlice,
SetXCom,
XComResult,
XComSequenceSliceResult,
)

log = structlog.get_logger(logger_name="task")

Expand Down Expand Up @@ -274,6 +281,56 @@ def get_one(
)
return None

@classmethod
def get_all(
cls,
*,
key: str,
dag_id: str,
task_id: str,
run_id: str,
) -> Any:
"""
Retrieve all XCom values for a task, typically from all map indexes.

XComSequenceSliceResult can never have *None* in it, it returns an empty list
if no values were found.

This is particularly useful for getting all XCom values from all map
indexes of a mapped task at once.

:param key: A key for the XCom. Only XComs with this key will be returned.
:param run_id: DAG run ID for the task.
:param dag_id: DAG ID to pull XComs from.
:param task_id: Task ID to pull XComs from.
:return: List of all XCom values if found.
"""
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

# Since Triggers can hit this code path via `sync_to_async` (which uses threads internally)
# we need to make sure that we "atomically" send a request and get the response to that
# back so that two triggers don't end up interleaving requests and create a possible
# race condition where the wrong trigger reads the response.
with SUPERVISOR_COMMS.lock:
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXComSequenceSlice(
key=key,
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
start=None,
stop=None,
step=None,
),
)
msg = SUPERVISOR_COMMS.get_message()

if not isinstance(msg, XComSequenceSliceResult):
raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}")

return msg.root

@staticmethod
def serialize_value(
value: Any,
Expand Down
30 changes: 21 additions & 9 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,35 @@ def xcom_pull(
run_id = self.run_id

single_task_requested = isinstance(task_ids, (str, type(None)))
single_map_index_requested = isinstance(map_indexes, (int, type(None), ArgNotSet))
single_map_index_requested = isinstance(map_indexes, (int, type(None)))

if task_ids is None:
# default to the current task if not provided
task_ids = [self.task_id]
elif isinstance(task_ids, str):
task_ids = [task_ids]

map_indexes_iterable: Iterable[int | None] = []
# If map_indexes is not provided, default to use the map_index of the calling task
# If map_indexes is not specified, pull xcoms from all map indexes for each task
if isinstance(map_indexes, ArgNotSet):
map_indexes_iterable = [self.map_index]
elif isinstance(map_indexes, int) or map_indexes is None:
xcoms = [
value
for t_id in task_ids
for value in XCom.get_all(
run_id=run_id,
key=key,
task_id=t_id,
dag_id=dag_id,
)
]

# For single task pulling from unmapped task, return single value
if single_task_requested and len(xcoms) == 1:
return xcoms[0]
return xcoms

# Original logic when map_indexes is explicitly specified
map_indexes_iterable: Iterable[int | None] = []
if isinstance(map_indexes, int) or map_indexes is None:
map_indexes_iterable = [map_indexes]
elif isinstance(map_indexes, Iterable):
map_indexes_iterable = map_indexes
Expand All @@ -360,10 +376,6 @@ def xcom_pull(
)

xcoms = []
# TODO: AIP 72 Execution API only allows working with a single map_index at a time
# this is inefficient and leads to task_id * map_index requests to the API.
# And we can't achieve the original behavior of XCom pull with multiple tasks
# directly now.
for t_id, m_idx in product(task_ids, map_indexes_iterable):
value = XCom.get_one(
run_id=run_id,
Expand Down
94 changes: 69 additions & 25 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
GetTICount,
GetVariable,
GetXCom,
GetXComSequenceSlice,
OKResponse,
PrevSuccessfulDagRunResult,
SetRenderedFields,
Expand All @@ -91,6 +92,7 @@
TriggerDagRun,
VariableResult,
XComResult,
XComSequenceSliceResult,
)
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
Expand Down Expand Up @@ -1113,7 +1115,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s
task = BaseOperator(task_id="hello")

# Assume the context is sent from the API server
# `task_sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received
# `task-sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received
# from the API server
runtime_ti = create_runtime_ti(task=task, dag_id="basic_task")

Expand Down Expand Up @@ -1387,7 +1389,17 @@ def execute(self, context):
runtime_ti = create_runtime_ti(task=task, **extra_for_ti)

ser_value = BaseXCom.serialize_value(xcom_values)
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=ser_value)

def mock_get_message_side_effect(*args, **kwargs):
calls = mock_supervisor_comms.send_request.call_args_list
if calls:
last_call = calls[-1]
msg = last_call[1]["msg"]
if isinstance(msg, GetXComSequenceSlice):
return XComSequenceSliceResult(root=[ser_value])
return XComResult(key="key", value=ser_value)

mock_supervisor_comms.get_message.side_effect = mock_get_message_side_effect

run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

Expand All @@ -1403,25 +1415,37 @@ def execute(self, context):
task_id = test_task_id
for map_index in map_indexes:
if map_index == NOTSET:
map_index = -1
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=map_index,
),
)
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXComSequenceSlice(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
start=None,
stop=None,
step=None,
),
)
else:
expected_map_index = map_index if map_index is not None else None
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=expected_map_index,
),
)

@pytest.mark.parametrize(
"task_ids, map_indexes, expected_value",
[
pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"),
pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"),
pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"),
pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"),
pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"),
pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"),
pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"),
Expand All @@ -1431,6 +1455,13 @@ def execute(self, context):
pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"),
pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"),
pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"),
pytest.param(
["task_a", "task_b"],
NOTSET,
[{"a": 1, "b": 2}, {"c": 3, "d": 4}],
id="multiple task_ids, map_index is ArgNotSet",
),
pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"),
pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"),
],
)
Expand All @@ -1444,7 +1475,7 @@ def test_xcom_pull_return_values(
):
"""
Tests return value of xcom_pull under various combinations of task_ids and map_indexes.
The above test covers the expected calls to supervisor comms.
Also verifies the correct XCom method (get_one vs get_all) is called.
"""

class CustomOperator(BaseOperator):
Expand All @@ -1455,13 +1486,28 @@ def execute(self, context):
task = CustomOperator(task_id=test_task_id)
runtime_ti = create_runtime_ti(task=task)

value = {"a": 1, "b": 2}
# API server returns serialised value for xcom result, staging it in that way
xcom_value = BaseXCom.serialize_value(value)
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value)

returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
assert returned_xcom == expected_value
with patch.object(XCom, "get_one") as mock_get_one, patch.object(XCom, "get_all") as mock_get_all:
if map_indexes == NOTSET:
# Use side_effect to return different values for different tasks
def mock_get_all_side_effect(task_id, **kwargs):
if task_id == "task_b":
return [{"c": 3, "d": 4}]
return [{"a": 1, "b": 2}]

mock_get_all.side_effect = mock_get_all_side_effect
mock_get_one.return_value = None
else:
mock_get_one.return_value = {"a": 1, "b": 2}
mock_get_all.return_value = None

xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
assert xcom == expected_value
if map_indexes == NOTSET:
assert mock_get_all.called
assert not mock_get_one.called
else:
assert mock_get_one.called
assert not mock_get_all.called

def test_get_param_from_context(
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
Expand Down Expand Up @@ -1910,13 +1956,11 @@ def execute(self, context):
runtime_ti = create_runtime_ti(task=task)
run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())

mock_xcom_backend.get_one.assert_called_once_with(
mock_xcom_backend.get_all.assert_called_once_with(
key="key",
dag_id="test_dag",
task_id="pull_task",
run_id="test_run",
map_index=-1,
include_prior_dates=False,
)

assert not any(
Expand Down