Skip to content

Commit 81ac13d

Browse files
pierrejeambrunsanederchik
authored andcommitted
Add Patch Task Instance Summary (apache#50526)
* Add Patch Task Instance Summary * Test small adjustment
1 parent d08b74d commit 81ac13d

File tree

2 files changed

+178
-78
lines changed

2 files changed

+178
-78
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from fastapi.exceptions import RequestValidationError
2525
from pydantic import ValidationError
2626
from sqlalchemy import or_, select
27-
from sqlalchemy.exc import MultipleResultsFound
2827
from sqlalchemy.orm import joinedload
2928
from sqlalchemy.sql.selectable import Select
3029

@@ -736,9 +735,9 @@ def _patch_ti_validate_request(
736735
dag_bag: DagBagDep,
737736
body: PatchTaskInstanceBody,
738737
session: SessionDep,
739-
map_index: int = -1,
738+
map_index: int | None = -1,
740739
update_mask: list[str] | None = Query(None),
741-
) -> tuple[DAG, TI, dict]:
740+
) -> tuple[DAG, list[TI], dict]:
742741
dag = dag_bag.get_dag(dag_id)
743742
if not dag:
744743
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")
@@ -752,20 +751,15 @@ def _patch_ti_validate_request(
752751
.join(TI.dag_run)
753752
.options(joinedload(TI.rendered_task_instance_fields))
754753
)
755-
query = query.where(TI.map_index == map_index)
754+
if map_index is not None:
755+
query = query.where(TI.map_index == map_index)
756756

757-
try:
758-
ti = session.scalar(query)
759-
except MultipleResultsFound:
760-
raise HTTPException(
761-
status.HTTP_400_BAD_REQUEST,
762-
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
763-
)
757+
tis = session.scalars(query).all()
764758

765759
err_msg_404 = (
766760
f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found",
767761
)
768-
if ti is None:
762+
if len(tis) == 0:
769763
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
770764

771765
fields_to_update = body.model_fields_set
@@ -777,7 +771,7 @@ def _patch_ti_validate_request(
777771
except ValidationError as e:
778772
raise RequestValidationError(errors=e.errors())
779773

780-
return dag, ti, body.model_dump(include=fields_to_update, by_alias=True)
774+
return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True)
781775

782776

783777
@task_instances_router.patch(
@@ -807,21 +801,16 @@ def patch_task_instance_dry_run(
807801
update_mask: list[str] | None = Query(None),
808802
) -> TaskInstanceCollectionResponse:
809803
"""Update a task instance dry_run mode."""
810-
if map_index is None:
811-
map_index = -1
812-
813-
dag, ti, data = _patch_ti_validate_request(
804+
dag, tis, data = _patch_ti_validate_request(
814805
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
815806
)
816807

817-
tis: list[TI] = []
818-
819808
if data.get("new_state"):
820809
tis = (
821810
dag.set_task_instance_state(
822811
task_id=task_id,
823812
run_id=dag_run_id,
824-
map_indexes=[map_index],
813+
map_indexes=[map_index] if map_index is not None else None,
825814
state=data["new_state"],
826815
upstream=body.include_upstream,
827816
downstream=body.include_downstream,
@@ -833,9 +822,6 @@ def patch_task_instance_dry_run(
833822
or []
834823
)
835824

836-
elif "note" in data:
837-
tis = [ti]
838-
839825
return TaskInstanceCollectionResponse(
840826
task_instances=[
841827
TaskInstanceResponse.model_validate(
@@ -881,19 +867,16 @@ def patch_task_instance(
881867
update_mask: list[str] | None = Query(None),
882868
) -> TaskInstanceCollectionResponse:
883869
"""Update a task instance."""
884-
if map_index is None:
885-
map_index = -1
886-
887-
dag, ti, data = _patch_ti_validate_request(
870+
dag, tis, data = _patch_ti_validate_request(
888871
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
889872
)
890873

891874
for key, _ in data.items():
892875
if key == "new_state":
893-
tis: list[TI] = dag.set_task_instance_state(
876+
tis = dag.set_task_instance_state(
894877
task_id=task_id,
895878
run_id=dag_run_id,
896-
map_indexes=[map_index],
879+
map_indexes=[map_index] if map_index is not None else None,
897880
state=data["new_state"],
898881
upstream=body.include_upstream,
899882
downstream=body.include_downstream,
@@ -906,37 +889,39 @@ def patch_task_instance(
906889
raise HTTPException(
907890
status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state"
908891
)
909-
ti = tis[0] if isinstance(tis, list) else tis
910-
try:
911-
if data["new_state"] == TaskInstanceState.SUCCESS:
912-
get_listener_manager().hook.on_task_instance_success(
913-
previous_state=None, task_instance=ti
914-
)
915-
elif data["new_state"] == TaskInstanceState.FAILED:
916-
get_listener_manager().hook.on_task_instance_failed(
917-
previous_state=None,
918-
task_instance=ti,
919-
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
920-
)
921-
except Exception:
922-
log.exception("error calling listener")
892+
893+
for ti in tis:
894+
try:
895+
if data["new_state"] == TaskInstanceState.SUCCESS:
896+
get_listener_manager().hook.on_task_instance_success(
897+
previous_state=None, task_instance=ti
898+
)
899+
elif data["new_state"] == TaskInstanceState.FAILED:
900+
get_listener_manager().hook.on_task_instance_failed(
901+
previous_state=None,
902+
task_instance=ti,
903+
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
904+
)
905+
except Exception:
906+
log.exception("error calling listener")
923907

924908
elif key == "note":
925-
if update_mask or body.note is not None:
926-
if ti.task_instance_note is None:
927-
ti.note = (body.note, user.get_id())
928-
else:
929-
ti.task_instance_note.content = body.note
930-
ti.task_instance_note.user_id = user.get_id()
931-
session.commit()
909+
for ti in tis:
910+
if update_mask or body.note is not None:
911+
if ti.task_instance_note is None:
912+
ti.note = (body.note, user.get_id())
913+
else:
914+
ti.task_instance_note.content = body.note
915+
ti.task_instance_note.user_id = user.get_id()
932916

933917
return TaskInstanceCollectionResponse(
934918
task_instances=[
935919
TaskInstanceResponse.model_validate(
936920
ti,
937921
)
922+
for ti in tis
938923
],
939-
total_entries=1,
924+
total_entries=len(tis),
940925
)
941926

942927

0 commit comments

Comments
 (0)