24
24
from fastapi .exceptions import RequestValidationError
25
25
from pydantic import ValidationError
26
26
from sqlalchemy import or_ , select
27
- from sqlalchemy .exc import MultipleResultsFound
28
27
from sqlalchemy .orm import joinedload
29
28
from sqlalchemy .sql .selectable import Select
30
29
@@ -736,9 +735,9 @@ def _patch_ti_validate_request(
736
735
dag_bag : DagBagDep ,
737
736
body : PatchTaskInstanceBody ,
738
737
session : SessionDep ,
739
- map_index : int = - 1 ,
738
+ map_index : int | None = - 1 ,
740
739
update_mask : list [str ] | None = Query (None ),
741
- ) -> tuple [DAG , TI , dict ]:
740
+ ) -> tuple [DAG , list [ TI ] , dict ]:
742
741
dag = dag_bag .get_dag (dag_id )
743
742
if not dag :
744
743
raise HTTPException (status .HTTP_404_NOT_FOUND , f"DAG { dag_id } not found" )
@@ -752,20 +751,15 @@ def _patch_ti_validate_request(
752
751
.join (TI .dag_run )
753
752
.options (joinedload (TI .rendered_task_instance_fields ))
754
753
)
755
- query = query .where (TI .map_index == map_index )
754
+ if map_index is not None :
755
+ query = query .where (TI .map_index == map_index )
756
756
757
- try :
758
- ti = session .scalar (query )
759
- except MultipleResultsFound :
760
- raise HTTPException (
761
- status .HTTP_400_BAD_REQUEST ,
762
- "Multiple task instances found. As the TI is mapped, add the map_index value to the URL" ,
763
- )
757
+ tis = session .scalars (query ).all ()
764
758
765
759
err_msg_404 = (
766
760
f"The Task Instance with dag_id: `{ dag_id } `, run_id: `{ dag_run_id } `, task_id: `{ task_id } ` and map_index: `{ map_index } ` was not found" ,
767
761
)
768
- if ti is None :
762
+ if len ( tis ) == 0 :
769
763
raise HTTPException (status .HTTP_404_NOT_FOUND , err_msg_404 )
770
764
771
765
fields_to_update = body .model_fields_set
@@ -777,7 +771,7 @@ def _patch_ti_validate_request(
777
771
except ValidationError as e :
778
772
raise RequestValidationError (errors = e .errors ())
779
773
780
- return dag , ti , body .model_dump (include = fields_to_update , by_alias = True )
774
+ return dag , list ( tis ) , body .model_dump (include = fields_to_update , by_alias = True )
781
775
782
776
783
777
@task_instances_router .patch (
@@ -807,21 +801,16 @@ def patch_task_instance_dry_run(
807
801
update_mask : list [str ] | None = Query (None ),
808
802
) -> TaskInstanceCollectionResponse :
809
803
"""Update a task instance dry_run mode."""
810
- if map_index is None :
811
- map_index = - 1
812
-
813
- dag , ti , data = _patch_ti_validate_request (
804
+ dag , tis , data = _patch_ti_validate_request (
814
805
dag_id , dag_run_id , task_id , dag_bag , body , session , map_index , update_mask
815
806
)
816
807
817
- tis : list [TI ] = []
818
-
819
808
if data .get ("new_state" ):
820
809
tis = (
821
810
dag .set_task_instance_state (
822
811
task_id = task_id ,
823
812
run_id = dag_run_id ,
824
- map_indexes = [map_index ],
813
+ map_indexes = [map_index ] if map_index is not None else None ,
825
814
state = data ["new_state" ],
826
815
upstream = body .include_upstream ,
827
816
downstream = body .include_downstream ,
@@ -833,9 +822,6 @@ def patch_task_instance_dry_run(
833
822
or []
834
823
)
835
824
836
- elif "note" in data :
837
- tis = [ti ]
838
-
839
825
return TaskInstanceCollectionResponse (
840
826
task_instances = [
841
827
TaskInstanceResponse .model_validate (
@@ -881,19 +867,16 @@ def patch_task_instance(
881
867
update_mask : list [str ] | None = Query (None ),
882
868
) -> TaskInstanceCollectionResponse :
883
869
"""Update a task instance."""
884
- if map_index is None :
885
- map_index = - 1
886
-
887
- dag , ti , data = _patch_ti_validate_request (
870
+ dag , tis , data = _patch_ti_validate_request (
888
871
dag_id , dag_run_id , task_id , dag_bag , body , session , map_index , update_mask
889
872
)
890
873
891
874
for key , _ in data .items ():
892
875
if key == "new_state" :
893
- tis : list [ TI ] = dag .set_task_instance_state (
876
+ tis = dag .set_task_instance_state (
894
877
task_id = task_id ,
895
878
run_id = dag_run_id ,
896
- map_indexes = [map_index ],
879
+ map_indexes = [map_index ] if map_index is not None else None ,
897
880
state = data ["new_state" ],
898
881
upstream = body .include_upstream ,
899
882
downstream = body .include_downstream ,
@@ -906,37 +889,39 @@ def patch_task_instance(
906
889
raise HTTPException (
907
890
status .HTTP_409_CONFLICT , f"Task id { task_id } is already in { data ['new_state' ]} state"
908
891
)
909
- ti = tis [0 ] if isinstance (tis , list ) else tis
910
- try :
911
- if data ["new_state" ] == TaskInstanceState .SUCCESS :
912
- get_listener_manager ().hook .on_task_instance_success (
913
- previous_state = None , task_instance = ti
914
- )
915
- elif data ["new_state" ] == TaskInstanceState .FAILED :
916
- get_listener_manager ().hook .on_task_instance_failed (
917
- previous_state = None ,
918
- task_instance = ti ,
919
- error = f"TaskInstance's state was manually set to `{ TaskInstanceState .FAILED } `." ,
920
- )
921
- except Exception :
922
- log .exception ("error calling listener" )
892
+
893
+ for ti in tis :
894
+ try :
895
+ if data ["new_state" ] == TaskInstanceState .SUCCESS :
896
+ get_listener_manager ().hook .on_task_instance_success (
897
+ previous_state = None , task_instance = ti
898
+ )
899
+ elif data ["new_state" ] == TaskInstanceState .FAILED :
900
+ get_listener_manager ().hook .on_task_instance_failed (
901
+ previous_state = None ,
902
+ task_instance = ti ,
903
+ error = f"TaskInstance's state was manually set to `{ TaskInstanceState .FAILED } `." ,
904
+ )
905
+ except Exception :
906
+ log .exception ("error calling listener" )
923
907
924
908
elif key == "note" :
925
- if update_mask or body . note is not None :
926
- if ti . task_instance_note is None :
927
- ti .note = ( body . note , user . get_id ())
928
- else :
929
- ti . task_instance_note . content = body . note
930
- ti .task_instance_note .user_id = user . get_id ()
931
- session . commit ()
909
+ for ti in tis :
910
+ if update_mask or body . note is not None :
911
+ if ti .task_instance_note is None :
912
+ ti . note = ( body . note , user . get_id ())
913
+ else :
914
+ ti .task_instance_note .content = body . note
915
+ ti . task_instance_note . user_id = user . get_id ()
932
916
933
917
return TaskInstanceCollectionResponse (
934
918
task_instances = [
935
919
TaskInstanceResponse .model_validate (
936
920
ti ,
937
921
)
922
+ for ti in tis
938
923
],
939
- total_entries = 1 ,
924
+ total_entries = len ( tis ) ,
940
925
)
941
926
942
927
0 commit comments