Skip to content

Add endpoint to watch dag run until finish #51920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,90 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/wait:
get:
tags:
- DagRun
- experimental
summary: 'Experimental: Wait for a dag run to complete, and return task results
if requested.'
description: "\U0001F6A7 This is an experimental endpoint and may change or\
\ be removed without notice."
operationId: wait_dag_run_until_finished
security:
- OAuth2PasswordBearer: []
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: dag_run_id
in: path
required: true
schema:
type: string
title: Dag Run Id
- name: interval
in: query
required: true
schema:
type: number
exclusiveMinimum: 0.0
description: Seconds to wait between dag run state checks
title: Interval
description: Seconds to wait between dag run state checks
- name: result
in: query
required: false
schema:
anyOf:
- type: array
items:
type: string
- type: 'null'
description: Collect result XCom from task. Can be set multiple times.
title: Result
description: Collect result XCom from task. Can be set multiple times.
responses:
'200':
description: Successful Response
content:
application/json:
schema: {}
application/x-ndjson:
schema:
type: string
example: '{"state": "running"}

{"state": "success", "results": {"op": 42}}

'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/api/v2/dags/{dag_id}/dagRuns/list:
post:
tags:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

from __future__ import annotations

import textwrap
from typing import Annotated, Literal, cast

import structlog
from fastapi import Depends, HTTPException, Query, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import StreamingResponse
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import joinedload
Expand Down Expand Up @@ -51,6 +53,7 @@
search_param_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.common.types import Mimetype
from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse
from airflow.api_fastapi.core_api.datamodels.dag_run import (
DAGRunClearBody,
Expand All @@ -72,6 +75,7 @@
requires_access_asset,
requires_access_dag,
)
from airflow.api_fastapi.core_api.services.public.dag_run import DagRunWaiter
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.exceptions import ParamValidationError
from airflow.listeners.listener import get_listener_manager
Expand Down Expand Up @@ -438,6 +442,57 @@ def trigger_dag_run(
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))


@dag_run_router.get(
"/{dag_run_id}/wait",
tags=["experimental"],
summary="Experimental: Wait for a dag run to complete, and return task results if requested.",
description="🚧 This is an experimental endpoint and may change or be removed without notice.",
responses={
**create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
status.HTTP_200_OK: {
"description": "Successful Response",
"content": {
Mimetype.NDJSON: {
"schema": {
"type": "string",
"example": textwrap.dedent(
"""\
{"state": "running"}
{"state": "success", "results": {"op": 42}}
"""
),
}
}
},
},
},
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))],
)
def wait_dag_run_until_finished(
dag_id: str,
dag_run_id: str,
session: SessionDep,
interval: Annotated[float, Query(gt=0.0, description="Seconds to wait between dag run state checks")],
result_task_ids: Annotated[
list[str] | None,
Query(alias="result", description="Collect result XCom from task. Can be set multiple times."),
] = None,
):
"Wait for a dag run until it finishes, and return its result(s)."
if not session.scalar(select(1).where(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found",
)
waiter = DagRunWaiter(
dag_id=dag_id,
run_id=dag_run_id,
interval=interval,
result_task_ids=result_task_ids,
)
return StreamingResponse(waiter.wait())


@dag_run_router.post(
"/list",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import asyncio
import itertools
import json
import operator
from typing import TYPE_CHECKING, Any

import attrs
from sqlalchemy import select

from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCOM_RETURN_KEY, XComModel
from airflow.utils.session import create_session_async
from airflow.utils.state import State

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterator


@attrs.define
class DagRunWaiter:
"""Wait for the specified dag run to finish, and collect info from it."""

dag_id: str
run_id: str
interval: float
result_task_ids: list[str] | None

async def _get_dag_run(self) -> DagRun:
async with create_session_async() as session:
return await session.scalar(select(DagRun).filter_by(dag_id=self.dag_id, run_id=self.run_id))

def _serialize_xcoms(self) -> dict[str, Any]:
xcom_query = XComModel.get_many(
run_id=self.run_id,
key=XCOM_RETURN_KEY,
task_ids=self.result_task_ids,
dag_ids=self.dag_id,
)
xcom_query = xcom_query.order_by(XComModel.task_id, XComModel.map_index)

def _group_xcoms(g: Iterator[XComModel]) -> Any:
entries = list(g)
if len(entries) == 1 and entries[0].map_index < 0: # Unpack non-mapped task xcom.
return entries[0].value
return [entry.value for entry in entries] # Task is mapped; return all xcoms in a list.

return {
task_id: _group_xcoms(g)
for task_id, g in itertools.groupby(xcom_query, key=operator.attrgetter("task_id"))
}

def _serialize_response(self, dag_run: DagRun) -> str:
resp = {"state": dag_run.state}
if dag_run.state not in State.finished_dr_states:
return json.dumps(resp)
if self.result_task_ids:
resp["results"] = self._serialize_xcoms()
return json.dumps(resp)

async def wait(self) -> AsyncGenerator[str, None]:
yield self._serialize_response(dag_run := await self._get_dag_run())
yield "\n"
while dag_run.state not in State.finished_dr_states:
await asyncio.sleep(self.interval)
yield self._serialize_response(dag_run := await self._get_dag_run())
yield "\n"
47 changes: 24 additions & 23 deletions airflow-core/src/airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import json
import logging
import os
import platform
import sys
import warnings
from collections.abc import Callable
Expand Down Expand Up @@ -321,6 +320,20 @@ def _is_sqlite_db_path_relative(sqla_conn_str: str) -> bool:
return True


def _configure_async_session():
global async_engine
global AsyncSession

async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
AsyncSession = sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
class_=SAAsyncSession,
expire_on_commit=False,
)


def configure_orm(disable_connection_pool=False, pool_class=None):
"""Configure ORM using SQLAlchemy."""
from airflow.sdk.execution_time.secrets_masker import mask_secret
Expand All @@ -335,8 +348,6 @@ def configure_orm(disable_connection_pool=False, pool_class=None):

global Session
global engine
global async_engine
global AsyncSession
global NonScopedSession

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
Expand All @@ -359,34 +370,24 @@ def configure_orm(disable_connection_pool=False, pool_class=None):
connect_args["check_same_thread"] = False

engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True)
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
AsyncSession = sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
class_=SAAsyncSession,
expire_on_commit=False,
)
mask_secret(engine.url.password)

setup_event_handlers(engine)

if conf.has_option("database", "sql_alchemy_session_maker"):
_session_maker = conf.getimport("database", "sql_alchemy_session_maker")
else:

def _session_maker(_engine):
return sessionmaker(
autocommit=False,
autoflush=False,
bind=_engine,
expire_on_commit=False,
)

_session_maker = functools.partial(
sessionmaker,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
NonScopedSession = _session_maker(engine)
Session = scoped_session(NonScopedSession)

if not platform.system() == "Windows":
_configure_async_session()

if register_at_fork := getattr(os, "register_at_fork", None):
# https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
def clean_in_fork():
_globals = globals()
Expand All @@ -396,7 +397,7 @@ def clean_in_fork():
async_engine.sync_engine.dispose(close=False)

# Won't work on Windows
os.register_at_fork(after_in_child=clean_in_fork)
register_at_fork(after_in_child=clean_in_fork)


DEFAULT_ENGINE_ARGS = {
Expand Down
20 changes: 19 additions & 1 deletion airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// generated with @7nohe/[email protected]

import { UseQueryResult } from "@tanstack/react-query";
import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen";
import { AssetService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagReportService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, VariableService, VersionService, XcomService } from "../requests/services.gen";
import { DagRunState, DagWarningType } from "../requests/types.gen";
export type AssetServiceGetAssetsDefaultResponse = Awaited<ReturnType<typeof AssetService.getAssets>>;
export type AssetServiceGetAssetsQueryResult<TData = AssetServiceGetAssetsDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
Expand Down Expand Up @@ -159,6 +159,24 @@ export const UseDagRunServiceGetDagRunsKeyFn = ({ dagId, endDateGte, endDateLte,
updatedAtGte?: string;
updatedAtLte?: string;
}, queryKey?: Array<unknown>) => [useDagRunServiceGetDagRunsKey, ...(queryKey ?? [{ dagId, endDateGte, endDateLte, limit, logicalDateGte, logicalDateLte, offset, orderBy, runAfterGte, runAfterLte, runIdPattern, runType, startDateGte, startDateLte, state, updatedAtGte, updatedAtLte }])];
export type DagRunServiceWaitDagRunUntilFinishedDefaultResponse = Awaited<ReturnType<typeof DagRunService.waitDagRunUntilFinished>>;
export type DagRunServiceWaitDagRunUntilFinishedQueryResult<TData = DagRunServiceWaitDagRunUntilFinishedDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useDagRunServiceWaitDagRunUntilFinishedKey = "DagRunServiceWaitDagRunUntilFinished";
export const UseDagRunServiceWaitDagRunUntilFinishedKeyFn = ({ dagId, dagRunId, interval, result }: {
dagId: string;
dagRunId: string;
interval: number;
result?: string[];
}, queryKey?: Array<unknown>) => [useDagRunServiceWaitDagRunUntilFinishedKey, ...(queryKey ?? [{ dagId, dagRunId, interval, result }])];
export type ExperimentalServiceWaitDagRunUntilFinishedDefaultResponse = Awaited<ReturnType<typeof ExperimentalService.waitDagRunUntilFinished>>;
export type ExperimentalServiceWaitDagRunUntilFinishedQueryResult<TData = ExperimentalServiceWaitDagRunUntilFinishedDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useExperimentalServiceWaitDagRunUntilFinishedKey = "ExperimentalServiceWaitDagRunUntilFinished";
export const UseExperimentalServiceWaitDagRunUntilFinishedKeyFn = ({ dagId, dagRunId, interval, result }: {
dagId: string;
dagRunId: string;
interval: number;
result?: string[];
}, queryKey?: Array<unknown>) => [useExperimentalServiceWaitDagRunUntilFinishedKey, ...(queryKey ?? [{ dagId, dagRunId, interval, result }])];
export type DagSourceServiceGetDagSourceDefaultResponse = Awaited<ReturnType<typeof DagSourceService.getDagSource>>;
export type DagSourceServiceGetDagSourceQueryResult<TData = DagSourceServiceGetDagSourceDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>;
export const useDagSourceServiceGetDagSourceKey = "DagSourceServiceGetDagSource";
Expand Down
Loading