Skip to content

Commit a3295a8

Browse files
authored
AIP-86 - Add Deadline References (#50677)
* AIP-86 - Add Deadline References
1 parent 3bd9746 commit a3295a8

File tree

4 files changed

+385
-23
lines changed

4 files changed

+385
-23
lines changed

airflow-core/src/airflow/models/deadline.py

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,30 @@
1717
from __future__ import annotations
1818

1919
import logging
20+
from abc import ABC, abstractmethod
21+
from dataclasses import dataclass
2022
from datetime import datetime, timedelta
21-
from enum import Enum
22-
from typing import TYPE_CHECKING, Callable
23+
from typing import TYPE_CHECKING, Any, Callable
2324

2425
import sqlalchemy_jsonfield
2526
import uuid6
26-
from sqlalchemy import Column, ForeignKey, Index, Integer, String
27+
from sqlalchemy import Column, ForeignKey, Index, Integer, String, select
28+
from sqlalchemy.exc import SQLAlchemyError
2729
from sqlalchemy.orm import relationship
2830
from sqlalchemy_utils import UUIDType
2931

3032
from airflow.models.base import Base, StringID
3133
from airflow.settings import json
34+
from airflow.utils.log.logging_mixin import LoggingMixin
3235
from airflow.utils.module_loading import import_string, is_valid_dotpath
3336
from airflow.utils.session import NEW_SESSION, provide_session
3437
from airflow.utils.sqlalchemy import UtcDateTime
3538

3639
if TYPE_CHECKING:
3740
from sqlalchemy.orm import Session
3841

42+
from airflow.sdk.definitions.deadline import DeadlineReference
43+
3944
logger = logging.getLogger(__name__)
4045

4146

@@ -100,33 +105,67 @@ def add_deadline(cls, deadline: Deadline, session: Session = NEW_SESSION):
100105
session.add(deadline)
101106

102107

103-
class DeadlineReference(Enum):
108+
class ReferenceModels:
104109
"""
105-
Store the calculation methods for the various Deadline Alert triggers.
110+
Store the implementations for the different Deadline References.
106111
107-
TODO: PLEASE NOTE This class is a placeholder and will be expanded in the next PR.
112+
After adding the implementations here, all DeadlineReferences should be added
113+
to the user interface in airflow.sdk.definitions.deadline.DeadlineReference
114+
"""
108115

109-
------
110-
Usage:
111-
------
116+
class BaseDeadlineReference(LoggingMixin, ABC):
117+
"""Base class for all Deadline implementations."""
112118

113-
Example use when defining a deadline in a DAG:
119+
# Set of required kwargs - subclasses should override this.
120+
required_kwargs: set[str] = set()
114121

115-
DAG(
116-
dag_id='dag_with_deadline',
117-
deadline=DeadlineAlert(
118-
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
119-
interval=timedelta(hours=1),
120-
callback=hello_callback,
121-
)
122-
)
122+
def evaluate_with(self, **kwargs: Any) -> datetime:
123+
"""Validate the provided kwargs and evaluate this deadline with the given conditions."""
124+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs}
123125

124-
To parse the deadline reference later we will use something like:
126+
if missing_kwargs := self.required_kwargs - filtered_kwargs.keys():
127+
raise ValueError(
128+
f"{self.__class__.__name__} is missing required parameters: {', '.join(missing_kwargs)}"
129+
)
125130

126-
dag.deadline.reference.evaluate_with(dag_id=dag.dag_id)
127-
"""
131+
if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
132+
self.log.debug("Ignoring unexpected parameters: %s", ", ".join(extra_kwargs))
133+
134+
return self._evaluate_with(**filtered_kwargs)
135+
136+
@abstractmethod
137+
def _evaluate_with(self, **kwargs: Any) -> datetime:
138+
"""Must be implemented by subclasses to perform the actual evaluation."""
139+
raise NotImplementedError
140+
141+
@dataclass
142+
class FixedDatetimeDeadline(BaseDeadlineReference):
143+
"""A deadline that always returns a fixed datetime."""
144+
145+
_datetime: datetime
146+
147+
def _evaluate_with(self, **kwargs: Any) -> datetime:
148+
return self._datetime
149+
150+
class DagRunLogicalDateDeadline(BaseDeadlineReference):
151+
"""A deadline that returns a DagRun's logical date."""
152+
153+
required_kwargs = {"dag_id"}
154+
155+
def _evaluate_with(self, **kwargs: Any) -> datetime:
156+
from airflow.models import DagRun
157+
158+
return _fetch_from_db(DagRun.logical_date, **kwargs)
159+
160+
class DagRunQueuedAtDeadline(BaseDeadlineReference):
161+
"""A deadline that returns when a DagRun was queued."""
128162

129-
DAGRUN_LOGICAL_DATE = "dagrun_logical_date"
163+
required_kwargs = {"dag_id"}
164+
165+
def _evaluate_with(self, **kwargs: Any) -> datetime:
166+
from airflow.models import DagRun
167+
168+
return _fetch_from_db(DagRun.queued_at, **kwargs)
130169

131170

132171
class DeadlineAlert:
@@ -186,3 +225,52 @@ def serialize_deadline_alert(self):
186225
"callback_kwargs": self.callback_kwargs,
187226
}
188227
)
228+
229+
230+
@provide_session
231+
def _fetch_from_db(model_reference: Column, session=None, **conditions) -> datetime:
232+
"""
233+
Fetch a datetime value from the database using the provided model reference and filtering conditions.
234+
235+
For example, to fetch a TaskInstance's start_date:
236+
_fetch_from_db(
237+
TaskInstance.start_date, dag_id='example_dag', task_id='example_task', run_id='example_run'
238+
)
239+
240+
This generates SQL equivalent to:
241+
SELECT start_date
242+
FROM task_instance
243+
WHERE dag_id = 'example_dag'
244+
AND task_id = 'example_task'
245+
AND run_id = 'example_run'
246+
247+
:param model_reference: SQLAlchemy Column to select (e.g., DagRun.logical_date, TaskInstance.start_date)
248+
:param conditions: Filtering conditions applied as equality comparisons in the WHERE clause.
249+
Multiple conditions are combined with AND.
250+
:param session: SQLAlchemy session (auto-provided by decorator)
251+
"""
252+
query = select(model_reference)
253+
254+
for key, value in conditions.items():
255+
query = query.where(getattr(model_reference.class_, key) == value)
256+
257+
compiled_query = query.compile(compile_kwargs={"literal_binds": True})
258+
pretty_query = "\n ".join(str(compiled_query).splitlines())
259+
logger.debug(
260+
"Executing query:\n %r\nAs SQL:\n %s",
261+
query,
262+
pretty_query,
263+
)
264+
265+
try:
266+
result = session.scalar(query)
267+
except SQLAlchemyError:
268+
logger.exception("Database query failed.")
269+
raise
270+
271+
if result is None:
272+
message = f"No matching record found in the database for query:\n {pretty_query}"
273+
logger.error(message)
274+
raise ValueError(message)
275+
276+
return result

airflow-core/tests/unit/models/test_deadline.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818

1919
import json
2020
import logging
21+
from datetime import datetime
22+
from unittest import mock
2123

2224
import pytest
2325
import time_machine
2426
from sqlalchemy import select
27+
from sqlalchemy.exc import SQLAlchemyError
2528

2629
from airflow.models import DagRun
27-
from airflow.models.deadline import Deadline, DeadlineAlert
30+
from airflow.models.deadline import Deadline, DeadlineAlert, _fetch_from_db
2831
from airflow.providers.standard.operators.empty import EmptyOperator
32+
from airflow.sdk.definitions.deadline import DeadlineReference
2933
from airflow.utils.state import DagRunState
3034

3135
from tests_common.test_utils import db
@@ -182,3 +186,126 @@ def test_log_unimportable_but_properly_formatted_callback(self, caplog):
182186

183187
assert "could not be imported" in caplog.text
184188
assert path == UNIMPORTABLE_DOT_PATH
189+
190+
191+
@pytest.mark.db_test
192+
class TestCalculatedDeadlineDatabaseCalls:
193+
@staticmethod
194+
def setup_method():
195+
_clean_db()
196+
197+
@staticmethod
198+
def teardown_method():
199+
_clean_db()
200+
201+
@pytest.mark.parametrize(
202+
"column, conditions, expected_query",
203+
[
204+
pytest.param(
205+
DagRun.logical_date,
206+
{"dag_id": DAG_ID},
207+
"SELECT dag_run.logical_date \nFROM dag_run \nWHERE dag_run.dag_id = :dag_id_1",
208+
id="single_condition_logical_date",
209+
),
210+
pytest.param(
211+
DagRun.queued_at,
212+
{"dag_id": DAG_ID},
213+
"SELECT dag_run.queued_at \nFROM dag_run \nWHERE dag_run.dag_id = :dag_id_1",
214+
id="single_condition_queued_at",
215+
),
216+
pytest.param(
217+
DagRun.logical_date,
218+
{"dag_id": DAG_ID, "state": "running"},
219+
"SELECT dag_run.logical_date \nFROM dag_run \nWHERE dag_run.dag_id = :dag_id_1 AND dag_run.state = :state_1",
220+
id="multiple_conditions",
221+
),
222+
],
223+
)
224+
@mock.patch("sqlalchemy.orm.Session")
225+
def test_fetch_from_db_success(self, mock_session, column, conditions, expected_query):
226+
"""Test successful database queries."""
227+
mock_session.scalar.return_value = DEFAULT_DATE
228+
229+
result = _fetch_from_db(column, session=mock_session, **conditions)
230+
231+
assert isinstance(result, datetime)
232+
mock_session.scalar.assert_called_once()
233+
234+
# Check that the correct query was constructed
235+
call_args = mock_session.scalar.call_args[0][0]
236+
assert str(call_args) == expected_query
237+
238+
# Verify the actual parameter values
239+
compiled = call_args.compile()
240+
for key, value in conditions.items():
241+
# Note that SQLAlchemy appends the _1 to ensure unique template field names
242+
assert compiled.params[f"{key}_1"] == value
243+
244+
@pytest.mark.parametrize(
245+
"use_valid_conditions, scalar_side_effect, expected_error, expected_message",
246+
[
247+
pytest.param(
248+
False,
249+
mock.DEFAULT, # This will allow the call to pass through
250+
AttributeError,
251+
None,
252+
id="invalid_attribute",
253+
),
254+
pytest.param(
255+
True,
256+
SQLAlchemyError("Database connection failed"),
257+
SQLAlchemyError,
258+
"Database connection failed",
259+
id="database_error",
260+
),
261+
pytest.param(
262+
True, lambda x: None, ValueError, "No matching record found in the database", id="no_results"
263+
),
264+
],
265+
)
266+
@mock.patch("sqlalchemy.orm.Session")
267+
def test_fetch_from_db_error_cases(
268+
self, mock_session, use_valid_conditions, scalar_side_effect, expected_error, expected_message, caplog
269+
):
270+
"""Test database access error handling."""
271+
model_reference = DagRun.logical_date
272+
conditions = {"dag_id": "test_dag"} if use_valid_conditions else {"non_existent_column": "some_value"}
273+
274+
# Configure mock session
275+
mock_session.scalar.side_effect = scalar_side_effect
276+
277+
with caplog.at_level(logging.ERROR):
278+
with pytest.raises(expected_error, match=expected_message):
279+
_fetch_from_db(model_reference, session=mock_session, **conditions)
280+
if expected_message:
281+
assert expected_message in caplog.text
282+
283+
@pytest.mark.parametrize(
284+
"reference, expected_column",
285+
[
286+
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, DagRun.logical_date, id="logical_date"),
287+
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, DagRun.queued_at, id="queued_at"),
288+
pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), None, id="fixed_deadline"),
289+
],
290+
)
291+
def test_deadline_database_integration(self, reference, expected_column):
292+
"""
293+
Test database integration for all deadline types.
294+
295+
Verifies:
296+
1. Calculated deadlines call _fetch_from_db with correct column.
297+
2. Fixed deadlines do not interact with database.
298+
"""
299+
conditions = {"dag_id": DAG_ID}
300+
301+
with mock.patch("airflow.models.deadline._fetch_from_db") as mock_fetch:
302+
mock_fetch.return_value = DEFAULT_DATE
303+
304+
if expected_column is not None:
305+
result = reference.evaluate_with(**conditions)
306+
mock_fetch.assert_called_once_with(expected_column, **conditions)
307+
else:
308+
result = reference.evaluate_with(**conditions)
309+
mock_fetch.assert_not_called()
310+
311+
assert result == DEFAULT_DATE
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from datetime import datetime
20+
21+
22+
class DeadlineReference:
23+
"""
24+
The public interface class for all DeadlineReference options.
25+
26+
This class provides a unified interface for working with Deadlines, supporting both
27+
calculated deadlines (which fetch values from the database) and fixed deadlines
28+
(which return a predefined datetime).
29+
30+
------
31+
Usage:
32+
------
33+
34+
1. Example deadline references:
35+
fixed = DeadlineReference.FIXED_DATETIME(datetime(2025, 5, 4))
36+
logical = DeadlineReference.DAGRUN_LOGICAL_DATE
37+
queued = DeadlineReference.DAGRUN_QUEUED_AT
38+
39+
2. Using in a DAG:
40+
DAG(
41+
dag_id='dag_with_deadline',
42+
deadline=DeadlineAlert(
43+
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
44+
interval=timedelta(hours=1),
45+
callback=hello_callback,
46+
)
47+
)
48+
49+
3. Evaluating deadlines will ignore unexpected parameters:
50+
# For deadlines requiring parameters:
51+
deadline = DeadlineReference.DAGRUN_LOGICAL_DATE
52+
deadline.evaluate_with(dag_id=dag.dag_id)
53+
54+
# For deadlines with no required parameters:
55+
deadline = DeadlineReference.FIXED_DATETIME(datetime(2025, 5, 4))
56+
deadline.evaluate_with()
57+
"""
58+
59+
from airflow.models.deadline import ReferenceModels
60+
61+
DAGRUN_LOGICAL_DATE: ReferenceModels.BaseDeadlineReference = ReferenceModels.DagRunLogicalDateDeadline()
62+
DAGRUN_QUEUED_AT: ReferenceModels.BaseDeadlineReference = ReferenceModels.DagRunQueuedAtDeadline()
63+
64+
@classmethod
65+
def FIXED_DATETIME(cls, datetime: datetime) -> ReferenceModels.BaseDeadlineReference:
66+
return cls.ReferenceModels.FixedDatetimeDeadline(datetime)

0 commit comments

Comments
 (0)