Skip to content

Improve xcom_pull to cover different scenarios for mapped tasks #51568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 12, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding tests
  • Loading branch information
amoghrajesh committed Jun 10, 2025
commit 8b72c73fc121dce040f06f28cf1083e7c8461060
41 changes: 31 additions & 10 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
@@ -1421,7 +1421,7 @@ def execute(self, context):
pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"),
pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"),
pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"),
pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"),
pytest.param("task_a", NOTSET, [{"a": 1, "b": 2}], id="task_id is str, map_index is ArgNotSet"),
pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"),
pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"),
pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"),
@@ -1431,7 +1431,13 @@ def execute(self, context):
pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"),
pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"),
pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"),
pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"),
pytest.param(None, NOTSET, [{"a": 1, "b": 2}], id="task_id is None, map_index is ArgNotSet"),
pytest.param(
["task_a", "task_b"],
NOTSET,
[{"a": 1, "b": 2}, {"c": 3, "d": 4}],
id="multiple task_ids, map_index is ArgNotSet",
),
],
)
def test_xcom_pull_return_values(
@@ -1444,7 +1450,7 @@ def test_xcom_pull_return_values(
):
"""
Tests return value of xcom_pull under various combinations of task_ids and map_indexes.
The above test covers the expected calls to supervisor comms.
Also verifies the correct XCom method (get_one vs get_all) is called.
"""

class CustomOperator(BaseOperator):
@@ -1455,13 +1461,28 @@ def execute(self, context):
task = CustomOperator(task_id=test_task_id)
runtime_ti = create_runtime_ti(task=task)

value = {"a": 1, "b": 2}
# API server returns serialised value for xcom result, staging it in that way
xcom_value = BaseXCom.serialize_value(value)
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value)

returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
assert returned_xcom == expected_value
with patch.object(XCom, "get_one") as mock_get_one, patch.object(XCom, "get_all") as mock_get_all:
if map_indexes == NOTSET:
# Use side_effect to return different values for different tasks
def mock_get_all_side_effect(*, task_id, **kwargs):
if task_id == "task_b":
return [{"c": 3, "d": 4}]
return [{"a": 1, "b": 2}]

mock_get_all.side_effect = mock_get_all_side_effect
mock_get_one.return_value = None
else:
mock_get_one.return_value = {"a": 1, "b": 2}
mock_get_all.return_value = None

xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes)
assert xcom == expected_value
if map_indexes == NOTSET:
assert mock_get_all.called
assert not mock_get_one.called
else:
assert mock_get_one.called
assert not mock_get_all.called

def test_get_param_from_context(
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti