Skip to content

Commit d31279c

Browse files
Add Patch Task Instance Summary
1 parent 1794420 commit d31279c

File tree

7 files changed

+956
-31
lines changed

7 files changed

+956
-31
lines changed

airflow-core/src/airflow/api_fastapi/core_api/openapi/v1-rest-api-generated.yaml

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5853,7 +5853,9 @@ paths:
58535853
in: path
58545854
required: true
58555855
schema:
5856-
type: integer
5856+
anyOf:
5857+
- type: integer
5858+
- type: 'null'
58575859
title: Map Index
58585860
- name: update_mask
58595861
in: query
@@ -5940,8 +5942,9 @@ paths:
59405942
in: query
59415943
required: false
59425944
schema:
5943-
type: integer
5944-
default: -1
5945+
anyOf:
5946+
- type: integer
5947+
- type: 'null'
59455948
title: Map Index
59465949
- name: update_mask
59475950
in: query
@@ -5996,6 +5999,99 @@ paths:
59965999
application/json:
59976000
schema:
59986001
$ref: '#/components/schemas/HTTPValidationError'
6002+
/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/summary:
6003+
patch:
6004+
tags:
6005+
- Task Instance
6006+
summary: Patch Task Instance Summary
6007+
description: 'Update a task instance summary.
6008+
6009+
6010+
If the task is unmapped this is similar to the patch task instance endpoint.
6011+
For mapped task
6012+
6013+
updates all the mapped task instances at once.'
6014+
operationId: patch_task_instance_summary
6015+
security:
6016+
- OAuth2PasswordBearer: []
6017+
parameters:
6018+
- name: dag_id
6019+
in: path
6020+
required: true
6021+
schema:
6022+
type: string
6023+
title: Dag Id
6024+
- name: dag_run_id
6025+
in: path
6026+
required: true
6027+
schema:
6028+
type: string
6029+
title: Dag Run Id
6030+
- name: task_id
6031+
in: path
6032+
required: true
6033+
schema:
6034+
type: string
6035+
title: Task Id
6036+
- name: update_mask
6037+
in: query
6038+
required: false
6039+
schema:
6040+
anyOf:
6041+
- type: array
6042+
items:
6043+
type: string
6044+
- type: 'null'
6045+
title: Update Mask
6046+
requestBody:
6047+
required: true
6048+
content:
6049+
application/json:
6050+
schema:
6051+
$ref: '#/components/schemas/PatchTaskInstanceBody'
6052+
responses:
6053+
'200':
6054+
description: Successful Response
6055+
content:
6056+
application/json:
6057+
schema:
6058+
$ref: '#/components/schemas/TaskInstanceCollectionResponse'
6059+
'401':
6060+
content:
6061+
application/json:
6062+
schema:
6063+
$ref: '#/components/schemas/HTTPExceptionResponse'
6064+
description: Unauthorized
6065+
'403':
6066+
content:
6067+
application/json:
6068+
schema:
6069+
$ref: '#/components/schemas/HTTPExceptionResponse'
6070+
description: Forbidden
6071+
'400':
6072+
content:
6073+
application/json:
6074+
schema:
6075+
$ref: '#/components/schemas/HTTPExceptionResponse'
6076+
description: Bad Request
6077+
'404':
6078+
content:
6079+
application/json:
6080+
schema:
6081+
$ref: '#/components/schemas/HTTPExceptionResponse'
6082+
description: Not Found
6083+
'409':
6084+
content:
6085+
application/json:
6086+
schema:
6087+
$ref: '#/components/schemas/HTTPExceptionResponse'
6088+
description: Conflict
6089+
'422':
6090+
description: Validation Error
6091+
content:
6092+
application/json:
6093+
schema:
6094+
$ref: '#/components/schemas/HTTPValidationError'
59996095
/api/v2/dags/{dag_id}/tasks:
60006096
get:
60016097
tags:

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

Lines changed: 96 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from fastapi.exceptions import RequestValidationError
2525
from pydantic import ValidationError
2626
from sqlalchemy import or_, select
27-
from sqlalchemy.exc import MultipleResultsFound
2827
from sqlalchemy.orm import joinedload
2928
from sqlalchemy.sql.selectable import Select
3029

@@ -736,9 +735,9 @@ def _patch_ti_validate_request(
736735
dag_bag: DagBagDep,
737736
body: PatchTaskInstanceBody,
738737
session: SessionDep,
739-
map_index: int = -1,
738+
map_index: int | None = -1,
740739
update_mask: list[str] | None = Query(None),
741-
) -> tuple[DAG, TI, dict]:
740+
) -> tuple[DAG, list[TI], dict]:
742741
dag = dag_bag.get_dag(dag_id)
743742
if not dag:
744743
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")
@@ -752,20 +751,15 @@ def _patch_ti_validate_request(
752751
.join(TI.dag_run)
753752
.options(joinedload(TI.rendered_task_instance_fields))
754753
)
755-
query = query.where(TI.map_index == map_index)
754+
if map_index is not None:
755+
query = query.where(TI.map_index == map_index)
756756

757-
try:
758-
ti = session.scalar(query)
759-
except MultipleResultsFound:
760-
raise HTTPException(
761-
status.HTTP_400_BAD_REQUEST,
762-
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
763-
)
757+
tis = session.scalars(query).all()
764758

765759
err_msg_404 = (
766760
f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found",
767761
)
768-
if ti is None:
762+
if len(tis) == 0:
769763
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
770764

771765
fields_to_update = body.model_fields_set
@@ -777,7 +771,7 @@ def _patch_ti_validate_request(
777771
except ValidationError as e:
778772
raise RequestValidationError(errors=e.errors())
779773

780-
return dag, ti, body.model_dump(include=fields_to_update, by_alias=True)
774+
return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True)
781775

782776

783777
@task_instances_router.patch(
@@ -803,22 +797,20 @@ def patch_task_instance_dry_run(
803797
dag_bag: DagBagDep,
804798
body: PatchTaskInstanceBody,
805799
session: SessionDep,
806-
map_index: int = -1,
800+
map_index: int | None = None,
807801
update_mask: list[str] | None = Query(None),
808802
) -> TaskInstanceCollectionResponse:
809803
"""Update a task instance dry_run mode."""
810-
dag, ti, data = _patch_ti_validate_request(
804+
dag, tis, data = _patch_ti_validate_request(
811805
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
812806
)
813807

814-
tis: list[TI] = []
815-
816808
if data.get("new_state"):
817809
tis = (
818810
dag.set_task_instance_state(
819811
task_id=task_id,
820812
run_id=dag_run_id,
821-
map_indexes=[map_index],
813+
map_indexes=[map_index] if map_index is not None else None,
822814
state=data["new_state"],
823815
upstream=body.include_upstream,
824816
downstream=body.include_downstream,
@@ -830,8 +822,88 @@ def patch_task_instance_dry_run(
830822
or []
831823
)
832824

833-
elif "note" in data:
834-
tis = [ti]
825+
return TaskInstanceCollectionResponse(
826+
task_instances=[
827+
TaskInstanceResponse.model_validate(
828+
ti,
829+
)
830+
for ti in tis
831+
],
832+
total_entries=len(tis),
833+
)
834+
835+
836+
@task_instances_router.patch(
837+
task_instances_prefix + "/{task_id}/summary",
838+
responses=create_openapi_http_exception_doc(
839+
[status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT],
840+
),
841+
dependencies=[
842+
Depends(action_logging()),
843+
Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE)),
844+
],
845+
)
846+
def patch_task_instance_summary(
847+
dag_id: str,
848+
dag_run_id: str,
849+
task_id: str,
850+
dag_bag: DagBagDep,
851+
body: PatchTaskInstanceBody,
852+
user: GetUserDep,
853+
session: SessionDep,
854+
update_mask: list[str] | None = Query(None),
855+
) -> TaskInstanceCollectionResponse:
856+
"""
857+
Update a task instance summary.
858+
859+
If the task is unmapped this is similar to the patch task instance endpoint. For mapped task
860+
updates all the mapped task instances at once.
861+
"""
862+
dag, tis, data = _patch_ti_validate_request(
863+
dag_id, dag_run_id, task_id, dag_bag, body, session, None, update_mask
864+
)
865+
866+
for key, _ in data.items():
867+
if key == "new_state":
868+
tis = dag.set_task_instance_state(
869+
task_id=task_id,
870+
run_id=dag_run_id,
871+
map_indexes=None,
872+
state=data["new_state"],
873+
upstream=body.include_upstream,
874+
downstream=body.include_downstream,
875+
future=body.include_future,
876+
past=body.include_past,
877+
commit=True,
878+
session=session,
879+
)
880+
if not tis:
881+
raise HTTPException(
882+
status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state"
883+
)
884+
for ti in tis:
885+
try:
886+
if data["new_state"] == TaskInstanceState.SUCCESS:
887+
get_listener_manager().hook.on_task_instance_success(
888+
previous_state=None, task_instance=ti
889+
)
890+
elif data["new_state"] == TaskInstanceState.FAILED:
891+
get_listener_manager().hook.on_task_instance_failed(
892+
previous_state=None,
893+
task_instance=ti,
894+
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
895+
)
896+
except Exception:
897+
log.exception("error calling listener")
898+
899+
elif key == "note":
900+
for ti in tis:
901+
if update_mask or body.note is not None:
902+
if ti.task_instance_note is None:
903+
ti.note = (body.note, user.get_id())
904+
else:
905+
ti.task_instance_note.content = body.note
906+
ti.task_instance_note.user_id = user.get_id()
835907

836908
return TaskInstanceCollectionResponse(
837909
task_instances=[
@@ -878,13 +950,15 @@ def patch_task_instance(
878950
update_mask: list[str] | None = Query(None),
879951
) -> TaskInstanceResponse:
880952
"""Update a task instance."""
881-
dag, ti, data = _patch_ti_validate_request(
953+
dag, tis, data = _patch_ti_validate_request(
882954
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
883955
)
884956

957+
ti = tis[0]
958+
885959
for key, _ in data.items():
886960
if key == "new_state":
887-
tis: list[TI] = dag.set_task_instance_state(
961+
tis = dag.set_task_instance_state(
888962
task_id=task_id,
889963
run_id=dag_run_id,
890964
map_indexes=[map_index],
@@ -922,7 +996,6 @@ def patch_task_instance(
922996
else:
923997
ti.task_instance_note.content = body.note
924998
ti.task_instance_note.user_id = user.get_id()
925-
session.commit()
926999

9271000
return TaskInstanceResponse.model_validate(ti)
9281001

airflow-core/src/airflow/ui/openapi-gen/queries/common.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,9 @@ export type TaskInstanceServicePatchTaskInstanceDryRunByMapIndexMutationResult =
19181918
export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited<
19191919
ReturnType<typeof TaskInstanceService.patchTaskInstanceDryRun>
19201920
>;
1921+
export type TaskInstanceServicePatchTaskInstanceSummaryMutationResult = Awaited<
1922+
ReturnType<typeof TaskInstanceService.patchTaskInstanceSummary>
1923+
>;
19211924
export type PoolServicePatchPoolMutationResult = Awaited<ReturnType<typeof PoolService.patchPool>>;
19221925
export type PoolServiceBulkPoolsMutationResult = Awaited<ReturnType<typeof PoolService.bulkPools>>;
19231926
export type XcomServiceUpdateXcomEntryMutationResult = Awaited<

airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4292,6 +4292,64 @@ export const useTaskInstanceServicePatchTaskInstanceDryRun = <
42924292
}) as unknown as Promise<TData>,
42934293
...options,
42944294
});
4295+
/**
4296+
* Patch Task Instance Summary
4297+
* Update a task instance summary.
4298+
*
4299+
* If the task is unmapped this is similar to the patch task instance endpoint. For mapped task
4300+
* updates all the mapped task instances at once.
4301+
* @param data The data for the request.
4302+
* @param data.dagId
4303+
* @param data.dagRunId
4304+
* @param data.taskId
4305+
* @param data.requestBody
4306+
* @param data.updateMask
4307+
* @returns TaskInstanceCollectionResponse Successful Response
4308+
* @throws ApiError
4309+
*/
4310+
export const useTaskInstanceServicePatchTaskInstanceSummary = <
4311+
TData = Common.TaskInstanceServicePatchTaskInstanceSummaryMutationResult,
4312+
TError = unknown,
4313+
TContext = unknown,
4314+
>(
4315+
options?: Omit<
4316+
UseMutationOptions<
4317+
TData,
4318+
TError,
4319+
{
4320+
dagId: string;
4321+
dagRunId: string;
4322+
requestBody: PatchTaskInstanceBody;
4323+
taskId: string;
4324+
updateMask?: string[];
4325+
},
4326+
TContext
4327+
>,
4328+
"mutationFn"
4329+
>,
4330+
) =>
4331+
useMutation<
4332+
TData,
4333+
TError,
4334+
{
4335+
dagId: string;
4336+
dagRunId: string;
4337+
requestBody: PatchTaskInstanceBody;
4338+
taskId: string;
4339+
updateMask?: string[];
4340+
},
4341+
TContext
4342+
>({
4343+
mutationFn: ({ dagId, dagRunId, requestBody, taskId, updateMask }) =>
4344+
TaskInstanceService.patchTaskInstanceSummary({
4345+
dagId,
4346+
dagRunId,
4347+
requestBody,
4348+
taskId,
4349+
updateMask,
4350+
}) as unknown as Promise<TData>,
4351+
...options,
4352+
});
42954353
/**
42964354
* Patch Pool
42974355
* Update a Pool.

0 commit comments

Comments
 (0)