77
77
GetTICount ,
78
78
GetVariable ,
79
79
GetXCom ,
80
+ GetXComSequenceSlice ,
80
81
OKResponse ,
81
82
PrevSuccessfulDagRunResult ,
82
83
SetRenderedFields ,
91
92
TriggerDagRun ,
92
93
VariableResult ,
93
94
XComResult ,
95
+ XComSequenceSliceResult ,
94
96
)
95
97
from airflow .sdk .execution_time .context import (
96
98
ConnectionAccessor ,
@@ -1113,7 +1115,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s
1113
1115
task = BaseOperator (task_id = "hello" )
1114
1116
1115
1117
# Assume the context is sent from the API server
1116
- # `task_sdk /tests/api/test_client.py::test_task_instance_start` checks the context is received
1118
+ # `task-sdk /tests/api/test_client.py::test_task_instance_start` checks the context is received
1117
1119
# from the API server
1118
1120
runtime_ti = create_runtime_ti (task = task , dag_id = "basic_task" )
1119
1121
@@ -1387,7 +1389,17 @@ def execute(self, context):
1387
1389
runtime_ti = create_runtime_ti (task = task , ** extra_for_ti )
1388
1390
1389
1391
ser_value = BaseXCom .serialize_value (xcom_values )
1390
- mock_supervisor_comms .get_message .return_value = XComResult (key = "key" , value = ser_value )
1392
+
1393
+ def mock_get_message_side_effect (* args , ** kwargs ):
1394
+ calls = mock_supervisor_comms .send_request .call_args_list
1395
+ if calls :
1396
+ last_call = calls [- 1 ]
1397
+ msg = last_call [1 ]["msg" ]
1398
+ if isinstance (msg , GetXComSequenceSlice ):
1399
+ return XComSequenceSliceResult (root = [ser_value ])
1400
+ return XComResult (key = "key" , value = ser_value )
1401
+
1402
+ mock_supervisor_comms .get_message .side_effect = mock_get_message_side_effect
1391
1403
1392
1404
run (runtime_ti , context = runtime_ti .get_template_context (), log = mock .MagicMock ())
1393
1405
@@ -1403,25 +1415,37 @@ def execute(self, context):
1403
1415
task_id = test_task_id
1404
1416
for map_index in map_indexes :
1405
1417
if map_index == NOTSET :
1406
- map_index = - 1
1407
- mock_supervisor_comms .send_request .assert_any_call (
1408
- log = mock .ANY ,
1409
- msg = GetXCom (
1410
- key = "key" ,
1411
- dag_id = "test_dag" ,
1412
- run_id = "test_run" ,
1413
- task_id = task_id ,
1414
- map_index = map_index ,
1415
- ),
1416
- )
1418
+ mock_supervisor_comms .send_request .assert_any_call (
1419
+ log = mock .ANY ,
1420
+ msg = GetXComSequenceSlice (
1421
+ key = "key" ,
1422
+ dag_id = "test_dag" ,
1423
+ run_id = "test_run" ,
1424
+ task_id = task_id ,
1425
+ start = None ,
1426
+ stop = None ,
1427
+ step = None ,
1428
+ ),
1429
+ )
1430
+ else :
1431
+ expected_map_index = map_index if map_index is not None else None
1432
+ mock_supervisor_comms .send_request .assert_any_call (
1433
+ log = mock .ANY ,
1434
+ msg = GetXCom (
1435
+ key = "key" ,
1436
+ dag_id = "test_dag" ,
1437
+ run_id = "test_run" ,
1438
+ task_id = task_id ,
1439
+ map_index = expected_map_index ,
1440
+ ),
1441
+ )
1417
1442
1418
1443
@pytest .mark .parametrize (
1419
1444
"task_ids, map_indexes, expected_value" ,
1420
1445
[
1421
1446
pytest .param ("task_a" , 0 , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is int" ),
1422
1447
pytest .param ("task_a" , [0 ], [{"a" : 1 , "b" : 2 }], id = "task_id is str, map_index is list" ),
1423
1448
pytest .param ("task_a" , None , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is None" ),
1424
- pytest .param ("task_a" , NOTSET , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is ArgNotSet" ),
1425
1449
pytest .param (["task_a" ], 0 , [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is int" ),
1426
1450
pytest .param (["task_a" ], [0 ], [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is list" ),
1427
1451
pytest .param (["task_a" ], None , [{"a" : 1 , "b" : 2 }], id = "task_id is list, map_index is None" ),
@@ -1431,6 +1455,13 @@ def execute(self, context):
1431
1455
pytest .param (None , 0 , {"a" : 1 , "b" : 2 }, id = "task_id is None, map_index is int" ),
1432
1456
pytest .param (None , [0 ], [{"a" : 1 , "b" : 2 }], id = "task_id is None, map_index is list" ),
1433
1457
pytest .param (None , None , {"a" : 1 , "b" : 2 }, id = "task_id is None, map_index is None" ),
1458
+ pytest .param (
1459
+ ["task_a" , "task_b" ],
1460
+ NOTSET ,
1461
+ [{"a" : 1 , "b" : 2 }, {"c" : 3 , "d" : 4 }],
1462
+ id = "multiple task_ids, map_index is ArgNotSet" ,
1463
+ ),
1464
+ pytest .param ("task_a" , NOTSET , {"a" : 1 , "b" : 2 }, id = "task_id is str, map_index is ArgNotSet" ),
1434
1465
pytest .param (None , NOTSET , {"a" : 1 , "b" : 2 }, id = "task_id is None, map_index is ArgNotSet" ),
1435
1466
],
1436
1467
)
@@ -1444,7 +1475,7 @@ def test_xcom_pull_return_values(
1444
1475
):
1445
1476
"""
1446
1477
Tests return value of xcom_pull under various combinations of task_ids and map_indexes.
1447
- The above test covers the expected calls to supervisor comms .
1478
+ Also verifies the correct XCom method (get_one vs get_all) is called .
1448
1479
"""
1449
1480
1450
1481
class CustomOperator (BaseOperator ):
@@ -1455,13 +1486,28 @@ def execute(self, context):
1455
1486
task = CustomOperator (task_id = test_task_id )
1456
1487
runtime_ti = create_runtime_ti (task = task )
1457
1488
1458
- value = {"a" : 1 , "b" : 2 }
1459
- # API server returns serialised value for xcom result, staging it in that way
1460
- xcom_value = BaseXCom .serialize_value (value )
1461
- mock_supervisor_comms .get_message .return_value = XComResult (key = "key" , value = xcom_value )
1462
-
1463
- returned_xcom = runtime_ti .xcom_pull (key = "key" , task_ids = task_ids , map_indexes = map_indexes )
1464
- assert returned_xcom == expected_value
1489
+ with patch .object (XCom , "get_one" ) as mock_get_one , patch .object (XCom , "get_all" ) as mock_get_all :
1490
+ if map_indexes == NOTSET :
1491
+ # Use side_effect to return different values for different tasks
1492
+ def mock_get_all_side_effect (task_id , ** kwargs ):
1493
+ if task_id == "task_b" :
1494
+ return [{"c" : 3 , "d" : 4 }]
1495
+ return [{"a" : 1 , "b" : 2 }]
1496
+
1497
+ mock_get_all .side_effect = mock_get_all_side_effect
1498
+ mock_get_one .return_value = None
1499
+ else :
1500
+ mock_get_one .return_value = {"a" : 1 , "b" : 2 }
1501
+ mock_get_all .return_value = None
1502
+
1503
+ xcom = runtime_ti .xcom_pull (key = "key" , task_ids = task_ids , map_indexes = map_indexes )
1504
+ assert xcom == expected_value
1505
+ if map_indexes == NOTSET :
1506
+ assert mock_get_all .called
1507
+ assert not mock_get_one .called
1508
+ else :
1509
+ assert mock_get_one .called
1510
+ assert not mock_get_all .called
1465
1511
1466
1512
def test_get_param_from_context (
1467
1513
self , mocked_parse , make_ti_context , mock_supervisor_comms , create_runtime_ti
@@ -1910,13 +1956,11 @@ def execute(self, context):
1910
1956
runtime_ti = create_runtime_ti (task = task )
1911
1957
run (runtime_ti , context = runtime_ti .get_template_context (), log = mock .MagicMock ())
1912
1958
1913
- mock_xcom_backend .get_one .assert_called_once_with (
1959
+ mock_xcom_backend .get_all .assert_called_once_with (
1914
1960
key = "key" ,
1915
1961
dag_id = "test_dag" ,
1916
1962
task_id = "pull_task" ,
1917
1963
run_id = "test_run" ,
1918
- map_index = - 1 ,
1919
- include_prior_dates = False ,
1920
1964
)
1921
1965
1922
1966
assert not any (
0 commit comments