|
| 1 | +"""Test the validity of all DAGs. **USED BY DEV PARSE COMMAND DO NOT EDIT**""" |
| 2 | +from contextlib import contextmanager |
| 3 | +import logging |
| 4 | +import os |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from airflow.models import DagBag, Variable, Connection |
| 9 | +from airflow.hooks.base import BaseHook |
| 10 | +from airflow.utils.db import initdb |
| 11 | + |
| 12 | +# init airflow database |
| 13 | +initdb() |
| 14 | + |
| 15 | +# The following code patches errors caused by missing OS Variables, Airflow Connections, and Airflow Variables |
| 16 | + |
| 17 | + |
| 18 | +# =========== MONKEYPATCH BaseHook.get_connection() =========== |
| 19 | +def basehook_get_connection_monkeypatch(key: str, *args, **kwargs): |
| 20 | + print( |
| 21 | + f"Attempted to fetch connection during parse returning an empty Connection object for {key}" |
| 22 | + ) |
| 23 | + return Connection(key) |
| 24 | + |
| 25 | + |
| 26 | +BaseHook.get_connection = basehook_get_connection_monkeypatch |
| 27 | +# # =========== /MONKEYPATCH BASEHOOK.GET_CONNECTION() =========== |
| 28 | + |
| 29 | + |
| 30 | +# =========== MONKEYPATCH OS.GETENV() =========== |
| 31 | +def os_getenv_monkeypatch(key: str, *args, **kwargs): |
| 32 | + default = None |
| 33 | + if args: |
| 34 | + default = args[0] # os.getenv should get at most 1 arg after the key |
| 35 | + if kwargs: |
| 36 | + default = kwargs.get( |
| 37 | + "default", None |
| 38 | + ) # and sometimes kwarg if people are using the sig |
| 39 | + |
| 40 | + env_value = os.environ.get(key, None) |
| 41 | + |
| 42 | + if env_value: |
| 43 | + return env_value # if the env_value is set, return it |
| 44 | + if ( |
| 45 | + key == "JENKINS_HOME" and default is None |
| 46 | + ): # fix https://github.com/astronomer/astro-cli/issues/601 |
| 47 | + return None |
| 48 | + if default: |
| 49 | + return default # otherwise return whatever default has been passed |
| 50 | + return f"MOCKED_{key.upper()}_VALUE" # if absolutely nothing has been passed - return the mocked value |
| 51 | + |
| 52 | + |
| 53 | +os.getenv = os_getenv_monkeypatch |
| 54 | +# # =========== /MONKEYPATCH OS.GETENV() =========== |
| 55 | + |
| 56 | +# =========== MONKEYPATCH VARIABLE.GET() =========== |
| 57 | + |
| 58 | + |
| 59 | +class magic_dict(dict): |
| 60 | + def __init__(self, *args, **kwargs): |
| 61 | + self.update(*args, **kwargs) |
| 62 | + |
| 63 | + def __getitem__(self, key): |
| 64 | + return {}.get(key, "MOCKED_KEY_VALUE") |
| 65 | + |
| 66 | + |
| 67 | +def variable_get_monkeypatch(key: str, default_var=None, deserialize_json=False): |
| 68 | + print( |
| 69 | + f"Attempted to get Variable value during parse, returning a mocked value for {key}" |
| 70 | + ) |
| 71 | + |
| 72 | + if default_var: |
| 73 | + return default_var |
| 74 | + if deserialize_json: |
| 75 | + return magic_dict() |
| 76 | + return "NON_DEFAULT_MOCKED_VARIABLE_VALUE" |
| 77 | + |
| 78 | + |
| 79 | +Variable.get = variable_get_monkeypatch |
| 80 | +# # =========== /MONKEYPATCH VARIABLE.GET() =========== |
| 81 | + |
| 82 | + |
| 83 | +@contextmanager |
| 84 | +def suppress_logging(namespace): |
| 85 | + """ |
| 86 | + Suppress logging within a specific namespace to keep tests "clean" during build |
| 87 | + """ |
| 88 | + logger = logging.getLogger(namespace) |
| 89 | + old_value = logger.disabled |
| 90 | + logger.disabled = True |
| 91 | + try: |
| 92 | + yield |
| 93 | + finally: |
| 94 | + logger.disabled = old_value |
| 95 | + |
| 96 | + |
| 97 | +def get_import_errors(): |
| 98 | + """ |
| 99 | + Generate a tuple for import errors in the dag bag |
| 100 | + """ |
| 101 | + with suppress_logging("airflow"): |
| 102 | + dag_bag = DagBag(include_examples=False) |
| 103 | + |
| 104 | + def strip_path_prefix(path): |
| 105 | + return os.path.relpath(path, os.environ.get("AIRFLOW_HOME")) |
| 106 | + |
| 107 | + # prepend "(None,None)" to ensure that a test object is always created even if it's a no op. |
| 108 | + return [(None, None)] + [ |
| 109 | + (strip_path_prefix(k), v.strip()) for k, v in dag_bag.import_errors.items() |
| 110 | + ] |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.parametrize( |
| 114 | + "rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()] |
| 115 | +) |
| 116 | +def test_file_imports(rel_path, rv): |
| 117 | + """Test for import errors on a file""" |
| 118 | + if rel_path and rv: # Make sure our no op test doesn't raise an error |
| 119 | + raise Exception(f"{rel_path} failed to import with message \n {rv}") |
0 commit comments