|
41 | 41 | from airflow.exceptions import TaskNotFound
|
42 | 42 | from airflow.models import DAG, DagRun as DR
|
43 | 43 | from airflow.models.xcom import XComModel
|
44 |
| -from airflow.settings import conf |
45 | 44 |
|
46 | 45 | xcom_router = AirflowRouter(
|
47 | 46 | tags=["XCom"], prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries"
|
@@ -69,41 +68,41 @@ def get_xcom_entry(
|
69 | 68 | stringify: Annotated[bool, Query()] = False,
|
70 | 69 | ) -> XComResponseNative | XComResponseString:
|
71 | 70 | """Get an XCom entry."""
|
72 |
| - if deserialize: |
73 |
| - if not conf.getboolean("api", "enable_xcom_deserialize_support", fallback=False): |
74 |
| - raise HTTPException( |
75 |
| - status.HTTP_400_BAD_REQUEST, "XCom deserialization is disabled in configuration." |
76 |
| - ) |
77 |
| - query = select(XComModel, XComModel.value) |
78 |
| - else: |
79 |
| - query = select(XComModel) |
80 |
| - |
81 |
| - query = query.where( |
82 |
| - XComModel.dag_id == dag_id, |
83 |
| - XComModel.task_id == task_id, |
84 |
| - XComModel.key == xcom_key, |
85 |
| - XComModel.map_index == map_index, |
| 71 | + xcom_query = XComModel.get_many( |
| 72 | + run_id=dag_run_id, |
| 73 | + key=xcom_key, |
| 74 | + task_ids=task_id, |
| 75 | + dag_ids=dag_id, |
| 76 | + map_indexes=map_index, |
| 77 | + session=session, |
| 78 | + limit=1, |
86 | 79 | )
|
87 |
| - query = query.join(DR, and_(XComModel.dag_id == DR.dag_id, XComModel.run_id == DR.run_id)) |
88 |
| - query = query.where(DR.run_id == dag_run_id) |
89 |
| - query = query.options(joinedload(XComModel.dag_run).joinedload(DR.dag_model)) |
90 | 80 |
|
91 |
| - if deserialize: |
92 |
| - item = session.execute(query).one_or_none() |
93 |
| - else: |
94 |
| - item = session.scalars(query).one_or_none() |
| 81 | + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. |
| 82 | + # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead |
| 83 | + # retrieves the raw serialized value from the database. |
| 84 | + result = xcom_query.limit(1).first() |
95 | 85 |
|
96 |
| - if item is None: |
| 86 | + if result is None: |
97 | 87 | raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found")
|
98 | 88 |
|
| 89 | + item = copy.copy(result) |
| 90 | + |
99 | 91 | if deserialize:
|
100 |
| - from airflow.sdk.execution_time.xcom import XCom |
| 92 | + # We use `airflow.serialization.serde` for deserialization here because custom XCom backends (with their own |
| 93 | + # serializers/deserializers) are only used on the worker side during task execution. |
| 94 | + |
| 95 | + # However, the XCom value is *always* stored in the metadata database as a valid JSON object. |
| 96 | + # Therefore, for purposes such as UI display or returning API responses, deserializing with |
| 97 | + # `airflow.serialization.serde` is safe and recommended. |
| 98 | + from airflow.serialization.serde import deserialize as serde_deserialize |
101 | 99 |
|
102 |
| - xcom, value = item |
103 |
| - xcom_stub = copy.copy(xcom) |
104 |
| - xcom_stub.value = value |
105 |
| - xcom_stub.value = XCom.deserialize_value(xcom_stub) |
106 |
| - item = xcom_stub |
| 100 | + # full=False ensures that the `item` is deserialized without loading the classes, and it returns a stringified version |
| 101 | + item.value = serde_deserialize(XComModel.deserialize_value(item), full=False) |
| 102 | + else: |
| 103 | + # For native format, return the raw serialized value from the database |
| 104 | + # This preserves the JSON string format that the API expects |
| 105 | + item.value = result.value |
107 | 106 |
|
108 | 107 | if stringify:
|
109 | 108 | return XComResponseString.model_validate(item)
|
|
0 commit comments