Skip to content

Add Patch Task Instance Summary #50526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from sqlalchemy import or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.sql.selectable import Select

Expand Down Expand Up @@ -736,9 +735,9 @@ def _patch_ti_validate_request(
dag_bag: DagBagDep,
body: PatchTaskInstanceBody,
session: SessionDep,
map_index: int = -1,
map_index: int | None = -1,
update_mask: list[str] | None = Query(None),
) -> tuple[DAG, TI, dict]:
) -> tuple[DAG, list[TI], dict]:
dag = dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")
Expand All @@ -752,20 +751,15 @@ def _patch_ti_validate_request(
.join(TI.dag_run)
.options(joinedload(TI.rendered_task_instance_fields))
)
query = query.where(TI.map_index == map_index)
if map_index is not None:
query = query.where(TI.map_index == map_index)

try:
ti = session.scalar(query)
except MultipleResultsFound:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
)
tis = session.scalars(query).all()

err_msg_404 = (
f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found",
)
if ti is None:
if len(tis) == 0:
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)

fields_to_update = body.model_fields_set
Expand All @@ -777,7 +771,7 @@ def _patch_ti_validate_request(
except ValidationError as e:
raise RequestValidationError(errors=e.errors())

return dag, ti, body.model_dump(include=fields_to_update, by_alias=True)
return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True)


@task_instances_router.patch(
Expand Down Expand Up @@ -807,21 +801,16 @@ def patch_task_instance_dry_run(
update_mask: list[str] | None = Query(None),
) -> TaskInstanceCollectionResponse:
"""Update a task instance dry_run mode."""
if map_index is None:
map_index = -1

dag, ti, data = _patch_ti_validate_request(
dag, tis, data = _patch_ti_validate_request(
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
)

tis: list[TI] = []

if data.get("new_state"):
tis = (
dag.set_task_instance_state(
task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index],
map_indexes=[map_index] if map_index is not None else None,
state=data["new_state"],
upstream=body.include_upstream,
downstream=body.include_downstream,
Expand All @@ -833,9 +822,6 @@ def patch_task_instance_dry_run(
or []
)

elif "note" in data:
tis = [ti]

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(
Expand Down Expand Up @@ -881,19 +867,16 @@ def patch_task_instance(
update_mask: list[str] | None = Query(None),
) -> TaskInstanceCollectionResponse:
"""Update a task instance."""
if map_index is None:
map_index = -1

dag, ti, data = _patch_ti_validate_request(
dag, tis, data = _patch_ti_validate_request(
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
)

for key, _ in data.items():
if key == "new_state":
tis: list[TI] = dag.set_task_instance_state(
tis = dag.set_task_instance_state(
task_id=task_id,
run_id=dag_run_id,
map_indexes=[map_index],
map_indexes=[map_index] if map_index is not None else None,
state=data["new_state"],
upstream=body.include_upstream,
downstream=body.include_downstream,
Expand All @@ -906,37 +889,39 @@ def patch_task_instance(
raise HTTPException(
status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state"
)
ti = tis[0] if isinstance(tis, list) else tis
try:
if data["new_state"] == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=None, task_instance=ti
)
elif data["new_state"] == TaskInstanceState.FAILED:
get_listener_manager().hook.on_task_instance_failed(
previous_state=None,
task_instance=ti,
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
)
except Exception:
log.exception("error calling listener")

for ti in tis:
try:
if data["new_state"] == TaskInstanceState.SUCCESS:
get_listener_manager().hook.on_task_instance_success(
previous_state=None, task_instance=ti
)
elif data["new_state"] == TaskInstanceState.FAILED:
get_listener_manager().hook.on_task_instance_failed(
previous_state=None,
task_instance=ti,
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
)
except Exception:
log.exception("error calling listener")

elif key == "note":
if update_mask or body.note is not None:
if ti.task_instance_note is None:
ti.note = (body.note, user.get_id())
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = user.get_id()
session.commit()
for ti in tis:
if update_mask or body.note is not None:
if ti.task_instance_note is None:
ti.note = (body.note, user.get_id())
else:
ti.task_instance_note.content = body.note
ti.task_instance_note.user_id = user.get_id()

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(
ti,
)
for ti in tis
],
total_entries=1,
total_entries=len(tis),
)


Expand Down
Loading