Skip to content

Commit 8324dae

Browse files
Fix the dependency handling logic for empty task groups (apache#49034)
1 parent 04c049b commit 8324dae

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

task-sdk/src/airflow/sdk/definitions/taskgroup.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,29 @@ def _set_relatives(
343343
if not isinstance(task_or_task_list, Sequence):
344344
task_or_task_list = [task_or_task_list]
345345

346+
# Helper function to find leaves from a task list or task group
347+
def find_leaves(group_or_task) -> list[Any]:
348+
while group_or_task:
349+
group_or_task_leaves = list(group_or_task.get_leaves())
350+
if group_or_task_leaves:
351+
return group_or_task_leaves
352+
if group_or_task.upstream_task_ids:
353+
upstream_task_ids_list = list(group_or_task.upstream_task_ids)
354+
return [self.dag.get_task(task_id) for task_id in upstream_task_ids_list]
355+
group_or_task = group_or_task.parent_group
356+
return []
357+
358+
# Check if the current TaskGroup is empty
359+
leaves = find_leaves(self)
360+
346361
for task_like in task_or_task_list:
347362
self.update_relative(task_like, upstream, edge_modifier=edge_modifier)
348363

349364
if upstream:
350365
for task in self.get_roots():
351366
task.set_upstream(task_or_task_list)
352367
else:
353-
for task in self.get_leaves():
368+
for task in leaves: # Use the fetched leaves
354369
task.set_downstream(task_or_task_list)
355370

356371
def __enter__(self) -> TaskGroup:

task-sdk/tests/task_sdk/definitions/test_taskgroup.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
from __future__ import annotations
1919

20+
import pendulum
2021
import pytest
2122

23+
from airflow.providers.standard.operators.empty import EmptyOperator
24+
from airflow.sdk.definitions.dag import DAG
2225
from airflow.sdk.definitions.taskgroup import TaskGroup
2326

2427

@@ -51,3 +54,69 @@ def test_dag_id_validation(self, group_id, exc_type, exc_value):
5154
with pytest.raises(exc_type) as ctx:
5255
TaskGroup(group_id)
5356
assert str(ctx.value) == exc_value
57+
58+
59+
def test_task_group_dependencies_between_tasks_if_task_group_is_empty_1():
60+
"""
61+
Test that if a task group is empty, the dependencies between tasks are still maintained.
62+
"""
63+
with DAG(dag_id="test_dag", schedule=None, start_date=pendulum.parse("20200101")):
64+
task1 = EmptyOperator(task_id="task1")
65+
with TaskGroup("group1") as tg1:
66+
pass
67+
with TaskGroup("group2") as tg2:
68+
task2 = EmptyOperator(task_id="task2")
69+
task3 = EmptyOperator(task_id="task3")
70+
task2 >> task3
71+
72+
task1 >> tg1 >> tg2
73+
74+
assert task1.downstream_task_ids == {"group2.task2"}
75+
76+
77+
def test_task_group_dependencies_between_tasks_if_task_group_is_empty_2():
78+
"""
79+
Test that if a task group is empty, the dependencies between tasks are still maintained.
80+
"""
81+
with DAG(dag_id="test_dag", schedule=None, start_date=pendulum.parse("20200101")):
82+
task1 = EmptyOperator(task_id="task1")
83+
with TaskGroup("group1") as tg1:
84+
pass
85+
with TaskGroup("group2") as tg2:
86+
pass
87+
with TaskGroup("group3") as tg3:
88+
pass
89+
with TaskGroup("group4") as tg4:
90+
pass
91+
with TaskGroup("group5") as tg5:
92+
task2 = EmptyOperator(task_id="task2")
93+
task3 = EmptyOperator(task_id="task3")
94+
task2 >> task3
95+
task1 >> tg1 >> tg2 >> tg3 >> tg4 >> tg5
96+
97+
assert task1.downstream_task_ids == {"group5.task2"}
98+
99+
100+
def test_task_group_dependencies_between_tasks_if_task_group_is_empty_3():
101+
"""
102+
Test that if a task group is empty, the dependencies between tasks are still maintained.
103+
"""
104+
with DAG(dag_id="test_dag", schedule=None, start_date=pendulum.parse("20200101")):
105+
task1 = EmptyOperator(task_id="task1")
106+
with TaskGroup("group1") as tg1:
107+
pass
108+
with TaskGroup("group2") as tg2:
109+
pass
110+
task2 = EmptyOperator(task_id="task2")
111+
with TaskGroup("group3") as tg3:
112+
pass
113+
with TaskGroup("group4") as tg4:
114+
pass
115+
with TaskGroup("group5") as tg5:
116+
task3 = EmptyOperator(task_id="task3")
117+
task4 = EmptyOperator(task_id="task4")
118+
task3 >> task4
119+
task1 >> tg1 >> tg2 >> task2 >> tg3 >> tg4 >> tg5
120+
121+
assert task1.downstream_task_ids == {"task2"}
122+
assert task2.downstream_task_ids == {"group5.task3"}

0 commit comments

Comments
 (0)