Skip to content

Commit 444ba87

Browse files
committed
verify test parameters
1 parent 422388f commit 444ba87

File tree

2 files changed

+137
-3
lines changed

2 files changed

+137
-3
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import inspect
21+
import os
22+
23+
import libcst as cst
24+
import pytest
25+
26+
DECORATOR_OPERATOR_MAP = {
27+
"kubernetes": "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator",
28+
"sensor": "airflow.sdk.bases.sensor.BaseSensorOperator",
29+
"virtualenv": "airflow.providers.standard.operators.python.PythonVirtualenvOperator",
30+
"branch_virtualenv": "airflow.providers.standard.operators.python.BranchPythonVirtualenvOperator",
31+
# Add more here...
32+
}
33+
34+
35+
def extract_function_params(code, function_name, return_type):
36+
"""Extracts parameters from a specific function definition in the given code.
37+
38+
Args:
39+
code (str): The Python code to parse.
40+
function_name (str): The name of the function to extract parameters from.
41+
return_type (str): As the pyi file has multiple @overload decorator, extract function param based on return type.
42+
43+
Returns:
44+
list: A list of parameter names, or None if the function is not found.
45+
"""
46+
module = cst.parse_module(code)
47+
48+
class FunctionParamExtractor(cst.CSTVisitor):
49+
def __init__(self, target_function_name, target_return_type):
50+
self.target_function_name = target_function_name
51+
self.target_return_type = target_return_type
52+
self.params: list[str] = []
53+
54+
def visit_FunctionDef(self, node):
55+
# Match function name
56+
if node.name.value == self.target_function_name:
57+
if node.returns:
58+
annotation = node.returns.annotation
59+
if isinstance(annotation, cst.Name) and annotation.value == self.target_return_type:
60+
parameters_node = node.params
61+
self.params.extend(param.name.value for param in parameters_node.params)
62+
self.params.extend(param.name.value for param in parameters_node.kwonly_params)
63+
self.params.extend(param.name.value for param in parameters_node.posonly_params)
64+
if parameters_node.star_kwarg:
65+
self.params.append(parameters_node.star_kwarg.name.value)
66+
return False # Stop traversing after finding the real function
67+
return True # Keep traversing
68+
69+
extractor = FunctionParamExtractor(function_name, return_type)
70+
module.visit(extractor)
71+
return extractor.params
72+
73+
74+
def get_decorator_params(decorator_name: str):
75+
file_path = os.path.abspath(
76+
os.path.join(
77+
os.path.dirname(__file__),
78+
"../../../../task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi",
79+
)
80+
)
81+
params = []
82+
with open(file_path) as file:
83+
code = file.read()
84+
params = extract_function_params(code, decorator_name, "TaskDecorator")
85+
return set(params)
86+
87+
88+
def get_operator_params(operator_path: str):
89+
module_path, class_name = operator_path.rsplit(".", 1)
90+
module = __import__(module_path, fromlist=[class_name])
91+
operator_cls = getattr(module, class_name)
92+
sig = inspect.signature(operator_cls.__init__)
93+
return set(p for p in sig.parameters.keys() if p not in ("self", "args", "kwargs"))
94+
95+
96+
@pytest.mark.parametrize("decorator, operator_path", DECORATOR_OPERATOR_MAP.items())
97+
def test_decorator_matches_operator_signature(decorator, operator_path):
98+
decorator_params = get_decorator_params(decorator)
99+
operator_params = get_operator_params(operator_path)
100+
missing_in_decorator = operator_params - decorator_params
101+
102+
ignored = {"kwargs", "args", "self"}
103+
missing_in_decorator -= ignored
104+
assert not missing_in_decorator, f"{decorator} is missing params: {missing_in_decorator}"

task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ from typing import Any, Callable, TypeVar, overload
2727
from docker.types import Mount
2828
from kubernetes.client import models as k8s
2929

30+
from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback
3031
from airflow.providers.cncf.kubernetes.secret import Secret
3132
from airflow.sdk.bases.decorator import FParams, FReturn, Task, TaskDecorator, _TaskDecorator
3233
from airflow.sdk.definitions.dag import dag
@@ -96,16 +97,21 @@ class TaskDecoratorCollection:
9697
multiple_outputs: bool | None = None,
9798
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
9899
# _PythonVirtualenvDecoratedOperator.
100+
python_callable: Callable,
101+
op_args: Collection[Any] | None = None,
102+
op_kwargs: Mapping[str, Any] | None = None,
99103
requirements: None | Iterable[str] | str = None,
100104
python_version: None | str | int | float = None,
101105
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
102106
system_site_packages: bool = True,
103107
templates_dict: Mapping[str, Any] | None = None,
108+
templates_exts: list[str] | None = None,
104109
pip_install_options: list[str] | None = None,
110+
expect_airflow: bool = True,
105111
skip_on_exit_code: int | Container[int] | None = None,
106112
index_urls: None | Collection[str] | str = None,
107113
venv_cache_path: None | str = None,
108-
show_return_value_in_logs: bool = True,
114+
string_args: Iterable[str] | None = None,
109115
env_vars: dict[str, str] | None = None,
110116
inherit_env: bool = True,
111117
**kwargs,
@@ -218,18 +224,24 @@ class TaskDecoratorCollection:
218224
self,
219225
*,
220226
multiple_outputs: bool | None = None,
221-
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
222-
# _PythonVirtualenvDecoratedOperator.
227+
python_callable: Callable,
223228
requirements: None | Iterable[str] | str = None,
229+
op_args: Collection[Any] | None = None,
230+
op_kwargs: Mapping[str, Any] | None = None,
231+
string_args: Iterable[str] | None = None,
224232
python_version: None | str | int | float = None,
225233
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
226234
system_site_packages: bool = True,
227235
templates_dict: Mapping[str, Any] | None = None,
236+
templates_exts: list[str] | None = None,
228237
pip_install_options: list[str] | None = None,
229238
skip_on_exit_code: int | Container[int] | None = None,
230239
index_urls: None | Collection[str] | str = None,
231240
venv_cache_path: None | str = None,
241+
expect_airflow: bool = True,
232242
show_return_value_in_logs: bool = True,
243+
env_vars: dict[str, str] | None = None,
244+
inherit_env: bool = True,
233245
**kwargs,
234246
) -> TaskDecorator:
235247
"""Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator.
@@ -507,6 +519,7 @@ class TaskDecoratorCollection:
507519
image: str | None = None,
508520
name: str | None = None,
509521
random_name_suffix: bool = ...,
522+
cmds: list[str] | None = None,
510523
arguments: list[str] | None = None,
511524
ports: list[k8s.V1ContainerPort] | None = None,
512525
volume_mounts: list[k8s.V1VolumeMount] | None = None,
@@ -520,6 +533,7 @@ class TaskDecoratorCollection:
520533
reattach_on_restart: bool = ...,
521534
startup_timeout_seconds: int = ...,
522535
startup_check_interval_seconds: int = ...,
536+
schedule_timeout_seconds: int | None = None,
523537
get_logs: bool = True,
524538
container_logs: Iterable[str] | str | Literal[True] = ...,
525539
image_pull_policy: str | None = None,
@@ -530,6 +544,7 @@ class TaskDecoratorCollection:
530544
node_selector: dict | None = None,
531545
image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
532546
service_account_name: str | None = None,
547+
automount_service_account_token: bool | None = None,
533548
hostnetwork: bool = False,
534549
host_aliases: list[k8s.V1HostAlias] | None = None,
535550
tolerations: list[k8s.V1Toleration] | None = None,
@@ -553,13 +568,19 @@ class TaskDecoratorCollection:
553568
skip_on_exit_code: int | Container[int] | None = None,
554569
base_container_name: str | None = None,
555570
base_container_status_polling_interval: float = ...,
571+
init_container_logs: Iterable[str] | str | Literal[True] | None = None,
556572
deferrable: bool = ...,
557573
poll_interval: float = ...,
558574
log_pod_spec_on_failure: bool = ...,
559575
on_finish_action: str = ...,
576+
is_delete_operator_pod: None | bool = None,
560577
termination_message_policy: str = ...,
561578
active_deadline_seconds: int | None = None,
579+
callbacks: (
580+
list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None
581+
) = None,
562582
progress_callback: Callable[[str], None] | None = None,
583+
logging_interval: int | None = None,
563584
**kwargs,
564585
) -> TaskDecorator:
565586
"""Create a decorator to convert a callable to a Kubernetes Pod task.
@@ -849,6 +870,8 @@ class TaskDecoratorCollection:
849870
mode: str = ...,
850871
exponential_backoff: bool = False,
851872
max_wait: timedelta | float | None = None,
873+
silent_fail: bool = False,
874+
never_fail: bool = False,
852875
**kwargs,
853876
) -> TaskDecorator:
854877
"""
@@ -873,6 +896,13 @@ class TaskDecoratorCollection:
873896
:param exponential_backoff: allow progressive longer waits between
874897
pokes by using exponential backoff algorithm
875898
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
899+
:param silent_fail: If true, and poke method raises an exception different from
900+
AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException
901+
and AirflowFailException, the sensor will log the error and continue
902+
its execution. Otherwise, the sensor task fails, and it can be retried
903+
based on the provided `retries` parameter.
904+
:param never_fail: If true, and poke method raises an exception, sensor will be skipped.
905+
Mutually exclusive with soft_fail.
876906
"""
877907
@overload
878908
def sensor(self, python_callable: Callable[FParams, FReturn] | None = None) -> Task[FParams, FReturn]: ...

0 commit comments

Comments
 (0)