Skip to content

Commit 87a9768

Browse files
committed
verify test parameters
1 parent 422388f commit 87a9768

File tree

2 files changed

+139
-3
lines changed

2 files changed

+139
-3
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import inspect
21+
import os
22+
23+
import libcst as cst
24+
import pytest
25+
26+
from airflow.decorators import task
27+
28+
DECORATOR_OPERATOR_MAP = {
29+
"kubernetes": "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator",
30+
"sensor": "airflow.sdk.bases.sensor.BaseSensorOperator",
31+
"virtualenv": "airflow.providers.standard.operators.python.PythonVirtualenvOperator",
32+
"branch_virtualenv": "airflow.providers.standard.operators.python.BranchPythonVirtualenvOperator",
33+
# Add more here...
34+
}
35+
36+
37+
def extract_function_params(code, function_name, return_type):
38+
"""Extracts parameters from a specific function definition in the given code.
39+
40+
Args:
41+
code (str): The Python code to parse.
42+
function_name (str): The name of the function to extract parameters from.
43+
return_type (str): As the pyi file has multiple @overload decorator, extract function param based on return type.
44+
45+
Returns:
46+
list: A list of parameter names, or None if the function is not found.
47+
"""
48+
module = cst.parse_module(code)
49+
50+
class FunctionParamExtractor(cst.CSTVisitor):
51+
def __init__(self, target_function_name, target_return_type):
52+
self.target_function_name = target_function_name
53+
self.target_return_type = target_return_type
54+
self.params: list[str] = []
55+
56+
def visit_FunctionDef(self, node):
57+
# Match function name
58+
if node.name.value == self.target_function_name:
59+
if node.returns:
60+
annotation = node.returns.annotation
61+
if isinstance(annotation, cst.Name) and annotation.value == self.target_return_type:
62+
parameters_node = node.params
63+
self.params.extend(param.name.value for param in parameters_node.params)
64+
self.params.extend(param.name.value for param in parameters_node.kwonly_params)
65+
self.params.extend(param.name.value for param in parameters_node.posonly_params)
66+
if parameters_node.star_kwarg:
67+
self.params.append(parameters_node.star_kwarg.name.value)
68+
return False # Stop traversing after finding the real function
69+
return True # Keep traversing
70+
71+
extractor = FunctionParamExtractor(function_name, return_type)
72+
module.visit(extractor)
73+
return extractor.params
74+
75+
76+
def get_decorator_params(decorator_name: str):
77+
file_path = os.path.abspath(
78+
os.path.join(
79+
os.path.dirname(__file__),
80+
"../../../../task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi",
81+
)
82+
)
83+
params = []
84+
with open(file_path) as file:
85+
code = file.read()
86+
params = extract_function_params(code, decorator_name, "TaskDecorator")
87+
return set(params)
88+
89+
90+
def get_operator_params(operator_path: str):
91+
module_path, class_name = operator_path.rsplit(".", 1)
92+
module = __import__(module_path, fromlist=[class_name])
93+
operator_cls = getattr(module, class_name)
94+
sig = inspect.signature(operator_cls.__init__)
95+
return set(p for p in sig.parameters.keys() if p not in ("self", "args", "kwargs"))
96+
97+
98+
@pytest.mark.parametrize("decorator, operator_path", DECORATOR_OPERATOR_MAP.items())
99+
def test_decorator_matches_operator_signature(decorator, operator_path):
100+
decorator_params = get_decorator_params(decorator)
101+
operator_params = get_operator_params(operator_path)
102+
missing_in_decorator = operator_params - decorator_params
103+
104+
ignored = {"kwargs", "args", "self"}
105+
missing_in_decorator -= ignored
106+
assert not missing_in_decorator, f"{decorator} is missing params: {missing_in_decorator}"

task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ from typing import Any, Callable, TypeVar, overload
2727
from docker.types import Mount
2828
from kubernetes.client import models as k8s
2929

30+
from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback
3031
from airflow.providers.cncf.kubernetes.secret import Secret
3132
from airflow.sdk.bases.decorator import FParams, FReturn, Task, TaskDecorator, _TaskDecorator
3233
from airflow.sdk.definitions.dag import dag
@@ -96,16 +97,21 @@ class TaskDecoratorCollection:
9697
multiple_outputs: bool | None = None,
9798
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
9899
# _PythonVirtualenvDecoratedOperator.
100+
python_callable: Callable,
101+
op_args: Collection[Any] | None = None,
102+
op_kwargs: Mapping[str, Any] | None = None,
99103
requirements: None | Iterable[str] | str = None,
100104
python_version: None | str | int | float = None,
101105
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
102106
system_site_packages: bool = True,
103107
templates_dict: Mapping[str, Any] | None = None,
108+
templates_exts: list[str] | None = None,
104109
pip_install_options: list[str] | None = None,
110+
expect_airflow: bool = True,
105111
skip_on_exit_code: int | Container[int] | None = None,
106112
index_urls: None | Collection[str] | str = None,
107113
venv_cache_path: None | str = None,
108-
show_return_value_in_logs: bool = True,
114+
string_args: Iterable[str] | None = None,
109115
env_vars: dict[str, str] | None = None,
110116
inherit_env: bool = True,
111117
**kwargs,
@@ -218,18 +224,24 @@ class TaskDecoratorCollection:
218224
self,
219225
*,
220226
multiple_outputs: bool | None = None,
221-
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
222-
# _PythonVirtualenvDecoratedOperator.
227+
python_callable: Callable,
223228
requirements: None | Iterable[str] | str = None,
229+
op_args: Collection[Any] | None = None,
230+
op_kwargs: Mapping[str, Any] | None = None,
231+
string_args: Iterable[str] | None = None,
224232
python_version: None | str | int | float = None,
225233
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
226234
system_site_packages: bool = True,
227235
templates_dict: Mapping[str, Any] | None = None,
236+
templates_exts: list[str] | None = None,
228237
pip_install_options: list[str] | None = None,
229238
skip_on_exit_code: int | Container[int] | None = None,
230239
index_urls: None | Collection[str] | str = None,
231240
venv_cache_path: None | str = None,
241+
expect_airflow: bool = True,
232242
show_return_value_in_logs: bool = True,
243+
env_vars: dict[str, str] | None = None,
244+
inherit_env: bool = True,
233245
**kwargs,
234246
) -> TaskDecorator:
235247
"""Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator.
@@ -507,6 +519,7 @@ class TaskDecoratorCollection:
507519
image: str | None = None,
508520
name: str | None = None,
509521
random_name_suffix: bool = ...,
522+
cmds: list[str] | None = None,
510523
arguments: list[str] | None = None,
511524
ports: list[k8s.V1ContainerPort] | None = None,
512525
volume_mounts: list[k8s.V1VolumeMount] | None = None,
@@ -520,6 +533,7 @@ class TaskDecoratorCollection:
520533
reattach_on_restart: bool = ...,
521534
startup_timeout_seconds: int = ...,
522535
startup_check_interval_seconds: int = ...,
536+
schedule_timeout_seconds: int | None = None,
523537
get_logs: bool = True,
524538
container_logs: Iterable[str] | str | Literal[True] = ...,
525539
image_pull_policy: str | None = None,
@@ -530,6 +544,7 @@ class TaskDecoratorCollection:
530544
node_selector: dict | None = None,
531545
image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
532546
service_account_name: str | None = None,
547+
automount_service_account_token: bool | None = None,
533548
hostnetwork: bool = False,
534549
host_aliases: list[k8s.V1HostAlias] | None = None,
535550
tolerations: list[k8s.V1Toleration] | None = None,
@@ -553,13 +568,19 @@ class TaskDecoratorCollection:
553568
skip_on_exit_code: int | Container[int] | None = None,
554569
base_container_name: str | None = None,
555570
base_container_status_polling_interval: float = ...,
571+
init_container_logs: Iterable[str] | str | Literal[True] | None = None,
556572
deferrable: bool = ...,
557573
poll_interval: float = ...,
558574
log_pod_spec_on_failure: bool = ...,
559575
on_finish_action: str = ...,
576+
is_delete_operator_pod: None | bool = None,
560577
termination_message_policy: str = ...,
561578
active_deadline_seconds: int | None = None,
579+
callbacks: (
580+
list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None
581+
) = None,
562582
progress_callback: Callable[[str], None] | None = None,
583+
logging_interval: int | None = None,
563584
**kwargs,
564585
) -> TaskDecorator:
565586
"""Create a decorator to convert a callable to a Kubernetes Pod task.
@@ -849,6 +870,8 @@ class TaskDecoratorCollection:
849870
mode: str = ...,
850871
exponential_backoff: bool = False,
851872
max_wait: timedelta | float | None = None,
873+
silent_fail: bool = False,
874+
never_fail: bool = False,
852875
**kwargs,
853876
) -> TaskDecorator:
854877
"""
@@ -873,6 +896,13 @@ class TaskDecoratorCollection:
873896
:param exponential_backoff: allow progressive longer waits between
874897
pokes by using exponential backoff algorithm
875898
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
899+
:param silent_fail: If true, and poke method raises an exception different from
900+
AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException
901+
and AirflowFailException, the sensor will log the error and continue
902+
its execution. Otherwise, the sensor task fails, and it can be retried
903+
based on the provided `retries` parameter.
904+
:param never_fail: If true, and poke method raises an exception, sensor will be skipped.
905+
Mutually exclusive with soft_fail.
876906
"""
877907
@overload
878908
def sensor(self, python_callable: Callable[FParams, FReturn] | None = None) -> Task[FParams, FReturn]: ...

0 commit comments

Comments
 (0)