Skip to content

Commit efe07f1

Browse files
committed
verify test parameters
1 parent 422388f commit efe07f1

File tree

2 files changed

+127
-3
lines changed

2 files changed

+127
-3
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
import inspect
19+
import libcst as cst
20+
import pytest
21+
import os
22+
from airflow.decorators import task
23+
24+
DECORATOR_OPERATOR_MAP = {
25+
"kubernetes": "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator",
26+
"sensor": "airflow.sdk.bases.sensor.BaseSensorOperator",
27+
"virtualenv": "airflow.providers.standard.operators.python.PythonVirtualenvOperator",
28+
"branch_virtualenv": "airflow.providers.standard.operators.python.BranchPythonVirtualenvOperator",
29+
# Add more here...
30+
}
31+
32+
def extract_function_params(code, function_name, return_type):
33+
"""Extracts parameters from a specific function definition in the given code.
34+
35+
Args:
36+
code (str): The Python code to parse.
37+
function_name (str): The name of the function to extract parameters from.
38+
return_type (str): As the pyi file has multiple @overload decorator, extract function param based on return type.
39+
40+
Returns:
41+
list: A list of parameter names, or None if the function is not found.
42+
"""
43+
module = cst.parse_module(code)
44+
class FunctionParamExtractor(cst.CSTVisitor):
45+
def __init__(self, target_function_name, target_return_type):
46+
self.target_function_name = target_function_name
47+
self.target_return_type = target_return_type
48+
self.params: list[str] = []
49+
50+
def visit_FunctionDef(self, node):
51+
# Match function name
52+
if node.name.value == self.target_function_name:
53+
if node.returns:
54+
annotation = node.returns.annotation
55+
if isinstance(annotation, cst.Name) and annotation.value == self.target_return_type:
56+
parameters_node = node.params
57+
self.params.extend(param.name.value for param in parameters_node.params)
58+
self.params.extend(param.name.value for param in parameters_node.kwonly_params)
59+
self.params.extend(param.name.value for param in parameters_node.posonly_params)
60+
if parameters_node.star_kwarg:
61+
self.params.append(parameters_node.star_kwarg.name.value)
62+
return False # Stop traversing after finding the real function
63+
return True # Keep traversing
64+
65+
extractor = FunctionParamExtractor(function_name, return_type)
66+
module.visit(extractor)
67+
return extractor.params
68+
69+
def get_decorator_params(decorator_name: str):
70+
decorator_fn = getattr(task, decorator_name)
71+
file_path = os.path.abspath(
72+
os.path.join(os.path.dirname(__file__), "../../../../task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi")
73+
)
74+
params = []
75+
with open(file_path, "r") as file:
76+
code = file.read()
77+
params = extract_function_params(code, decorator_name, "TaskDecorator")
78+
return set(params)
79+
80+
def get_operator_params(operator_path: str):
81+
module_path, class_name = operator_path.rsplit(".", 1)
82+
module = __import__(module_path, fromlist=[class_name])
83+
operator_cls = getattr(module, class_name)
84+
sig = inspect.signature(operator_cls.__init__)
85+
return set(p for p in sig.parameters.keys() if p not in ("self", "args", "kwargs"))
86+
87+
@pytest.mark.parametrize("decorator, operator_path", DECORATOR_OPERATOR_MAP.items())
88+
def test_decorator_matches_operator_signature(decorator, operator_path):
89+
decorator_params = get_decorator_params(decorator)
90+
operator_params = get_operator_params(operator_path)
91+
missing_in_decorator = operator_params - decorator_params
92+
93+
ignored = {"kwargs", "args", "self"}
94+
missing_in_decorator -= ignored
95+
assert not missing_in_decorator, f"{decorator} is missing params: {missing_in_decorator}"

task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,21 @@ class TaskDecoratorCollection:
9696
multiple_outputs: bool | None = None,
9797
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
9898
# _PythonVirtualenvDecoratedOperator.
99+
python_callable: Callable,
100+
op_args: Collection[Any] | None = None,
101+
op_kwargs: Mapping[str, Any] | None = None,
99102
requirements: None | Iterable[str] | str = None,
100103
python_version: None | str | int | float = None,
101104
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
102105
system_site_packages: bool = True,
103106
templates_dict: Mapping[str, Any] | None = None,
107+
templates_exts: list[str] | None = None,
104108
pip_install_options: list[str] | None = None,
109+
expect_airflow: bool = True,
105110
skip_on_exit_code: int | Container[int] | None = None,
106111
index_urls: None | Collection[str] | str = None,
107112
venv_cache_path: None | str = None,
108-
show_return_value_in_logs: bool = True,
113+
string_args: Iterable[str] | None = None,
109114
env_vars: dict[str, str] | None = None,
110115
inherit_env: bool = True,
111116
**kwargs,
@@ -218,18 +223,24 @@ class TaskDecoratorCollection:
218223
self,
219224
*,
220225
multiple_outputs: bool | None = None,
221-
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
222-
# _PythonVirtualenvDecoratedOperator.
226+
python_callable: Callable,
223227
requirements: None | Iterable[str] | str = None,
228+
op_args: Collection[Any] | None = None,
229+
op_kwargs: Mapping[str, Any] | None = None,
230+
string_args: Iterable[str] | None = None,
224231
python_version: None | str | int | float = None,
225232
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
226233
system_site_packages: bool = True,
227234
templates_dict: Mapping[str, Any] | None = None,
235+
templates_exts: list[str] | None = None,
228236
pip_install_options: list[str] | None = None,
229237
skip_on_exit_code: int | Container[int] | None = None,
230238
index_urls: None | Collection[str] | str = None,
231239
venv_cache_path: None | str = None,
240+
expect_airflow: bool = True,
232241
show_return_value_in_logs: bool = True,
242+
env_vars: dict[str, str] | None = None,
243+
inherit_env: bool = True,
233244
**kwargs,
234245
) -> TaskDecorator:
235246
"""Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator.
@@ -507,6 +518,7 @@ class TaskDecoratorCollection:
507518
image: str | None = None,
508519
name: str | None = None,
509520
random_name_suffix: bool = ...,
521+
cmds: list[str] | None = None,
510522
arguments: list[str] | None = None,
511523
ports: list[k8s.V1ContainerPort] | None = None,
512524
volume_mounts: list[k8s.V1VolumeMount] | None = None,
@@ -520,6 +532,7 @@ class TaskDecoratorCollection:
520532
reattach_on_restart: bool = ...,
521533
startup_timeout_seconds: int = ...,
522534
startup_check_interval_seconds: int = ...,
535+
schedule_timeout_seconds: int | None = None,
523536
get_logs: bool = True,
524537
container_logs: Iterable[str] | str | Literal[True] = ...,
525538
image_pull_policy: str | None = None,
@@ -530,6 +543,7 @@ class TaskDecoratorCollection:
530543
node_selector: dict | None = None,
531544
image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
532545
service_account_name: str | None = None,
546+
automount_service_account_token: bool | None = None,
533547
hostnetwork: bool = False,
534548
host_aliases: list[k8s.V1HostAlias] | None = None,
535549
tolerations: list[k8s.V1Toleration] | None = None,
@@ -553,13 +567,19 @@ class TaskDecoratorCollection:
553567
skip_on_exit_code: int | Container[int] | None = None,
554568
base_container_name: str | None = None,
555569
base_container_status_polling_interval: float = ...,
570+
init_container_logs: Iterable[str] | str | Literal[True] | None = None,
556571
deferrable: bool = ...,
557572
poll_interval: float = ...,
558573
log_pod_spec_on_failure: bool = ...,
559574
on_finish_action: str = ...,
575+
is_delete_operator_pod: None | bool = None,
560576
termination_message_policy: str = ...,
561577
active_deadline_seconds: int | None = None,
578+
callbacks: (
579+
list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None
580+
) = None,
562581
progress_callback: Callable[[str], None] | None = None,
582+
logging_interval: int | None = None,
563583
**kwargs,
564584
) -> TaskDecorator:
565585
"""Create a decorator to convert a callable to a Kubernetes Pod task.
@@ -849,6 +869,8 @@ class TaskDecoratorCollection:
849869
mode: str = ...,
850870
exponential_backoff: bool = False,
851871
max_wait: timedelta | float | None = None,
872+
silent_fail: bool = False,
873+
never_fail: bool = False,
852874
**kwargs,
853875
) -> TaskDecorator:
854876
"""
@@ -873,6 +895,13 @@ class TaskDecoratorCollection:
873895
:param exponential_backoff: allow progressive longer waits between
874896
pokes by using exponential backoff algorithm
875897
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
898+
:param silent_fail: If true, and poke method raises an exception different from
899+
AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException
900+
and AirflowFailException, the sensor will log the error and continue
901+
its execution. Otherwise, the sensor task fails, and it can be retried
902+
based on the provided `retries` parameter.
903+
:param never_fail: If true, and poke method raises an exception, sensor will be skipped.
904+
Mutually exclusive with soft_fail.
876905
"""
877906
@overload
878907
def sensor(self, python_callable: Callable[FParams, FReturn] | None = None) -> Task[FParams, FReturn]: ...

0 commit comments

Comments
 (0)