Skip to content

Feat: Visualize DAG version changes and mixed versions in Grid view #53216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class GridRunsResponse(BaseModel):
run_after: datetime
state: TaskInstanceState | None
run_type: DagRunType
dag_version_number: int | None = None
dag_version_id: str | None = None
is_version_changed: bool = False
has_mixed_versions: bool = False
latest_version_number: int | None = None

@computed_field
def duration(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@
from airflow.utils.state import TaskInstanceState


class GridTaskInstanceSummary(BaseModel):
"""Task Instance Summary model for the Grid UI."""

task_id: str
try_number: int
start_date: datetime | None
end_date: datetime | None
queued_dttm: datetime | None
child_states: dict[str, int] | None
task_count: int
state: TaskInstanceState | None
note: str | None
dag_version_id: str | None = None
dag_version_number: int | None = None


class LightGridTaskInstanceSummary(BaseModel):
"""Task Instance Summary model for the Grid UI."""

Expand All @@ -32,6 +48,8 @@ class LightGridTaskInstanceSummary(BaseModel):
child_states: dict[TaskInstanceState | None, int] | None
min_start_date: datetime | None
max_end_date: datetime | None
dag_version_id: str | None = None
dag_version_number: int | None = None


class GridTISummaries(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,29 @@ components:
- type: 'null'
run_type:
$ref: '#/components/schemas/DagRunType'
dag_version_number:
anyOf:
- type: integer
- type: 'null'
title: Dag Version Number
dag_version_id:
anyOf:
- type: string
- type: 'null'
title: Dag Version Id
is_version_changed:
type: boolean
title: Is Version Changed
default: false
has_mixed_versions:
type: boolean
title: Has Mixed Versions
default: false
latest_version_number:
anyOf:
- type: integer
- type: 'null'
title: Latest Version Number
duration:
type: integer
title: Duration
Expand Down Expand Up @@ -1852,6 +1875,16 @@ components:
format: date-time
- type: 'null'
title: Max End Date
dag_version_id:
anyOf:
- type: string
- type: 'null'
title: Dag Version Id
dag_version_number:
anyOf:
- type: integer
- type: 'null'
title: Dag Version Number
type: object
required:
- task_id
Expand Down
113 changes: 77 additions & 36 deletions airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import structlog
from fastapi import Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.orm import joinedload, selectinload

from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
Expand All @@ -44,6 +45,7 @@
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.api_fastapi.core_api.services.ui.dag_version_service import DagVersionService
from airflow.api_fastapi.core_api.services.ui.grid import (
_find_aggregates,
_merge_node_dicts,
Expand Down Expand Up @@ -217,35 +219,74 @@ def get_grid_runs(
run_after: Annotated[RangeFilter, Depends(datetime_range_filter_factory("run_after", DagRun))],
) -> list[GridRunsResponse]:
"""Get info about a run for the grid."""
# Retrieve, sort the previous DAG Runs
base_query = select(
DagRun.dag_id,
DagRun.run_id,
DagRun.queued_at,
DagRun.start_date,
DagRun.end_date,
DagRun.run_after,
DagRun.state,
DagRun.run_type,
).where(DagRun.dag_id == dag_id)
try:
# Base query to get DagRun information with version details
base_query = (
select(DagRun).options(joinedload(DagRun.created_dag_version)).where(DagRun.dag_id == dag_id)
)

# This comparison is to fall back to DAG timetable when no order_by is provided
if order_by.value == order_by.get_primary_key_string():
latest_serdag = _get_latest_serdag(dag_id, session)
latest_dag = latest_serdag.dag
ordering = list(latest_dag.timetable.run_ordering)
order_by = SortParam(
allowed_attrs=ordering,
model=DagRun,
).set_value(ordering[0])
dag_runs_select_filter, _ = paginated_select(
statement=base_query,
order_by=order_by,
offset=offset,
filters=[run_after],
limit=limit,
)
return session.execute(dag_runs_select_filter)
# This comparison is to fall back to DAG timetable when no order_by is provided
if order_by.value == order_by.get_primary_key_string():
latest_serdag = _get_latest_serdag(dag_id, session)
latest_dag = latest_serdag.dag
ordering = list(latest_dag.timetable.run_ordering)
order_by = SortParam(
allowed_attrs=ordering,
model=DagRun,
).set_value(ordering[0])

dag_runs_select_filter, _ = paginated_select(
statement=base_query,
order_by=order_by,
offset=offset,
filters=[run_after],
limit=limit,
)

dag_runs = list(session.scalars(dag_runs_select_filter))

if not dag_runs:
return []

version_service = DagVersionService(session)
version_info_list = version_service.get_version_info_for_runs(dag_id, dag_runs)

response = []
for dag_run, version_info in zip(dag_runs, version_info_list):
grid_run = GridRunsResponse(
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
queued_at=dag_run.queued_at,
start_date=dag_run.start_date,
end_date=dag_run.end_date,
run_after=dag_run.run_after,
state=dag_run.state,
run_type=dag_run.run_type,
dag_version_number=version_info["dag_version_number"],
dag_version_id=version_info["dag_version_id"],
is_version_changed=version_info["is_version_changed"],
has_mixed_versions=version_info["has_mixed_versions"],
latest_version_number=version_info["latest_version_number"],
)
response.append(grid_run)

return response

except HTTPException:
# Re-raise HTTPException (like 404 from _get_latest_serdag) without modification
raise
except ValueError as e:
log.error("Invalid data format while retrieving grid runs", dag_id=dag_id, error=str(e))
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail={"reason": "invalid_data", "message": f"Invalid data format: {str(e)}"},
)
except Exception as e:
log.error("Unexpected error retrieving grid runs", dag_id=dag_id, error=str(e))
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"reason": "internal_error", "message": "An unexpected error occurred"},
)


@grid_router.get(
Expand Down Expand Up @@ -292,13 +333,8 @@ def get_grid_ti_summaries(
"""
tis_of_dag_runs, _ = paginated_select(
statement=(
select(
TaskInstance.task_id,
TaskInstance.state,
TaskInstance.dag_version_id,
TaskInstance.start_date,
TaskInstance.end_date,
)
select(TaskInstance)
.options(selectinload(TaskInstance.dag_version))
.where(TaskInstance.dag_id == dag_id)
.where(
TaskInstance.run_id == run_id,
Expand All @@ -309,18 +345,23 @@ def get_grid_ti_summaries(
limit=None,
return_total_entries=False,
)
task_instances = list(session.execute(tis_of_dag_runs))
task_instances = list(session.scalars(tis_of_dag_runs))
if not task_instances:
raise HTTPException(
status.HTTP_404_NOT_FOUND, f"No task instances for dag_id={dag_id} run_id={run_id}"
)
ti_details = collections.defaultdict(list)
for ti in task_instances:
dag_version_id = str(ti.dag_version_id) if ti.dag_version_id else None
dag_version_number = ti.dag_version.version_number if ti.dag_version else None

ti_details[ti.task_id].append(
{
"state": ti.state,
"start_date": ti.start_date,
"end_date": ti.end_date,
"dag_version_id": dag_version_id,
"dag_version_number": dag_version_number,
}
)
serdag = _get_serdag(
Expand Down
Loading