Skip to content

Commit fab3d5a

Browse files
authored
Allow Asset decorator to work with any TaskFlow operator (#51229)
1 parent d8086e6 commit fab3d5a

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

airflow-core/docs/authoring-and-scheduling/assets.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,18 @@ The other way around also applies:
229229
def process_example_asset(example_asset):
230230
"""Process inlet example_asset..."""
231231
232+
In addition, ``@asset`` can be used with ``@task`` to customize the task that generates the asset,
233+
utilizing the modern TaskFlow approach described in :doc:`/tutorial/taskflow`.
234+
235+
This combination allows you to set initial arguments for the task and to use various operators, such as the ``BashOperator``:
236+
237+
.. code-block:: python
238+
239+
@asset(schedule=None)
240+
@task.bash(retries=3)
241+
def example_asset():
242+
"""Write to example_asset, from a Bash task with 3 retries..."""
243+
return "echo 'run'"
232244
233245
Output to multiple assets in one task
234246
-------------------------------------

task-sdk/src/airflow/sdk/definitions/asset/decorators.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import inspect
21-
from typing import TYPE_CHECKING, Any
21+
from typing import TYPE_CHECKING, Any, cast
2222

2323
import attrs
2424

@@ -30,6 +30,7 @@
3030
from collections.abc import Callable, Collection, Iterator, Mapping
3131

3232
from airflow.sdk import DAG, AssetAlias, ObjectStoragePath
33+
from airflow.sdk.bases.decorator import _TaskDecorator
3334
from airflow.sdk.definitions.asset import AssetUniqueKey
3435
from airflow.sdk.definitions.dag import DagStateChangeCallback, ScheduleArg
3536
from airflow.sdk.definitions.param import ParamsDict
@@ -96,6 +97,18 @@ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
9697
return dict(self._iter_kwargs(context))
9798

9899

100+
def _instantiate_task(definition: AssetDefinition | MultiAssetDefinition) -> None:
101+
decorated_operator = cast("_TaskDecorator", definition._function)
102+
if getattr(decorated_operator, "_airflow_is_task_decorator", False):
103+
if "outlets" in decorated_operator.kwargs:
104+
raise TypeError("@task decorator with 'outlets' argument is not supported in @asset")
105+
106+
decorated_operator.kwargs["outlets"] = [v for _, v in definition.iter_assets()]
107+
decorated_operator()
108+
else:
109+
_AssetMainOperator.from_definition(definition)
110+
111+
99112
@attrs.define(kw_only=True)
100113
class AssetDefinition(Asset):
101114
"""
@@ -109,7 +122,7 @@ class AssetDefinition(Asset):
109122

110123
def __attrs_post_init__(self) -> None:
111124
with self._source.create_dag(default_dag_id=self.name):
112-
_AssetMainOperator.from_definition(self)
125+
_instantiate_task(self)
113126

114127

115128
@attrs.define(kw_only=True)
@@ -129,7 +142,7 @@ class MultiAssetDefinition(BaseAsset):
129142

130143
def __attrs_post_init__(self) -> None:
131144
with self._source.create_dag(default_dag_id=self._function.__name__):
132-
_AssetMainOperator.from_definition(self)
145+
_instantiate_task(self)
133146

134147
def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
135148
for o in self._source.outlets:

task-sdk/tests/task_sdk/definitions/test_asset_decorators.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from airflow.sdk.definitions.asset import Asset
2424
from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
25+
from airflow.sdk.definitions.decorators import task
2526
from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
2627

2728

@@ -159,6 +160,27 @@ def example_asset_func(self, /):
159160
== "positional-only argument 'self' without a default is not supported in @asset"
160161
)
161162

163+
def test_with_task_decorator(self, func_fixer):
164+
@task(retries=3)
165+
@func_fixer
166+
def _example_task_func():
167+
return "This is example_task"
168+
169+
asset_definition = asset(name="asset", dag_id="dag", schedule=None)(_example_task_func)
170+
assert asset_definition.name == "asset"
171+
assert asset_definition._source.dag_id == "dag"
172+
assert asset_definition._function == _example_task_func
173+
174+
def test_with_task_decorator_and_outlets(self, func_fixer):
175+
@task(retries=3, outlets=Asset(name="a"))
176+
@func_fixer
177+
def _example_task_func():
178+
return "This is example_task"
179+
180+
with pytest.raises(TypeError) as err:
181+
asset(schedule=None)(_example_task_func)
182+
assert err.value.args[0] == "@task decorator with 'outlets' argument is not supported in @asset"
183+
162184
@pytest.mark.parametrize(
163185
"provided_uri, expected_uri",
164186
[
@@ -222,6 +244,36 @@ def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_
222244
)
223245
from_definition.assert_called_once_with(asset_definition)
224246

247+
@mock.patch("airflow.sdk.bases.decorator._TaskDecorator.__call__")
248+
@mock.patch("airflow.sdk.definitions.dag.DAG")
249+
def test_with_task_decorator(self, DAG, __call__, func_fixer):
250+
@task(retries=3)
251+
@func_fixer
252+
def _example_task_func():
253+
return "This is example_task"
254+
255+
asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
256+
_example_task_func
257+
)
258+
259+
DAG.assert_called_once_with(
260+
dag_id="example_asset_func",
261+
dag_display_name="example_asset_func",
262+
description=None,
263+
schedule=None,
264+
catchup=False,
265+
is_paused_upon_creation=None,
266+
on_failure_callback=None,
267+
on_success_callback=None,
268+
params=None,
269+
access_control=None,
270+
owner_links={},
271+
tags=set(),
272+
auto_register=True,
273+
)
274+
__call__.assert_called_once_with()
275+
assert asset_definition._function.kwargs["outlets"] == [asset_definition]
276+
225277

226278
class TestMultiAssetDefinition:
227279
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
@@ -249,6 +301,37 @@ def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_
249301
)
250302
from_definition.assert_called_once_with(definition)
251303

304+
@mock.patch("airflow.sdk.bases.decorator._TaskDecorator.__call__")
305+
@mock.patch("airflow.sdk.definitions.dag.DAG")
306+
def test_with_task_decorator(self, DAG, __call__, func_fixer):
307+
@task(retries=3)
308+
@func_fixer
309+
def _example_task_func():
310+
return "This is example_task"
311+
312+
definition = asset.multi(
313+
schedule=None,
314+
outlets=[Asset(name="a"), Asset(name="b")],
315+
)(_example_task_func)
316+
317+
DAG.assert_called_once_with(
318+
dag_id="example_asset_func",
319+
dag_display_name="example_asset_func",
320+
description=None,
321+
schedule=None,
322+
catchup=False,
323+
is_paused_upon_creation=None,
324+
on_failure_callback=None,
325+
on_success_callback=None,
326+
params=None,
327+
access_control=None,
328+
owner_links={},
329+
tags=set(),
330+
auto_register=True,
331+
)
332+
__call__.assert_called_once_with()
333+
assert definition._function.kwargs["outlets"] == [Asset(name="a"), Asset(name="b")]
334+
252335

253336
class Test_AssetMainOperator:
254337
def test_from_definition(self, example_asset_func_with_valid_arg_as_inlet_asset):

0 commit comments

Comments
 (0)