Skip to content

Commit b76da2e

Browse files
Add Patch Task Instance Summary
1 parent e37a40d commit b76da2e

File tree

7 files changed

+422
-78
lines changed

7 files changed

+422
-78
lines changed

airflow-core/src/airflow/api_fastapi/core_api/openapi/v1-rest-api-generated.yaml

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6002,6 +6002,99 @@ paths:
60026002
application/json:
60036003
schema:
60046004
$ref: '#/components/schemas/HTTPValidationError'
6005+
/api/v2/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/summary:
6006+
patch:
6007+
tags:
6008+
- Task Instance
6009+
summary: Patch Task Instance Summary
6010+
description: 'Update a task instance summary.
6011+
6012+
6013+
If the task is unmapped this is similar to the patch task instance endpoint.
6014+
For mapped task
6015+
6016+
updates all the mapped task instances at once.'
6017+
operationId: patch_task_instance_summary
6018+
security:
6019+
- OAuth2PasswordBearer: []
6020+
parameters:
6021+
- name: dag_id
6022+
in: path
6023+
required: true
6024+
schema:
6025+
type: string
6026+
title: Dag Id
6027+
- name: dag_run_id
6028+
in: path
6029+
required: true
6030+
schema:
6031+
type: string
6032+
title: Dag Run Id
6033+
- name: task_id
6034+
in: path
6035+
required: true
6036+
schema:
6037+
type: string
6038+
title: Task Id
6039+
- name: update_mask
6040+
in: query
6041+
required: false
6042+
schema:
6043+
anyOf:
6044+
- type: array
6045+
items:
6046+
type: string
6047+
- type: 'null'
6048+
title: Update Mask
6049+
requestBody:
6050+
required: true
6051+
content:
6052+
application/json:
6053+
schema:
6054+
$ref: '#/components/schemas/PatchTaskInstanceBody'
6055+
responses:
6056+
'200':
6057+
description: Successful Response
6058+
content:
6059+
application/json:
6060+
schema:
6061+
$ref: '#/components/schemas/TaskInstanceCollectionResponse'
6062+
'401':
6063+
content:
6064+
application/json:
6065+
schema:
6066+
$ref: '#/components/schemas/HTTPExceptionResponse'
6067+
description: Unauthorized
6068+
'403':
6069+
content:
6070+
application/json:
6071+
schema:
6072+
$ref: '#/components/schemas/HTTPExceptionResponse'
6073+
description: Forbidden
6074+
'400':
6075+
content:
6076+
application/json:
6077+
schema:
6078+
$ref: '#/components/schemas/HTTPExceptionResponse'
6079+
description: Bad Request
6080+
'404':
6081+
content:
6082+
application/json:
6083+
schema:
6084+
$ref: '#/components/schemas/HTTPExceptionResponse'
6085+
description: Not Found
6086+
'409':
6087+
content:
6088+
application/json:
6089+
schema:
6090+
$ref: '#/components/schemas/HTTPExceptionResponse'
6091+
description: Conflict
6092+
'422':
6093+
description: Validation Error
6094+
content:
6095+
application/json:
6096+
schema:
6097+
$ref: '#/components/schemas/HTTPValidationError'
60056098
/api/v2/dags/{dag_id}/tasks:
60066099
get:
60076100
tags:

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from fastapi.exceptions import RequestValidationError
2525
from pydantic import ValidationError
2626
from sqlalchemy import or_, select
27-
from sqlalchemy.exc import MultipleResultsFound
2827
from sqlalchemy.orm import joinedload
2928
from sqlalchemy.sql.selectable import Select
3029

@@ -736,9 +735,9 @@ def _patch_ti_validate_request(
736735
dag_bag: DagBagDep,
737736
body: PatchTaskInstanceBody,
738737
session: SessionDep,
739-
map_index: int = -1,
738+
map_index: int | None = -1,
740739
update_mask: list[str] | None = Query(None),
741-
) -> tuple[DAG, TI, dict]:
740+
) -> tuple[DAG, list[TI], dict]:
742741
dag = dag_bag.get_dag(dag_id)
743742
if not dag:
744743
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found")
@@ -752,20 +751,15 @@ def _patch_ti_validate_request(
752751
.join(TI.dag_run)
753752
.options(joinedload(TI.rendered_task_instance_fields))
754753
)
755-
query = query.where(TI.map_index == map_index)
754+
if map_index is not None:
755+
query = query.where(TI.map_index == map_index)
756756

757-
try:
758-
ti = session.scalar(query)
759-
except MultipleResultsFound:
760-
raise HTTPException(
761-
status.HTTP_400_BAD_REQUEST,
762-
"Multiple task instances found. As the TI is mapped, add the map_index value to the URL",
763-
)
757+
tis = session.scalars(query).all()
764758

765759
err_msg_404 = (
766760
f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}` and map_index: `{map_index}` was not found",
767761
)
768-
if ti is None:
762+
if len(tis) == 0:
769763
raise HTTPException(status.HTTP_404_NOT_FOUND, err_msg_404)
770764

771765
fields_to_update = body.model_fields_set
@@ -777,7 +771,7 @@ def _patch_ti_validate_request(
777771
except ValidationError as e:
778772
raise RequestValidationError(errors=e.errors())
779773

780-
return dag, ti, body.model_dump(include=fields_to_update, by_alias=True)
774+
return dag, list(tis), body.model_dump(include=fields_to_update, by_alias=True)
781775

782776

783777
@task_instances_router.patch(
@@ -807,21 +801,16 @@ def patch_task_instance_dry_run(
807801
update_mask: list[str] | None = Query(None),
808802
) -> TaskInstanceCollectionResponse:
809803
"""Update a task instance dry_run mode."""
810-
if map_index is None:
811-
map_index = -1
812-
813-
dag, ti, data = _patch_ti_validate_request(
804+
dag, tis, data = _patch_ti_validate_request(
814805
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
815806
)
816807

817-
tis: list[TI] = []
818-
819808
if data.get("new_state"):
820809
tis = (
821810
dag.set_task_instance_state(
822811
task_id=task_id,
823812
run_id=dag_run_id,
824-
map_indexes=[map_index],
813+
map_indexes=[map_index] if map_index is not None else None,
825814
state=data["new_state"],
826815
upstream=body.include_upstream,
827816
downstream=body.include_downstream,
@@ -833,9 +822,6 @@ def patch_task_instance_dry_run(
833822
or []
834823
)
835824

836-
elif "note" in data:
837-
tis = [ti]
838-
839825
return TaskInstanceCollectionResponse(
840826
task_instances=[
841827
TaskInstanceResponse.model_validate(
@@ -881,19 +867,16 @@ def patch_task_instance(
881867
update_mask: list[str] | None = Query(None),
882868
) -> TaskInstanceCollectionResponse:
883869
"""Update a task instance."""
884-
if map_index is None:
885-
map_index = -1
886-
887-
dag, ti, data = _patch_ti_validate_request(
870+
dag, tis, data = _patch_ti_validate_request(
888871
dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask
889872
)
890873

891874
for key, _ in data.items():
892875
if key == "new_state":
893-
tis: list[TI] = dag.set_task_instance_state(
876+
tis = dag.set_task_instance_state(
894877
task_id=task_id,
895878
run_id=dag_run_id,
896-
map_indexes=[map_index],
879+
map_indexes=[map_index] if map_index is not None else None,
897880
state=data["new_state"],
898881
upstream=body.include_upstream,
899882
downstream=body.include_downstream,
@@ -906,37 +889,39 @@ def patch_task_instance(
906889
raise HTTPException(
907890
status.HTTP_409_CONFLICT, f"Task id {task_id} is already in {data['new_state']} state"
908891
)
909-
ti = tis[0] if isinstance(tis, list) else tis
910-
try:
911-
if data["new_state"] == TaskInstanceState.SUCCESS:
912-
get_listener_manager().hook.on_task_instance_success(
913-
previous_state=None, task_instance=ti
914-
)
915-
elif data["new_state"] == TaskInstanceState.FAILED:
916-
get_listener_manager().hook.on_task_instance_failed(
917-
previous_state=None,
918-
task_instance=ti,
919-
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
920-
)
921-
except Exception:
922-
log.exception("error calling listener")
892+
893+
for ti in tis:
894+
try:
895+
if data["new_state"] == TaskInstanceState.SUCCESS:
896+
get_listener_manager().hook.on_task_instance_success(
897+
previous_state=None, task_instance=ti
898+
)
899+
elif data["new_state"] == TaskInstanceState.FAILED:
900+
get_listener_manager().hook.on_task_instance_failed(
901+
previous_state=None,
902+
task_instance=ti,
903+
error=f"TaskInstance's state was manually set to `{TaskInstanceState.FAILED}`.",
904+
)
905+
except Exception:
906+
log.exception("error calling listener")
923907

924908
elif key == "note":
925-
if update_mask or body.note is not None:
926-
if ti.task_instance_note is None:
927-
ti.note = (body.note, user.get_id())
928-
else:
929-
ti.task_instance_note.content = body.note
930-
ti.task_instance_note.user_id = user.get_id()
931-
session.commit()
909+
for ti in tis:
910+
if update_mask or body.note is not None:
911+
if ti.task_instance_note is None:
912+
ti.note = (body.note, user.get_id())
913+
else:
914+
ti.task_instance_note.content = body.note
915+
ti.task_instance_note.user_id = user.get_id()
932916

933917
return TaskInstanceCollectionResponse(
934918
task_instances=[
935919
TaskInstanceResponse.model_validate(
936920
ti,
937921
)
922+
for ti in tis
938923
],
939-
total_entries=1,
924+
total_entries=len(tis),
940925
)
941926

942927

airflow-core/src/airflow/ui/openapi-gen/queries/common.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,9 @@ export type TaskInstanceServicePatchTaskInstanceDryRunByMapIndexMutationResult =
19181918
export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited<
19191919
ReturnType<typeof TaskInstanceService.patchTaskInstanceDryRun>
19201920
>;
1921+
export type TaskInstanceServicePatchTaskInstanceSummaryMutationResult = Awaited<
1922+
ReturnType<typeof TaskInstanceService.patchTaskInstanceSummary>
1923+
>;
19211924
export type PoolServicePatchPoolMutationResult = Awaited<ReturnType<typeof PoolService.patchPool>>;
19221925
export type PoolServiceBulkPoolsMutationResult = Awaited<ReturnType<typeof PoolService.bulkPools>>;
19231926
export type XcomServiceUpdateXcomEntryMutationResult = Awaited<

airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4292,6 +4292,64 @@ export const useTaskInstanceServicePatchTaskInstanceDryRun = <
42924292
}) as unknown as Promise<TData>,
42934293
...options,
42944294
});
4295+
/**
4296+
* Patch Task Instance Summary
4297+
* Update a task instance summary.
4298+
*
4299+
* If the task is unmapped this is similar to the patch task instance endpoint. For mapped task
4300+
* updates all the mapped task instances at once.
4301+
* @param data The data for the request.
4302+
* @param data.dagId
4303+
* @param data.dagRunId
4304+
* @param data.taskId
4305+
* @param data.requestBody
4306+
* @param data.updateMask
4307+
* @returns TaskInstanceCollectionResponse Successful Response
4308+
* @throws ApiError
4309+
*/
4310+
export const useTaskInstanceServicePatchTaskInstanceSummary = <
4311+
TData = Common.TaskInstanceServicePatchTaskInstanceSummaryMutationResult,
4312+
TError = unknown,
4313+
TContext = unknown,
4314+
>(
4315+
options?: Omit<
4316+
UseMutationOptions<
4317+
TData,
4318+
TError,
4319+
{
4320+
dagId: string;
4321+
dagRunId: string;
4322+
requestBody: PatchTaskInstanceBody;
4323+
taskId: string;
4324+
updateMask?: string[];
4325+
},
4326+
TContext
4327+
>,
4328+
"mutationFn"
4329+
>,
4330+
) =>
4331+
useMutation<
4332+
TData,
4333+
TError,
4334+
{
4335+
dagId: string;
4336+
dagRunId: string;
4337+
requestBody: PatchTaskInstanceBody;
4338+
taskId: string;
4339+
updateMask?: string[];
4340+
},
4341+
TContext
4342+
>({
4343+
mutationFn: ({ dagId, dagRunId, requestBody, taskId, updateMask }) =>
4344+
TaskInstanceService.patchTaskInstanceSummary({
4345+
dagId,
4346+
dagRunId,
4347+
requestBody,
4348+
taskId,
4349+
updateMask,
4350+
}) as unknown as Promise<TData>,
4351+
...options,
4352+
});
42954353
/**
42964354
* Patch Pool
42974355
* Update a Pool.

0 commit comments

Comments
 (0)