Skip to content

Commit 2a6e549

Browse files
uranusjrHsiuChuanHsu
authored andcommitted
Add endpoint to watch dag run until finish (apache#51920)
1 parent ee98e89 commit 2a6e549

File tree

12 files changed

+573
-31
lines changed

12 files changed

+573
-31
lines changed

airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,90 @@ paths:
21472147
application/json:
21482148
schema:
21492149
$ref: '#/components/schemas/HTTPValidationError'
2150+
/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/wait:
2151+
get:
2152+
tags:
2153+
- DagRun
2154+
- experimental
2155+
summary: 'Experimental: Wait for a dag run to complete, and return task results
2156+
if requested.'
2157+
description: "\U0001F6A7 This is an experimental endpoint and may change or\
2158+
\ be removed without notice."
2159+
operationId: wait_dag_run_until_finished
2160+
security:
2161+
- OAuth2PasswordBearer: []
2162+
parameters:
2163+
- name: dag_id
2164+
in: path
2165+
required: true
2166+
schema:
2167+
type: string
2168+
title: Dag Id
2169+
- name: dag_run_id
2170+
in: path
2171+
required: true
2172+
schema:
2173+
type: string
2174+
title: Dag Run Id
2175+
- name: interval
2176+
in: query
2177+
required: true
2178+
schema:
2179+
type: number
2180+
exclusiveMinimum: 0.0
2181+
description: Seconds to wait between dag run state checks
2182+
title: Interval
2183+
description: Seconds to wait between dag run state checks
2184+
- name: result
2185+
in: query
2186+
required: false
2187+
schema:
2188+
anyOf:
2189+
- type: array
2190+
items:
2191+
type: string
2192+
- type: 'null'
2193+
description: Collect result XCom from task. Can be set multiple times.
2194+
title: Result
2195+
description: Collect result XCom from task. Can be set multiple times.
2196+
responses:
2197+
'200':
2198+
description: Successful Response
2199+
content:
2200+
application/json:
2201+
schema: {}
2202+
application/x-ndjson:
2203+
schema:
2204+
type: string
2205+
example: '{"state": "running"}
2206+
2207+
{"state": "success", "results": {"op": 42}}
2208+
2209+
'
2210+
'401':
2211+
content:
2212+
application/json:
2213+
schema:
2214+
$ref: '#/components/schemas/HTTPExceptionResponse'
2215+
description: Unauthorized
2216+
'403':
2217+
content:
2218+
application/json:
2219+
schema:
2220+
$ref: '#/components/schemas/HTTPExceptionResponse'
2221+
description: Forbidden
2222+
'404':
2223+
content:
2224+
application/json:
2225+
schema:
2226+
$ref: '#/components/schemas/HTTPExceptionResponse'
2227+
description: Not Found
2228+
'422':
2229+
description: Validation Error
2230+
content:
2231+
application/json:
2232+
schema:
2233+
$ref: '#/components/schemas/HTTPValidationError'
21502234
/api/v2/dags/{dag_id}/dagRuns/list:
21512235
post:
21522236
tags:

airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
from __future__ import annotations
1919

20+
import textwrap
2021
from typing import Annotated, Literal, cast
2122

2223
import structlog
2324
from fastapi import Depends, HTTPException, Query, status
2425
from fastapi.exceptions import RequestValidationError
26+
from fastapi.responses import StreamingResponse
2527
from pydantic import ValidationError
2628
from sqlalchemy import select
2729
from sqlalchemy.orm import joinedload
@@ -51,6 +53,7 @@
5153
search_param_factory,
5254
)
5355
from airflow.api_fastapi.common.router import AirflowRouter
56+
from airflow.api_fastapi.common.types import Mimetype
5457
from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse
5558
from airflow.api_fastapi.core_api.datamodels.dag_run import (
5659
DAGRunClearBody,
@@ -72,6 +75,7 @@
7275
requires_access_asset,
7376
requires_access_dag,
7477
)
78+
from airflow.api_fastapi.core_api.services.public.dag_run import DagRunWaiter
7579
from airflow.api_fastapi.logging.decorators import action_logging
7680
from airflow.exceptions import ParamValidationError
7781
from airflow.listeners.listener import get_listener_manager
@@ -438,6 +442,57 @@ def trigger_dag_run(
438442
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
439443

440444

445+
@dag_run_router.get(
446+
"/{dag_run_id}/wait",
447+
tags=["experimental"],
448+
summary="Experimental: Wait for a dag run to complete, and return task results if requested.",
449+
description="🚧 This is an experimental endpoint and may change or be removed without notice.",
450+
responses={
451+
**create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
452+
status.HTTP_200_OK: {
453+
"description": "Successful Response",
454+
"content": {
455+
Mimetype.NDJSON: {
456+
"schema": {
457+
"type": "string",
458+
"example": textwrap.dedent(
459+
"""\
460+
{"state": "running"}
461+
{"state": "success", "results": {"op": 42}}
462+
"""
463+
),
464+
}
465+
}
466+
},
467+
},
468+
},
469+
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))],
470+
)
471+
def wait_dag_run_until_finished(
472+
dag_id: str,
473+
dag_run_id: str,
474+
session: SessionDep,
475+
interval: Annotated[float, Query(gt=0.0, description="Seconds to wait between dag run state checks")],
476+
result_task_ids: Annotated[
477+
list[str] | None,
478+
Query(alias="result", description="Collect result XCom from task. Can be set multiple times."),
479+
] = None,
480+
):
481+
"Wait for a dag run until it finishes, and return its result(s)."
482+
if not session.scalar(select(1).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)):
483+
raise HTTPException(
484+
status.HTTP_404_NOT_FOUND,
485+
f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found",
486+
)
487+
waiter = DagRunWaiter(
488+
dag_id=dag_id,
489+
run_id=dag_run_id,
490+
interval=interval,
491+
result_task_ids=result_task_ids,
492+
)
493+
return StreamingResponse(waiter.wait())
494+
495+
441496
@dag_run_router.post(
442497
"/list",
443498
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import asyncio
21+
import itertools
22+
import json
23+
import operator
24+
from typing import TYPE_CHECKING, Any
25+
26+
import attrs
27+
from sqlalchemy import select
28+
29+
from airflow.models.dagrun import DagRun
30+
from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
31+
from airflow.utils.session import create_session_async
32+
from airflow.utils.state import State
33+
34+
if TYPE_CHECKING:
35+
from collections.abc import AsyncGenerator, Iterator
36+
37+
38+
@attrs.define
39+
class DagRunWaiter:
40+
"""Wait for the specified dag run to finish, and collect info from it."""
41+
42+
dag_id: str
43+
run_id: str
44+
interval: float
45+
result_task_ids: list[str] | None
46+
47+
async def _get_dag_run(self) -> DagRun:
48+
async with create_session_async() as session:
49+
return await session.scalar(select(DagRun).filter_by(dag_id=self.dag_id, run_id=self.run_id))
50+
51+
def _serialize_xcoms(self) -> dict[str, Any]:
52+
xcom_query = XComModel.get_many(
53+
run_id=self.run_id,
54+
key=XCOM_RETURN_KEY,
55+
task_ids=self.result_task_ids,
56+
dag_ids=self.dag_id,
57+
)
58+
xcom_query = xcom_query.order_by(XComModel.task_id, XComModel.map_index)
59+
60+
def _group_xcoms(g: Iterator[XComModel]) -> Any:
61+
entries = list(g)
62+
if len(entries) == 1 and entries[0].map_index < 0: # Unpack non-mapped task xcom.
63+
return entries[0].value
64+
return [entry.value for entry in entries] # Task is mapped; return all xcoms in a list.
65+
66+
return {
67+
task_id: _group_xcoms(g)
68+
for task_id, g in itertools.groupby(xcom_query, key=operator.attrgetter("task_id"))
69+
}
70+
71+
def _serialize_response(self, dag_run: DagRun) -> str:
72+
resp = {"state": dag_run.state}
73+
if dag_run.state not in State.finished_dr_states:
74+
return json.dumps(resp)
75+
if self.result_task_ids:
76+
resp["results"] = self._serialize_xcoms()
77+
return json.dumps(resp)
78+
79+
async def wait(self) -> AsyncGenerator[str, None]:
80+
yield self._serialize_response(dag_run := await self._get_dag_run())
81+
yield "\n"
82+
while dag_run.state not in State.finished_dr_states:
83+
await asyncio.sleep(self.interval)
84+
yield self._serialize_response(dag_run := await self._get_dag_run())
85+
yield "\n"

airflow-core/src/airflow/settings.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import json
2323
import logging
2424
import os
25-
import platform
2625
import sys
2726
import warnings
2827
from collections.abc import Callable
@@ -321,6 +320,20 @@ def _is_sqlite_db_path_relative(sqla_conn_str: str) -> bool:
321320
return True
322321

323322

323+
def _configure_async_session():
324+
global async_engine
325+
global AsyncSession
326+
327+
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
328+
AsyncSession = sessionmaker(
329+
bind=async_engine,
330+
autocommit=False,
331+
autoflush=False,
332+
class_=SAAsyncSession,
333+
expire_on_commit=False,
334+
)
335+
336+
324337
def configure_orm(disable_connection_pool=False, pool_class=None):
325338
"""Configure ORM using SQLAlchemy."""
326339
from airflow.sdk.execution_time.secrets_masker import mask_secret
@@ -335,8 +348,6 @@ def configure_orm(disable_connection_pool=False, pool_class=None):
335348

336349
global Session
337350
global engine
338-
global async_engine
339-
global AsyncSession
340351
global NonScopedSession
341352

342353
if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
@@ -359,34 +370,24 @@ def configure_orm(disable_connection_pool=False, pool_class=None):
359370
connect_args["check_same_thread"] = False
360371

361372
engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True)
362-
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
363-
AsyncSession = sessionmaker(
364-
bind=async_engine,
365-
autocommit=False,
366-
autoflush=False,
367-
class_=SAAsyncSession,
368-
expire_on_commit=False,
369-
)
370373
mask_secret(engine.url.password)
371-
372374
setup_event_handlers(engine)
373375

374376
if conf.has_option("database", "sql_alchemy_session_maker"):
375377
_session_maker = conf.getimport("database", "sql_alchemy_session_maker")
376378
else:
377-
378-
def _session_maker(_engine):
379-
return sessionmaker(
380-
autocommit=False,
381-
autoflush=False,
382-
bind=_engine,
383-
expire_on_commit=False,
384-
)
385-
379+
_session_maker = functools.partial(
380+
sessionmaker,
381+
autocommit=False,
382+
autoflush=False,
383+
expire_on_commit=False,
384+
)
386385
NonScopedSession = _session_maker(engine)
387386
Session = scoped_session(NonScopedSession)
388387

389-
if not platform.system() == "Windows":
388+
_configure_async_session()
389+
390+
if register_at_fork := getattr(os, "register_at_fork", None):
390391
# https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
391392
def clean_in_fork():
392393
_globals = globals()
@@ -396,7 +397,7 @@ def clean_in_fork():
396397
async_engine.sync_engine.dispose(close=False)
397398

398399
# Won't work on Windows
399-
os.register_at_fork(after_in_child=clean_in_fork)
400+
register_at_fork(after_in_child=clean_in_fork)
400401

401402

402403
DEFAULT_ENGINE_ARGS = {

airflow-core/src/airflow/ui/openapi-gen/queries/common.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// generated with @7nohe/[email protected]
22

33
import { UseQueryResult } from "@tanstack/react-query";
4-
import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen";
4+
import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen";
55
import { DagRunState, DagWarningType } from "../requests/types.gen";
66
export type AssetServiceGetAssetsDefaultResponse = Awaited<ReturnType<typeof AssetService.getAssets>>;
77
export type AssetServiceGetAssetsQueryResult<TData = AssetServiceGetAssetsDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
@@ -159,6 +159,24 @@ export const UseDagRunServiceGetDagRunsKeyFn = ({ dagId, endDateGte, endDateLte,
159159
updatedAtGte?: string;
160160
updatedAtLte?: string;
161161
}, queryKey?: Array<unknown>) => [useDagRunServiceGetDagRunsKey, ...(queryKey ?? [{ dagId, endDateGte, endDateLte, limit, logicalDateGte, logicalDateLte, offset, orderBy, runAfterGte, runAfterLte, runIdPattern, runType, startDateGte, startDateLte, state, updatedAtGte, updatedAtLte }])];
162+
export type DagRunServiceWaitDagRunUntilFinishedDefaultResponse = Awaited<ReturnType<typeof DagRunService.waitDagRunUntilFinished>>;
163+
export type DagRunServiceWaitDagRunUntilFinishedQueryResult<TData = DagRunServiceWaitDagRunUntilFinishedDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
164+
export const useDagRunServiceWaitDagRunUntilFinishedKey = "DagRunServiceWaitDagRunUntilFinished";
165+
export const UseDagRunServiceWaitDagRunUntilFinishedKeyFn = ({ dagId, dagRunId, interval, result }: {
166+
dagId: string;
167+
dagRunId: string;
168+
interval: number;
169+
result?: string[];
170+
}, queryKey?: Array<unknown>) => [useDagRunServiceWaitDagRunUntilFinishedKey, ...(queryKey ?? [{ dagId, dagRunId, interval, result }])];
171+
export type ExperimentalServiceWaitDagRunUntilFinishedDefaultResponse = Awaited<ReturnType<typeof ExperimentalService.waitDagRunUntilFinished>>;
172+
export type ExperimentalServiceWaitDagRunUntilFinishedQueryResult<TData = ExperimentalServiceWaitDagRunUntilFinishedDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
173+
export const useExperimentalServiceWaitDagRunUntilFinishedKey = "ExperimentalServiceWaitDagRunUntilFinished";
174+
export const UseExperimentalServiceWaitDagRunUntilFinishedKeyFn = ({ dagId, dagRunId, interval, result }: {
175+
dagId: string;
176+
dagRunId: string;
177+
interval: number;
178+
result?: string[];
179+
}, queryKey?: Array<unknown>) => [useExperimentalServiceWaitDagRunUntilFinishedKey, ...(queryKey ?? [{ dagId, dagRunId, interval, result }])];
162180
export type DagSourceServiceGetDagSourceDefaultResponse = Awaited<ReturnType<typeof DagSourceService.getDagSource>>;
163181
export type DagSourceServiceGetDagSourceQueryResult<TData = DagSourceServiceGetDagSourceDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
164182
export const useDagSourceServiceGetDagSourceKey = "DagSourceServiceGetDagSource";

0 commit comments

Comments
 (0)