Skip to content

Commit f2986cf

Browse files
committed
Allow Asset decorator to work with any TaskFlow operator
1 parent 6041b77 commit f2986cf

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

airflow-core/docs/authoring-and-scheduling/assets.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,15 @@ The other way around also applies:
229229
def process_example_asset(example_asset):
230230
"""Process inlet example_asset..."""
231231
232+
In addition, ``@asset`` can be used with ``@task`` to set initial arguments for the task or to use an operator other than ``PythonOperator``:
233+
234+
.. code-block:: python
235+
236+
@asset(schedule=None)
237+
@task.bash(retries=3)
238+
def example_asset():
239+
"""Write to example_asset, from a Bash task with 3 retries..."""
240+
return "echo 'run'"
232241
233242
Output to multiple assets in one task
234243
-------------------------------------

task-sdk/src/airflow/sdk/definitions/asset/decorators.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import attrs
2424

2525
from airflow.providers.standard.operators.python import PythonOperator
26+
from airflow.sdk.bases.decorator import _TaskDecorator
2627
from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset
2728
from airflow.sdk.exceptions import AirflowRuntimeError
2829

@@ -111,7 +112,12 @@ class AssetDefinition(Asset):
111112

112113
def __attrs_post_init__(self) -> None:
113114
with self._source.create_dag(default_dag_id=self.name):
114-
_AssetMainOperator.from_definition(self)
115+
if isinstance(self._function, _TaskDecorator):
116+
if "outlets" not in self._function.kwargs:
117+
self._function.kwargs["outlets"] = [v for _, v in self.iter_assets()]
118+
self._function()
119+
else:
120+
_AssetMainOperator.from_definition(self)
115121

116122

117123
@attrs.define(kw_only=True)
@@ -131,7 +137,12 @@ class MultiAssetDefinition(BaseAsset):
131137

132138
def __attrs_post_init__(self) -> None:
133139
with self._source.create_dag(default_dag_id=self._function.__name__):
134-
_AssetMainOperator.from_definition(self)
140+
if isinstance(self._function, _TaskDecorator):
141+
if "outlets" not in self._function.kwargs:
142+
self._function.kwargs["outlets"] = [v for _, v in self.iter_assets()]
143+
self._function()
144+
else:
145+
_AssetMainOperator.from_definition(self)
135146

136147
def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
137148
for o in self._source.outlets:

task-sdk/tests/task_sdk/definitions/test_asset_decorators.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from airflow.sdk.definitions.asset import Asset
2424
from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
25+
from airflow.sdk.definitions.decorators import task
2526
from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
2627

2728

@@ -159,6 +160,17 @@ def example_asset_func(self, /):
159160
== "positional-only argument 'self' without a default is not supported in @asset"
160161
)
161162

163+
def test_with_task_decorator(self, func_fixer):
164+
@task(retries=3)
165+
@func_fixer
166+
def _example_task_func():
167+
return "This is example_task"
168+
169+
asset_definition = asset(name="asset", dag_id="dag", schedule=None)(_example_task_func)
170+
assert asset_definition.name == "asset"
171+
assert asset_definition._source.dag_id == "dag"
172+
assert asset_definition._function == _example_task_func
173+
162174
@pytest.mark.parametrize(
163175
"provided_uri, expected_uri",
164176
[
@@ -222,6 +234,36 @@ def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_
222234
)
223235
from_definition.assert_called_once_with(asset_definition)
224236

237+
@mock.patch("airflow.sdk.bases.decorator._TaskDecorator.__call__")
238+
@mock.patch("airflow.sdk.definitions.dag.DAG")
239+
def test_with_task_decorator(self, DAG, __call__, func_fixer):
240+
@task(retries=3)
241+
@func_fixer
242+
def _example_task_func():
243+
return "This is example_task"
244+
245+
asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
246+
_example_task_func
247+
)
248+
249+
DAG.assert_called_once_with(
250+
dag_id="example_asset_func",
251+
dag_display_name="example_asset_func",
252+
description=None,
253+
schedule=None,
254+
catchup=False,
255+
is_paused_upon_creation=None,
256+
on_failure_callback=None,
257+
on_success_callback=None,
258+
params=None,
259+
access_control=None,
260+
owner_links={},
261+
tags=set(),
262+
auto_register=True,
263+
)
264+
__call__.assert_called_once_with()
265+
assert asset_definition._function.kwargs["outlets"] == [asset_definition]
266+
225267

226268
class TestMultiAssetDefinition:
227269
@mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition")
@@ -249,6 +291,37 @@ def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_
249291
)
250292
from_definition.assert_called_once_with(definition)
251293

294+
@mock.patch("airflow.sdk.bases.decorator._TaskDecorator.__call__")
295+
@mock.patch("airflow.sdk.definitions.dag.DAG")
296+
def test_with_task_decorator(self, DAG, __call__, func_fixer):
297+
@task(retries=3)
298+
@func_fixer
299+
def _example_task_func():
300+
return "This is example_task"
301+
302+
definition = asset.multi(
303+
schedule=None,
304+
outlets=[Asset(name="a"), Asset(name="b")],
305+
)(_example_task_func)
306+
307+
DAG.assert_called_once_with(
308+
dag_id="example_asset_func",
309+
dag_display_name="example_asset_func",
310+
description=None,
311+
schedule=None,
312+
catchup=False,
313+
is_paused_upon_creation=None,
314+
on_failure_callback=None,
315+
on_success_callback=None,
316+
params=None,
317+
access_control=None,
318+
owner_links={},
319+
tags=set(),
320+
auto_register=True,
321+
)
322+
__call__.assert_called_once_with()
323+
assert definition._function.kwargs["outlets"] == [Asset(name="a"), Asset(name="b")]
324+
252325

253326
class Test_AssetMainOperator:
254327
def test_from_definition(self, example_asset_func_with_valid_arg_as_inlet_asset):

0 commit comments

Comments
 (0)