Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0ddd70b

Browse files
authoredJun 11, 2025··
Fix serialization of DeadlineAlert and add unit tests to prevent regr… (#51494)
* Fix serialization of DeadlineAlert and add unit tests to prevent regression
1 parent 8dd6079 commit 0ddd70b

File tree

9 files changed

+400
-26
lines changed

9 files changed

+400
-26
lines changed
 

‎airflow-core/src/airflow/models/dag.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from airflow.sdk import TaskGroup
8888
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, BaseAsset
8989
from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as task_sdk_dag_decorator
90+
from airflow.sdk.definitions.deadline import DeadlineAlert
9091
from airflow.settings import json
9192
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
9293
from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
@@ -1917,7 +1918,7 @@ class DagModel(Base):
19171918
# Asset expression based on asset triggers
19181919
asset_expression = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
19191920
# DAG deadline information
1920-
deadline = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
1921+
_deadline = Column("deadline", sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
19211922
# Tags for view filter
19221923
tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag"))
19231924
# Dag owner links for DAGs view
@@ -2011,6 +2012,16 @@ def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) ->
20112012
else:
20122013
self.next_dagrun_data_interval_start, self.next_dagrun_data_interval_end = value
20132014

2015+
@property
2016+
def deadline(self):
2017+
"""Get the deserialized deadline alert."""
2018+
return DeadlineAlert.deserialize_deadline_alert(self._deadline) if self._deadline else None
2019+
2020+
@deadline.setter
2021+
def deadline(self, value):
2022+
"""Set and serialize the deadline alert."""
2023+
self._deadline = None if value is None else value.serialize_deadline_alert()
2024+
20142025
@property
20152026
def timezone(self):
20162027
return settings.TIMEZONE

‎airflow-core/src/airflow/models/deadline.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
from airflow.models.base import Base, StringID
3333
from airflow.settings import json
34+
from airflow.utils import timezone
35+
from airflow.utils.decorators import classproperty
3436
from airflow.utils.log.logging_mixin import LoggingMixin
3537
from airflow.utils.session import NEW_SESSION, provide_session
3638
from airflow.utils.sqlalchemy import UtcDateTime
@@ -111,12 +113,36 @@ class ReferenceModels:
111113
to the user interface in airflow.sdk.definitions.deadline.DeadlineReference
112114
"""
113115

116+
REFERENCE_TYPE_FIELD = "reference_type"
117+
118+
@classmethod
119+
def get_reference_class(cls, reference_name: str) -> type[BaseDeadlineReference]:
120+
"""
121+
Get a reference class by its name.
122+
123+
:param reference_name: The name of the reference class to find
124+
"""
125+
try:
126+
return next(
127+
ref_class
128+
for name, ref_class in vars(cls).items()
129+
if isinstance(ref_class, type)
130+
and issubclass(ref_class, cls.BaseDeadlineReference)
131+
and ref_class.__name__ == reference_name
132+
)
133+
except StopIteration:
134+
raise ValueError(f"No reference class found with name: {reference_name}")
135+
114136
class BaseDeadlineReference(LoggingMixin, ABC):
115137
"""Base class for all Deadline implementations."""
116138

117139
# Set of required kwargs - subclasses should override this.
118140
required_kwargs: set[str] = set()
119141

142+
@classproperty
143+
def reference_name(cls: Any) -> str:
144+
return cls.__name__
145+
120146
def evaluate_with(self, **kwargs: Any) -> datetime:
121147
"""Validate the provided kwargs and evaluate this deadline with the given conditions."""
122148
filtered_kwargs = {k: v for k, v in kwargs.items() if k in self.required_kwargs}
@@ -136,6 +162,30 @@ def _evaluate_with(self, **kwargs: Any) -> datetime:
136162
"""Must be implemented by subclasses to perform the actual evaluation."""
137163
raise NotImplementedError
138164

165+
@classmethod
166+
def deserialize_reference(cls, reference_data: dict):
167+
"""
168+
Deserialize a reference type from its dictionary representation.
169+
170+
While the base implementation doesn't use reference_data, this parameter is required
171+
for subclasses that need additional data for initialization (like FixedDatetimeDeadline
172+
which needs a datetime value).
173+
174+
:param reference_data: Dictionary containing serialized reference data.
175+
Always includes a 'reference_type' field, and may include additional
176+
fields needed by specific reference implementations.
177+
"""
178+
return cls()
179+
180+
def serialize_reference(self) -> dict:
181+
"""
182+
Serialize this reference type into a dictionary representation.
183+
184+
This method assumes that the reference doesn't require any additional data.
185+
Override this method in subclasses if additional data is needed for serialization.
186+
"""
187+
return {ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name}
188+
139189
@dataclass
140190
class FixedDatetimeDeadline(BaseDeadlineReference):
141191
"""A deadline that always returns a fixed datetime."""
@@ -145,6 +195,16 @@ class FixedDatetimeDeadline(BaseDeadlineReference):
145195
def _evaluate_with(self, **kwargs: Any) -> datetime:
146196
return self._datetime
147197

198+
def serialize_reference(self) -> dict:
199+
return {
200+
ReferenceModels.REFERENCE_TYPE_FIELD: self.reference_name,
201+
"datetime": self._datetime.timestamp(),
202+
}
203+
204+
@classmethod
205+
def deserialize_reference(cls, reference_data: dict):
206+
return cls(_datetime=timezone.from_timestamp(reference_data["datetime"]))
207+
148208
class DagRunLogicalDateDeadline(BaseDeadlineReference):
149209
"""A deadline that returns a DagRun's logical date."""
150210

‎airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,8 @@ def deserialize(cls, encoded_var: Any) -> Any:
951951
return TaskInstanceKey(**var)
952952
elif type_ == DAT.ARG_NOT_SET:
953953
return NOTSET
954+
elif type_ == DAT.DEADLINE_ALERT:
955+
return DeadlineAlert.deserialize_deadline_alert(var)
954956
else:
955957
raise TypeError(f"Invalid type {type_!s} in deserialization.")
956958

@@ -1779,6 +1781,9 @@ def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
17791781
if "has_on_failure_callback" in encoded_dag:
17801782
dag.has_on_failure_callback = True
17811783

1784+
if "deadline" in encoded_dag and encoded_dag["deadline"] is not None:
1785+
dag.deadline = DeadlineAlert.deserialize_deadline_alert(encoded_dag["deadline"])
1786+
17821787
keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
17831788
for k in keys_to_set_none:
17841789
setattr(dag, k, None)

‎airflow-core/src/airflow/utils/decorators.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,30 @@ def fixup_decorator_warning_stack(func, delta: int = 2):
8686
# Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to
8787
# `warnings.warn` to ignore the decorator.
8888
func.__globals__["warnings"] = _autostacklevel_warn(delta)
89+
90+
91+
class classproperty:
92+
"""
93+
Decorator that converts a method with a single cls argument into a property.
94+
95+
Mypy won't let us use both @property and @classmethod together, this is a workaround
96+
to combine the two.
97+
98+
Usage:
99+
100+
class Circle:
101+
def __init__(self, radius):
102+
self.radius = radius
103+
104+
@classproperty
105+
def pi(cls):
106+
return 3.14159
107+
108+
print(Circle.pi) # Outputs: 3.14159
109+
"""
110+
111+
def __init__(self, method):
112+
self.method = method
113+
114+
def __get__(self, instance, cls=None):
115+
return self.method(cls)

‎airflow-core/tests/unit/models/test_dagrun.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from airflow.providers.standard.operators.empty import EmptyOperator
4444
from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator
4545
from airflow.sdk import setup, task, task_group, teardown
46+
from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
4647
from airflow.serialization.serialized_objects import SerializedDAG
4748
from airflow.stats import Stats
4849
from airflow.triggers.base import StartTriggerArgs
@@ -68,6 +69,11 @@
6869
DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
6970

7071

72+
def test_callback_for_deadline():
73+
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
74+
pass
75+
76+
7177
@pytest.fixture(scope="module")
7278
def dagbag():
7379
from airflow.models.dagbag import DagBag
@@ -1244,6 +1250,40 @@ def test_dag_run_version_number(self, dag_maker, session):
12441250
# the latest task instance dag_version
12451251
assert dag_run.version_number == dag_v.version_number
12461252

1253+
def test_dagrun_success_deadline(self, dag_maker, session):
1254+
def on_success_callable(context):
1255+
assert context["dag_run"].dag_id == "test_dagrun_success_callback"
1256+
1257+
with dag_maker(
1258+
dag_id="test_dagrun_success_callback",
1259+
schedule=datetime.timedelta(days=1),
1260+
start_date=datetime.datetime(2017, 1, 1),
1261+
on_success_callback=on_success_callable,
1262+
deadline=DeadlineAlert(
1263+
reference=DeadlineReference.FIXED_DATETIME(DEFAULT_DATE),
1264+
interval=datetime.timedelta(hours=1),
1265+
callback=test_callback_for_deadline,
1266+
),
1267+
) as dag:
1268+
...
1269+
dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
1270+
dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
1271+
dag_task1.set_downstream(dag_task2)
1272+
1273+
initial_task_states = {
1274+
"test_state_succeeded1": TaskInstanceState.SUCCESS,
1275+
"test_state_succeeded2": TaskInstanceState.SUCCESS,
1276+
}
1277+
1278+
# Scheduler uses Serialized DAG -- so use that instead of the Actual DAG
1279+
dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
1280+
1281+
dag_run = self.create_dag_run(dag=dag, task_states=initial_task_states, session=session)
1282+
_, callback = dag_run.update_state()
1283+
assert dag_run.state == DagRunState.SUCCESS
1284+
# Callbacks are not added until handle_callback = False is passed to dag_run.update_state()
1285+
assert callback is None
1286+
12471287

12481288
@pytest.mark.parametrize(
12491289
("run_type", "expected_tis"),

‎airflow-core/tests/unit/models/test_deadline.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,19 @@
3737
DAG_ID = "dag_id_1"
3838
RUN_ID = 1
3939

40-
TEST_CALLBACK_KWARGS = {"to": "the_boss@work.com"}
41-
TEST_CALLBACK_PATH = f"{__name__}.test_callback"
40+
TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline"
41+
TEST_CALLBACK_KWARGS = {"arg1": "value1"}
42+
43+
REFERENCE_TYPES = [
44+
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
45+
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
46+
pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), id="fixed_deadline"),
47+
]
48+
49+
50+
def test_callback_for_deadline():
51+
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
52+
pass
4253

4354

4455
def _clean_db():

‎airflow-core/tests/unit/serialization/test_serialized_objects.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from airflow.providers.standard.operators.python import PythonOperator
4545
from airflow.providers.standard.triggers.file import FileDeleteTrigger
4646
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetUniqueKey, AssetWatcher
47+
from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineAlertFields, DeadlineReference
4748
from airflow.sdk.definitions.decorators import task
4849
from airflow.sdk.definitions.param import Param
4950
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
@@ -57,6 +58,24 @@
5758
from airflow.utils.task_group import TaskGroup
5859
from airflow.utils.types import DagRunType
5960

61+
from unit.models import DEFAULT_DATE
62+
63+
DAG_ID = "dag_id_1"
64+
65+
TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline"
66+
TEST_CALLBACK_KWARGS = {"arg1": "value1"}
67+
68+
REFERENCE_TYPES = [
69+
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
70+
pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
71+
pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), id="fixed_deadline"),
72+
]
73+
74+
75+
def test_callback_for_deadline():
76+
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
77+
pass
78+
6079

6180
def test_recursive_serialize_calls_must_forward_kwargs():
6281
"""Any time we recurse cls.serialize, we must forward all kwargs."""
@@ -315,6 +334,16 @@ def __len__(self) -> int:
315334
DAT.DAG,
316335
lambda _, b: list(b.task_group.children.keys()) == sorted(b.task_group.children.keys()),
317336
),
337+
(
338+
DeadlineAlert(
339+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
340+
interval=timedelta(hours=1),
341+
callback="valid.callback.path",
342+
callback_kwargs={"arg1": "value1"},
343+
),
344+
DAT.DEADLINE_ALERT,
345+
equals,
346+
),
318347
],
319348
)
320349
def test_serialize_deserialize(input, encoded_type, cmp_func):
@@ -323,8 +352,8 @@ def test_serialize_deserialize(input, encoded_type, cmp_func):
323352
serialized = BaseSerialization.serialize(input) # does not raise
324353
json.dumps(serialized) # does not raise
325354
if encoded_type is not None:
326-
assert serialized["__type"] == encoded_type
327-
assert serialized["__var"] is not None
355+
assert serialized[Encoding.TYPE] == encoded_type
356+
assert serialized[Encoding.VAR] is not None
328357
if cmp_func is not None:
329358
deserialized = BaseSerialization.deserialize(serialized)
330359
assert cmp_func(input, deserialized)
@@ -336,6 +365,30 @@ def test_serialize_deserialize(input, encoded_type, cmp_func):
336365
json.dumps(serialized) # does not raise
337366

338367

368+
@pytest.mark.parametrize("reference", REFERENCE_TYPES)
369+
def test_serialize_deserialize_deadline_alert(reference):
370+
public_deadline_alert_fields = {
371+
field.lower() for field in vars(DeadlineAlertFields) if not field.startswith("_")
372+
}
373+
original = DeadlineAlert(
374+
reference=reference,
375+
interval=timedelta(hours=1),
376+
callback=test_callback_for_deadline,
377+
callback_kwargs=TEST_CALLBACK_KWARGS,
378+
)
379+
380+
serialized = original.serialize_deadline_alert()
381+
assert serialized[Encoding.TYPE] == DAT.DEADLINE_ALERT
382+
assert set(serialized[Encoding.VAR].keys()) == public_deadline_alert_fields
383+
384+
deserialized = DeadlineAlert.deserialize_deadline_alert(serialized)
385+
assert deserialized.reference.serialize_reference() == reference.serialize_reference()
386+
assert deserialized.interval == original.interval
387+
assert deserialized.callback_kwargs == original.callback_kwargs
388+
assert isinstance(deserialized.callback, str)
389+
assert deserialized.callback == TEST_CALLBACK_PATH
390+
391+
339392
@pytest.mark.parametrize(
340393
"conn_uri",
341394
[

‎task-sdk/src/airflow/sdk/definitions/deadline.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from datetime import datetime, timedelta
2121
from typing import TYPE_CHECKING, Callable
2222

23+
from airflow.models.deadline import ReferenceModels
24+
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
2325
from airflow.utils.module_loading import import_string, is_valid_dotpath
2426

2527
if TYPE_CHECKING:
@@ -28,6 +30,20 @@
2830
logger = logging.getLogger(__name__)
2931

3032

33+
class DeadlineAlertFields:
34+
"""
35+
Define field names used in DeadlineAlert serialization/deserialization.
36+
37+
These constants provide a single source of truth for the field names used when
38+
serializing DeadlineAlert instances to and from their dictionary representation.
39+
"""
40+
41+
REFERENCE = "reference"
42+
INTERVAL = "interval"
43+
CALLBACK = "callback"
44+
CALLBACK_KWARGS = "callback_kwargs"
45+
46+
3147
class DeadlineAlert:
3248
"""Store Deadline values needed to calculate the need-by timestamp and the callback information."""
3349

@@ -40,12 +56,35 @@ def __init__(
4056
):
4157
self.reference = reference
4258
self.interval = interval
43-
self.callback_kwargs = callback_kwargs
59+
self.callback_kwargs = callback_kwargs or {}
4460
self.callback = self.get_callback_path(callback)
4561

62+
def __eq__(self, other: object) -> bool:
63+
if not isinstance(other, DeadlineAlert):
64+
return NotImplemented
65+
return (
66+
isinstance(self.reference, type(other.reference))
67+
and self.interval == other.interval
68+
and self.callback == other.callback
69+
and self.callback_kwargs == other.callback_kwargs
70+
)
71+
72+
def __hash__(self) -> int:
73+
return hash(
74+
(
75+
type(self.reference).__name__,
76+
self.interval,
77+
self.callback,
78+
tuple(sorted(self.callback_kwargs.items())) if self.callback_kwargs else None,
79+
)
80+
)
81+
4682
@staticmethod
4783
def get_callback_path(_callback: str | Callable) -> str:
84+
"""Convert callback to a string path that can be used to import it later."""
4885
if callable(_callback):
86+
# TODO: This implementation doesn't support using a lambda function as a callback.
87+
# We should consider that in the future, but the addition is non-trivial.
4988
# Get the reference path to the callable in the form `airflow.models.deadline.get_from_db`
5089
return f"{_callback.__module__}.{_callback.__qualname__}"
5190

@@ -57,6 +96,9 @@ def get_callback_path(_callback: str | Callable) -> str:
5796
try:
5897
# The provided callback is a string which appears to be a valid dotpath, attempt to import it.
5998
callback = import_string(stripped_callback)
99+
if not callable(callback):
100+
# The input is a string which can be imported, but is not callable.
101+
raise AttributeError(f"Provided callback {callback} is not callable.")
60102
except ImportError as e:
61103
# Logging here instead of failing because it is possible that the code for the callable
62104
# exists somewhere other than on the DAG processor. We are making a best effort to validate,
@@ -66,24 +108,37 @@ def get_callback_path(_callback: str | Callable) -> str:
66108
stripped_callback,
67109
e,
68110
)
69-
return stripped_callback
70-
71-
# If we get this far then the input is a string which can be imported, check if it is a callable.
72-
if not callable(callback):
73-
raise AttributeError(f"Provided callback {callback} is not callable.")
74111

75112
return stripped_callback
76113

77114
def serialize_deadline_alert(self):
78-
from airflow.serialization.serialized_objects import BaseSerialization
79-
80-
return BaseSerialization.serialize(
81-
{
82-
"reference": self.reference,
83-
"interval": self.interval,
84-
"callback": self.callback,
85-
"callback_kwargs": self.callback_kwargs,
86-
}
115+
"""Return the data in a format that BaseSerialization can handle."""
116+
return {
117+
Encoding.TYPE: DAT.DEADLINE_ALERT,
118+
Encoding.VAR: {
119+
DeadlineAlertFields.REFERENCE: self.reference.serialize_reference(),
120+
DeadlineAlertFields.INTERVAL: self.interval.total_seconds(),
121+
DeadlineAlertFields.CALLBACK: self.callback, # Already stored as a string path
122+
DeadlineAlertFields.CALLBACK_KWARGS: self.callback_kwargs,
123+
},
124+
}
125+
126+
@classmethod
127+
def deserialize_deadline_alert(cls, encoded_data: dict) -> DeadlineAlert:
128+
"""Deserialize a DeadlineAlert from serialized data."""
129+
data = encoded_data.get(Encoding.VAR, encoded_data)
130+
131+
reference_data = data[DeadlineAlertFields.REFERENCE]
132+
reference_type = reference_data[ReferenceModels.REFERENCE_TYPE_FIELD]
133+
134+
reference_class = ReferenceModels.get_reference_class(reference_type)
135+
reference = reference_class.deserialize_reference(reference_data)
136+
137+
return cls(
138+
reference=reference,
139+
interval=timedelta(seconds=data[DeadlineAlertFields.INTERVAL]),
140+
callback=data[DeadlineAlertFields.CALLBACK], # Keep as string path
141+
callback_kwargs=data[DeadlineAlertFields.CALLBACK_KWARGS],
87142
)
88143

89144

‎task-sdk/tests/task_sdk/definitions/test_deadline.py

Lines changed: 118 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,22 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from datetime import datetime, timedelta
1920
from unittest import mock
2021

2122
import pytest
22-
from task_sdk.definitions.test_dag import DEFAULT_DATE
2323

2424
from airflow.models.deadline import ReferenceModels
2525
from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
2626

27+
UNIMPORTABLE_DOT_PATH = "valid.but.nonexistent.path"
28+
2729
DAG_ID = "dag_id_1"
30+
RUN_ID = 1
31+
DEFAULT_DATE = datetime(2025, 6, 26)
2832

29-
TEST_CALLBACK_PATH = f"{__name__}.test_callback"
30-
UNIMPORTABLE_DOT_PATH = "valid.but.nonexistent.path"
33+
TEST_CALLBACK_PATH = f"{__name__}.test_callback_for_deadline"
34+
TEST_CALLBACK_KWARGS = {"arg1": "value1"}
3135

3236
REFERENCE_TYPES = [
3337
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
@@ -36,16 +40,16 @@
3640
]
3741

3842

39-
def test_callback():
40-
"""An empty Callable to use for the callback tests in this suite."""
43+
def test_callback_for_deadline():
44+
"""Used in a number of tests to confirm that Deadlines and DeadlineAlerts function correctly."""
4145
pass
4246

4347

4448
class TestDeadlineAlert:
4549
@pytest.mark.parametrize(
4650
"callback_value, expected_path",
4751
[
48-
pytest.param(test_callback, TEST_CALLBACK_PATH, id="valid_callable"),
52+
pytest.param(test_callback_for_deadline, TEST_CALLBACK_PATH, id="valid_callable"),
4953
pytest.param(TEST_CALLBACK_PATH, TEST_CALLBACK_PATH, id="valid_path_string"),
5054
pytest.param(lambda x: x, None, id="lambda_function"),
5155
pytest.param(TEST_CALLBACK_PATH + " ", TEST_CALLBACK_PATH, id="path_with_whitespace"),
@@ -77,6 +81,114 @@ def test_get_callback_path_error_cases(self, callback_value, error_type):
7781
with pytest.raises(error_type, match=expected_message):
7882
DeadlineAlert.get_callback_path(callback_value)
7983

84+
@pytest.mark.parametrize(
85+
"test_alert, should_equal",
86+
[
87+
pytest.param(
88+
DeadlineAlert(
89+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
90+
interval=timedelta(hours=1),
91+
callback=TEST_CALLBACK_PATH,
92+
callback_kwargs=TEST_CALLBACK_KWARGS,
93+
),
94+
True,
95+
id="same_alert",
96+
),
97+
pytest.param(
98+
DeadlineAlert(
99+
reference=DeadlineReference.DAGRUN_LOGICAL_DATE,
100+
interval=timedelta(hours=1),
101+
callback=TEST_CALLBACK_PATH,
102+
callback_kwargs=TEST_CALLBACK_KWARGS,
103+
),
104+
False,
105+
id="different_reference",
106+
),
107+
pytest.param(
108+
DeadlineAlert(
109+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
110+
interval=timedelta(hours=2),
111+
callback=TEST_CALLBACK_PATH,
112+
callback_kwargs=TEST_CALLBACK_KWARGS,
113+
),
114+
False,
115+
id="different_interval",
116+
),
117+
pytest.param(
118+
DeadlineAlert(
119+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
120+
interval=timedelta(hours=1),
121+
callback="other.callback",
122+
callback_kwargs=TEST_CALLBACK_KWARGS,
123+
),
124+
False,
125+
id="different_callback",
126+
),
127+
pytest.param(
128+
DeadlineAlert(
129+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
130+
interval=timedelta(hours=1),
131+
callback=TEST_CALLBACK_PATH,
132+
callback_kwargs={"arg2": "value2"},
133+
),
134+
False,
135+
id="different_kwargs",
136+
),
137+
pytest.param("not a DeadlineAlert", False, id="non_deadline_alert"),
138+
],
139+
)
140+
def test_deadline_alert_equality(self, test_alert, should_equal):
141+
base_alert = DeadlineAlert(
142+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
143+
interval=timedelta(hours=1),
144+
callback=TEST_CALLBACK_PATH,
145+
callback_kwargs=TEST_CALLBACK_KWARGS,
146+
)
147+
148+
assert (base_alert == test_alert) == should_equal
149+
150+
def test_deadline_alert_hash(self):
151+
std_interval = timedelta(hours=1)
152+
std_callback = TEST_CALLBACK_PATH
153+
std_kwargs = TEST_CALLBACK_KWARGS
154+
155+
alert1 = DeadlineAlert(
156+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
157+
interval=std_interval,
158+
callback=std_callback,
159+
callback_kwargs=std_kwargs,
160+
)
161+
alert2 = DeadlineAlert(
162+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
163+
interval=std_interval,
164+
callback=std_callback,
165+
callback_kwargs=std_kwargs,
166+
)
167+
168+
assert hash(alert1) == hash(alert1)
169+
assert hash(alert1) == hash(alert2)
170+
171+
def test_deadline_alert_in_set(self):
172+
std_interval = timedelta(hours=1)
173+
std_callback = TEST_CALLBACK_PATH
174+
std_kwargs = TEST_CALLBACK_KWARGS
175+
176+
alert1 = DeadlineAlert(
177+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
178+
interval=std_interval,
179+
callback=std_callback,
180+
callback_kwargs=std_kwargs,
181+
)
182+
alert2 = DeadlineAlert(
183+
reference=DeadlineReference.DAGRUN_QUEUED_AT,
184+
interval=std_interval,
185+
callback=std_callback,
186+
callback_kwargs=std_kwargs,
187+
)
188+
189+
alert_set = {alert1, alert2}
190+
assert len(alert_set) == 1
191+
80192

81193
class TestDeadlineReference:
82194
@pytest.mark.parametrize("reference", REFERENCE_TYPES)

0 commit comments

Comments
 (0)
Please sign in to comment.