Skip to content

Add run_on_latest_version support for backfill and clear operations #52177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow-core/src/airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ def create_app(apps: str = "all") -> FastAPI:
dag_bag = create_dag_bag()

if "execution" in apps_list or "all" in apps_list:
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

task_exec_api_app = create_task_execution_api_app()
task_exec_api_app.state.dag_bag = dag_bag
task_exec_api_app.state.dag_bag = SchedulerDagBag()
init_error_handlers(task_exec_api_app)
app.mount("/execution", task_exec_api_app)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DAGRunClearBody(StrictBaseModel):

dry_run: bool = True
only_failed: bool = False
run_on_latest_version: bool = False


class DAGRunResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class ClearTaskInstancesBody(StrictBaseModel):
include_downstream: bool = False
include_future: bool = False
include_past: bool = False
run_on_latest_version: bool = False

@model_validator(mode="before")
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8396,6 +8396,10 @@ components:
type: boolean
title: Include Past
default: false
run_on_latest_version:
type: boolean
title: Run On Latest Version
default: false
additionalProperties: false
type: object
title: ClearTaskInstancesBody
Expand Down Expand Up @@ -9049,6 +9053,10 @@ components:
type: boolean
title: Only Failed
default: false
run_on_latest_version:
type: boolean
title: Run On Latest Version
default: false
additionalProperties: false
type: object
title: DAGRunClearBody
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def clear_dag_run(
run_id=dag_run_id,
task_ids=None,
only_failed=body.only_failed,
run_on_latest_version=body.run_on_latest_version,
dry_run=True,
session=session,
)
Expand All @@ -293,6 +294,7 @@ def clear_dag_run(
run_id=dag_run_id,
task_ids=None,
only_failed=body.only_failed,
run_on_latest_version=body.run_on_latest_version,
session=session,
)
dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,11 @@ def post_clear_task_instances(
if dag_run is None:
error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
raise HTTPException(status.HTTP_404_NOT_FOUND, error_message)
# If dag_run_id is provided, we should get the dag from SchedulerDagBag
# to ensure we get the right version.
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

dag = SchedulerDagBag().get_dag(dag_run=dag_run, session=session)
if past or future:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -724,6 +728,7 @@ def post_clear_task_instances(
task_instances,
session,
DagRunState.QUEUED if reset_dag_runs else False,
run_on_latest_version=body.run_on_latest_version,
)

return TaskInstanceCollectionResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload
from airflow.exceptions import DagRunAlreadyExists
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.utils.types import DagRunTriggeredByType

Expand Down Expand Up @@ -122,9 +121,20 @@ def clear_dag_run(
"message": f"DAG with dag_id: '{dag_id}' has import errors and cannot be triggered",
},
)
from airflow.jobs.scheduler_job_runner import SchedulerDagBag

dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id))
dag_bag = SchedulerDagBag()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we did task_exec_api_app.state.dag_bag = SchedulerDagBag() in the app init, why do we need to create a new object here, rather than use the existing one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because the circular import error as mentioned in #52177 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This endpoint don't also have dagbag dependency. If you scroll down, you can see usage of DagBag before the change

dag = dag_bag.get_dag(dag_run=dag_run, session=session)
if not dag:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "Not Found",
"message": f"DAG with dag_id: '{dag_id}' was not found in the DagBag",
},
)

dag_bag = DagBag(dag_folder=dm.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(dag_id)
dag.clear(run_id=run_id)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@
import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
from fastapi import Body, Depends, HTTPException, Query, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars

from airflow.api_fastapi.common.dagbag import DagBagDep
from airflow.api_fastapi.common.dagbag import dag_bag_from_app
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
Expand Down Expand Up @@ -76,6 +76,9 @@
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk.types import Operator

from airflow.jobs.scheduler_job_runner import SchedulerDagBag

SchedulerDagBagDep = Annotated[SchedulerDagBag, Depends(dag_bag_from_app)]

router = VersionedAPIRouter()

Expand Down Expand Up @@ -104,7 +107,7 @@ def ti_run(
task_instance_id: UUID,
ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
session: SessionDep,
dag_bag: DagBagDep,
dag_bag: SchedulerDagBagDep,
) -> TIRunContext:
"""
Run a TaskInstance.
Expand Down Expand Up @@ -255,7 +258,7 @@ def ti_run(
or 0
)

if dag := dag_bag.get_dag(ti.dag_id):
if dag := dag_bag.get_dag(dag_run=dr, session=session):
upstream_map_indexes = dict(
_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index, ti.run_id, session)
)
Expand Down Expand Up @@ -330,7 +333,7 @@ def ti_update_state(
task_instance_id: UUID,
ti_patch_payload: Annotated[TIStateUpdate, Body()],
session: SessionDep,
dag_bag: DagBagDep,
dag_bag: SchedulerDagBagDep,
):
"""
Update the state of a TaskInstance.
Expand Down Expand Up @@ -417,8 +420,9 @@ def ti_update_state(
)


def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> None:
ser_dag = dag_bag.get_dag(dag_id)
def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: SchedulerDagBagDep) -> None:
dr = ti.dag_run
ser_dag = dag_bag.get_dag(dag_run=dr, session=session)
if ser_dag and getattr(ser_dag, "fail_fast", False):
task_dict = getattr(ser_dag, "task_dict")
task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
Expand All @@ -432,7 +436,7 @@ def _create_ti_state_update_query_and_update_state(
query: Update,
updated_state,
session: SessionDep,
dag_bag: DagBagDep,
dag_bag: SchedulerDagBagDep,
dag_id: str,
) -> tuple[Update, TaskInstanceState]:
if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)):
Expand Down Expand Up @@ -893,7 +897,7 @@ def _get_group_tasks(dag_id: str, task_group_id: str, session: SessionDep, logic
def validate_inlets_and_outlets(
task_instance_id: UUID,
session: SessionDep,
dag_bag: DagBagDep,
dag_bag: SchedulerDagBagDep,
) -> InactiveAssetsResponse:
"""Validate whether there're inactive assets in inlets and outlets of a given task instance."""
ti_id_str = str(task_instance_id)
Expand All @@ -911,7 +915,8 @@ def validate_inlets_and_outlets(
)

if not ti.task:
dag = dag_bag.get_dag(ti.dag_id)
dr = ti.dag_run
dag = dag_bag.get_dag(dag_run=dr, session=session)
if dag:
with contextlib.suppress(TaskNotFound):
ti.task = dag.get_task(ti.task_id)
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,14 @@ def string_lower_type(val):
),
choices=("none", "completed", "failed"),
)
ARG_BACKFILL_RUN_ON_LATEST_VERSION = Arg(
("--run-on-latest-version",),
help=(
"(Experimental) If set, the backfill will run tasks using the latest bundle version instead of "
"the version that was active when the original Dag run was created."
),
action="store_true",
)


# misc
Expand Down Expand Up @@ -968,6 +976,7 @@ class GroupCommand(NamedTuple):
ARG_RUN_BACKWARDS,
ARG_MAX_ACTIVE_RUNS,
ARG_BACKFILL_REPROCESS_BEHAVIOR,
ARG_BACKFILL_RUN_ON_LATEST_VERSION,
ARG_BACKFILL_DRY_RUN,
),
),
Expand Down
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/cli/commands/backfill_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_backfill(args) -> None:
reverse=args.run_backwards,
dag_run_conf=args.dag_run_conf,
reprocess_behavior=reprocess_behavior,
run_on_latest_version=args.run_on_latest_version,
)
for k, v in params.items():
console.print(f" - {k} = {v}")
Expand Down Expand Up @@ -88,4 +89,5 @@ def create_backfill(args) -> None:
dag_run_conf=args.dag_run_conf,
triggering_user_name=user,
reprocess_behavior=reprocess_behavior,
run_on_latest_version=args.run_on_latest_version,
)
15 changes: 8 additions & 7 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,16 @@ def _get_dag(self, version_id: str, session: Session) -> DAG | None:
return dag

@staticmethod
def _version_from_dag_run(dag_run, session):
if dag_run.bundle_version:
dag_version = dag_run.created_dag_version
else:
def _version_from_dag_run(dag_run, latest, session):
if latest or not dag_run.bundle_version:
dag_version = DagVersion.get_latest_version(dag_id=dag_run.dag_id, session=session)
return dag_version
if dag_version:
return dag_version

return dag_run.created_dag_version

def get_dag(self, dag_run: DagRun, session: Session) -> DAG | None:
version = self._version_from_dag_run(dag_run=dag_run, session=session)
def get_dag(self, dag_run: DagRun, session: Session, latest=False) -> DAG | None:
version = self._version_from_dag_run(dag_run=dag_run, latest=latest, session=session)
if not version:
return None
return self._get_dag(version_id=version.id, session=session)
Expand Down
7 changes: 6 additions & 1 deletion airflow-core/src/airflow/models/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _create_backfill_dag_run(
dag_run_conf,
backfill_sort_ordinal,
triggering_user_name,
run_on_latest_version,
session,
):
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -328,6 +329,7 @@ def _create_backfill_dag_run(
info=info,
backfill_id=backfill_id,
sort_ordinal=backfill_sort_ordinal,
run_on_latest=run_on_latest_version,
)
else:
session.add(
Expand Down Expand Up @@ -401,7 +403,7 @@ def _get_info_list(
return dagrun_info_list


def _handle_clear_run(session, dag, dr, info, backfill_id, sort_ordinal):
def _handle_clear_run(session, dag, dr, info, backfill_id, sort_ordinal, run_on_latest=False):
"""Clear the existing DAG run and update backfill metadata."""
from sqlalchemy.sql import update

Expand All @@ -415,6 +417,7 @@ def _handle_clear_run(session, dag, dr, info, backfill_id, sort_ordinal):
session=session,
confirm_prompt=False,
dry_run=False,
run_on_latest_version=run_on_latest,
)

# Update backfill_id and run_type in DagRun table
Expand Down Expand Up @@ -447,6 +450,7 @@ def _create_backfill(
dag_run_conf: dict | None,
triggering_user_name: str | None,
reprocess_behavior: ReprocessBehavior | None = None,
run_on_latest_version: bool = False,
) -> Backfill | None:
from airflow.models import DagModel
from airflow.models.serialized_dag import SerializedDagModel
Expand Down Expand Up @@ -510,6 +514,7 @@ def _create_backfill(
reprocess_behavior=br.reprocess_behavior,
backfill_sort_ordinal=backfill_sort_ordinal,
triggering_user_name=br.triggering_user_name,
run_on_latest_version=run_on_latest_version,
session=session,
)
log.info(
Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,7 @@ def clear(
only_running: bool = False,
confirm_prompt: bool = False,
dag_run_state: DagRunState = DagRunState.QUEUED,
run_on_latest_version: bool = False,
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
Expand All @@ -1314,6 +1315,7 @@ def clear(
confirm_prompt: bool = False,
dag_run_state: DagRunState = DagRunState.QUEUED,
dry_run: Literal[False] = False,
run_on_latest_version: bool = False,
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
Expand All @@ -1332,6 +1334,7 @@ def clear(
only_running: bool = False,
confirm_prompt: bool = False,
dag_run_state: DagRunState = DagRunState.QUEUED,
run_on_latest_version: bool = False,
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
Expand All @@ -1350,6 +1353,7 @@ def clear(
confirm_prompt: bool = False,
dag_run_state: DagRunState = DagRunState.QUEUED,
dry_run: Literal[False] = False,
run_on_latest_version: bool = False,
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
Expand All @@ -1369,6 +1373,7 @@ def clear(
confirm_prompt: bool = False,
dag_run_state: DagRunState = DagRunState.QUEUED,
dry_run: bool = False,
run_on_latest_version: bool = False,
session: Session = NEW_SESSION,
dag_bag: DagBag | None = None,
exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
Expand All @@ -1387,6 +1392,7 @@ def clear(
:param dag_run_state: state to set DagRun to. If set to False, dagrun state will not
be changed.
:param dry_run: Find the tasks to clear but don't clear them.
:param run_on_latest_version: whether to run on latest serialized DAG and Bundle version
:param session: The sqlalchemy session to use
:param dag_bag: The DagBag used to find the dags (Optional)
:param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``)
Expand Down Expand Up @@ -1432,6 +1438,7 @@ def clear(
list(tis),
session,
dag_run_state=dag_run_state,
run_on_latest_version=run_on_latest_version,
)
else:
count = 0
Expand Down
Loading