diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst b/airflow-core/docs/authoring-and-scheduling/assets.rst index 4664c51d7f0e8..366a4069ca52b 100644 --- a/airflow-core/docs/authoring-and-scheduling/assets.rst +++ b/airflow-core/docs/authoring-and-scheduling/assets.rst @@ -229,6 +229,18 @@ The other way around also applies: def process_example_asset(example_asset): """Process inlet example_asset...""" +In addition, ``@asset`` can be used with ``@task`` to customize the task that generates the asset, +utilizing the modern TaskFlow approach described in :doc:`/tutorial/taskflow`. + +This combination allows you to set initial arguments for the task and to use various operators, such as the ``BashOperator``: + +.. code-block:: python + + @asset(schedule=None) + @task.bash(retries=3) + def example_asset(): + """Write to example_asset, from a Bash task with 3 retries...""" + return "echo 'run'" Output to multiple assets in one task ------------------------------------- diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 44479fbb9cd42..daeb6204c3f05 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -18,7 +18,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import attrs @@ -30,6 +30,7 @@ from collections.abc import Callable, Collection, Iterator, Mapping from airflow.sdk import DAG, AssetAlias, ObjectStoragePath + from airflow.sdk.bases.decorator import _TaskDecorator from airflow.sdk.definitions.asset import AssetUniqueKey from airflow.sdk.definitions.dag import DagStateChangeCallback, ScheduleArg from airflow.sdk.definitions.param import ParamsDict @@ -96,6 +97,18 @@ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: return dict(self._iter_kwargs(context)) +def _instantiate_task(definition: AssetDefinition | MultiAssetDefinition) -> None: + decorated_operator = cast("_TaskDecorator", definition._function) + if getattr(decorated_operator, "_airflow_is_task_decorator", False): + if "outlets" in decorated_operator.kwargs: + raise TypeError("@task decorator with 'outlets' argument is not supported in @asset") + + decorated_operator.kwargs["outlets"] = [v for _, v in definition.iter_assets()] + decorated_operator() + else: + _AssetMainOperator.from_definition(definition) + + @attrs.define(kw_only=True) class AssetDefinition(Asset): """ @@ -109,7 +122,7 @@ class AssetDefinition(Asset): def __attrs_post_init__(self) -> None: with self._source.create_dag(default_dag_id=self.name): - _AssetMainOperator.from_definition(self) + _instantiate_task(self) @attrs.define(kw_only=True) @@ -129,7 +142,7 @@ class MultiAssetDefinition(BaseAsset): def __attrs_post_init__(self) -> None: with self._source.create_dag(default_dag_id=self._function.__name__): - _AssetMainOperator.from_definition(self) + _instantiate_task(self) def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]: for o in self._source.outlets: diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py index 406785cb7964e..3dfc0b44588a0 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py @@ -22,6 +22,7 @@ from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset +from airflow.sdk.definitions.decorators import task from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName @@ -159,6 +160,27 @@ def example_asset_func(self, /): == "positional-only argument 'self' without a default is not supported in @asset" ) + def test_with_task_decorator(self, func_fixer): + @task(retries=3) + @func_fixer + def _example_task_func(): + return "This is example_task" + + asset_definition = asset(name="asset", dag_id="dag", schedule=None)(_example_task_func) + assert asset_definition.name == "asset" + assert asset_definition._source.dag_id == "dag" + assert asset_definition._function == _example_task_func + + def test_with_task_decorator_and_outlets(self, func_fixer): + @task(retries=3, outlets=Asset(name="a")) + @func_fixer + def _example_task_func(): + return "This is example_task" + + with pytest.raises(TypeError) as err: + asset(schedule=None)(_example_task_func) + assert err.value.args[0] == "@task decorator with 'outlets' argument is not supported in @asset" + @pytest.mark.parametrize( "provided_uri, expected_uri", [ @@ -222,6 +244,36 @@ def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_ ) from_definition.assert_called_once_with(asset_definition) + @mock.patch("airflow.sdk.bases.decorator._TaskDecorator.__call__") + @mock.patch("airflow.sdk.definitions.dag.DAG") + def test_with_task_decorator(self, DAG, __call__, func_fixer): + @task(retries=3) + @func_fixer + def _example_task_func(): + return "This is example_task" + + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + _example_task_func + ) + + DAG.assert_called_once_with( + dag_id="example_asset_func", + dag_display_name="example_asset_func", + description=None, + schedule=None, + catchup=False, + is_paused_upon_creation=None, + on_failure_callback=None, + on_success_callback=None, + params=None, + access_control=None, + owner_links={}, + tags=set(), + auto_register=True, + ) + __call__.assert_called_once_with() + assert asset_definition._function.kwargs["outlets"] == [asset_definition] + class TestMultiAssetDefinition: @mock.patch("airflow.sdk.definitions.asset.decorators._AssetMainOperator.from_definition") @@ -249,6 +301,37 @@ def test__attrs_post_init__(self, DAG, from_definition, example_asset_func_with_ ) from_definition.assert_called_once_with(definition) + @mock.patch("airflow.sdk.bases.decorator._TaskDecorator.__call__") + @mock.patch("airflow.sdk.definitions.dag.DAG") + def test_with_task_decorator(self, DAG, __call__, func_fixer): + @task(retries=3) + @func_fixer + def _example_task_func(): + return "This is example_task" + + definition = asset.multi( + schedule=None, + outlets=[Asset(name="a"), Asset(name="b")], + )(_example_task_func) + + DAG.assert_called_once_with( + dag_id="example_asset_func", + dag_display_name="example_asset_func", + description=None, + schedule=None, + catchup=False, + is_paused_upon_creation=None, + on_failure_callback=None, + on_success_callback=None, + params=None, + access_control=None, + owner_links={}, + tags=set(), + auto_register=True, + ) + __call__.assert_called_once_with() + assert definition._function.kwargs["outlets"] == [Asset(name="a"), Asset(name="b")] + class Test_AssetMainOperator: def test_from_definition(self, example_asset_func_with_valid_arg_as_inlet_asset):