|
| 1 | +#!/usr/bin/env python |
| 2 | +# |
| 3 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 4 | +# or more contributor license agreements. See the NOTICE file |
| 5 | +# distributed with this work for additional information |
| 6 | +# regarding copyright ownership. The ASF licenses this file |
| 7 | +# to you under the Apache License, Version 2.0 (the |
| 8 | +# "License"); you may not use this file except in compliance |
| 9 | +# with the License. You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, |
| 14 | +# software distributed under the License is distributed on an |
| 15 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | +# KIND, either express or implied. See the License for the |
| 17 | +# specific language governing permissions and limitations |
| 18 | +# under the License. |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import inspect |
| 22 | +import sys |
| 23 | + |
| 24 | +import libcst as cst |
| 25 | +from in_container_utils import AIRFLOW_ROOT_PATH, console |
| 26 | + |
| 27 | +DECORATOR_OPERATOR_MAP = { |
| 28 | + "kubernetes": "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator", |
| 29 | + "sensor": "airflow.sdk.bases.sensor.BaseSensorOperator", |
| 30 | + "virtualenv": "airflow.providers.standard.operators.python.PythonVirtualenvOperator", |
| 31 | + "branch_virtualenv": "airflow.providers.standard.operators.python.BranchPythonVirtualenvOperator", |
| 32 | + # Add more here... |
| 33 | +} |
| 34 | +DECORATOR_PYI_PATH = ( |
| 35 | + AIRFLOW_ROOT_PATH / "task-sdk" / "src" / "airflow" / "sdk" / "definitions" / "decorators" / "__init__.pyi" |
| 36 | +) |
| 37 | +decorator_pyi_file_content = DECORATOR_PYI_PATH.read_text() |
| 38 | + |
| 39 | + |
| 40 | +def extract_function_params(code, function_name, return_type): |
| 41 | + """Extracts parameters from a specific function definition in the given code. |
| 42 | +
|
| 43 | + Args: |
| 44 | + code (str): The Python code to parse. |
| 45 | + function_name (str): The name of the function to extract parameters from. |
| 46 | + return_type (str): As the pyi file has multiple @overload decorator, extract function param based on return type. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + list: A list of parameter names, or None if the function is not found. |
| 50 | + """ |
| 51 | + module = cst.parse_module(code) |
| 52 | + |
| 53 | + class FunctionParamExtractor(cst.CSTVisitor): |
| 54 | + def __init__(self, target_function_name, target_return_type): |
| 55 | + self.target_function_name = target_function_name |
| 56 | + self.target_return_type = target_return_type |
| 57 | + self.params: list[str] = [] |
| 58 | + |
| 59 | + def visit_FunctionDef(self, node): |
| 60 | + # Match function name |
| 61 | + if node.name.value == self.target_function_name: |
| 62 | + if node.returns: |
| 63 | + annotation = node.returns.annotation |
| 64 | + if isinstance(annotation, cst.Name) and annotation.value == self.target_return_type: |
| 65 | + parameters_node = node.params |
| 66 | + self.params.extend(param.name.value for param in parameters_node.params) |
| 67 | + self.params.extend(param.name.value for param in parameters_node.kwonly_params) |
| 68 | + self.params.extend(param.name.value for param in parameters_node.posonly_params) |
| 69 | + if parameters_node.star_kwarg: |
| 70 | + self.params.append(parameters_node.star_kwarg.name.value) |
| 71 | + return False # Stop traversing after finding the real function |
| 72 | + return True # Keep traversing |
| 73 | + |
| 74 | + extractor = FunctionParamExtractor(function_name, return_type) |
| 75 | + module.visit(extractor) |
| 76 | + return extractor.params |
| 77 | + |
| 78 | + |
| 79 | +def get_decorator_params(decorator_name: str): |
| 80 | + params = extract_function_params(decorator_pyi_file_content, decorator_name, "TaskDecorator") |
| 81 | + return set(params) |
| 82 | + |
| 83 | + |
| 84 | +def get_operator_params(operator_path: str): |
| 85 | + console.print("Operator path:", operator_path) |
| 86 | + module_path, class_name = operator_path.rsplit(".", 1) |
| 87 | + module = __import__(module_path, fromlist=[class_name]) |
| 88 | + operator_cls = getattr(module, class_name) |
| 89 | + sig = inspect.signature(operator_cls.__init__) |
| 90 | + return set(p for p in sig.parameters.keys() if p not in ("self", "args", "kwargs")) |
| 91 | + |
| 92 | + |
| 93 | +def verify_signature_consistency(): |
| 94 | + failure = False |
| 95 | + console.print("Verify signature consistency") |
| 96 | + for decorator, operator_path in DECORATOR_OPERATOR_MAP.items(): |
| 97 | + decorator_params = get_decorator_params(decorator) |
| 98 | + operator_params = get_operator_params(operator_path) |
| 99 | + missing_in_decorator = operator_params - decorator_params |
| 100 | + |
| 101 | + ignored = {"kwargs", "args", "self", "python_callable", "op_args", "op_kwargs"} |
| 102 | + missing_in_decorator -= ignored |
| 103 | + if missing_in_decorator: |
| 104 | + failure = True |
| 105 | + console.print(f"[yellow]Missing params in[/] [bold]__init__.py[/]: {missing_in_decorator}") |
| 106 | + if failure: |
| 107 | + console.print("[red]Some of the decorator signatures are missing in __init__.py[/]") |
| 108 | + sys.exit(1) |
| 109 | + console.print("[green]All decorator signature matches[/]") |
| 110 | + sys.exit(0) |
| 111 | + |
| 112 | + |
| 113 | +if __name__ == "__main__": |
| 114 | + verify_signature_consistency() |
0 commit comments