@@ -268,9 +268,11 @@ case class AdaptiveSparkPlanExec(
268
268
269
269
def finalPhysicalPlan : SparkPlan = withFinalPlanUpdate(identity)
270
270
271
- private def getFinalPhysicalPlan (): SparkPlan = lock.synchronized {
272
- if (isFinalPlan) return currentPhysicalPlan
273
-
271
+ /**
272
+ * Run `fun` on finalized physical plan
273
+ */
274
+ def withFinalPlanUpdate [T ](fun : SparkPlan => T ): T = lock.synchronized {
275
+ _isFinalPlan = false
274
276
// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
275
277
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
276
278
// created in the middle of the execution.
@@ -279,7 +281,7 @@ case class AdaptiveSparkPlanExec(
279
281
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
280
282
// during `initialPlan`
281
283
var currentLogicalPlan = inputPlan.logicalLink.get
282
- var result = createQueryStages(currentPhysicalPlan)
284
+ var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true )
283
285
val events = new LinkedBlockingQueue [StageMaterializationEvent ]()
284
286
val errors = new mutable.ArrayBuffer [Throwable ]()
285
287
var stagesToReplace = Seq .empty[QueryStageExec ]
@@ -344,56 +346,53 @@ case class AdaptiveSparkPlanExec(
344
346
if (errors.nonEmpty) {
345
347
cleanUpAndThrowException(errors.toSeq, None )
346
348
}
347
-
348
- // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
349
- // than that of the current plan; otherwise keep the current physical plan together with
350
- // the current logical plan since the physical plan's logical links point to the logical
351
- // plan it has originated from.
352
- // Meanwhile, we keep a list of the query stages that have been created since last plan
353
- // update, which stands for the "semantic gap" between the current logical and physical
354
- // plans. And each time before re-planning, we replace the corresponding nodes in the
355
- // current logical plan with logical query stages to make it semantically in sync with
356
- // the current physical plan. Once a new plan is adopted and both logical and physical
357
- // plans are updated, we can clear the query stage list because at this point the two plans
358
- // are semantically and physically in sync again.
359
- val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
360
- val afterReOptimize = reOptimize(logicalPlan)
361
- if (afterReOptimize.isDefined) {
362
- val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
363
- val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
364
- val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
365
- if (newCost < origCost ||
366
- (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
367
- lazy val plans =
368
- sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString(" \n " )
369
- logOnLevel(log " Plan changed: \n ${MDC (QUERY_PLAN , plans)}" )
370
- cleanUpTempTags(newPhysicalPlan)
371
- currentPhysicalPlan = newPhysicalPlan
372
- currentLogicalPlan = newLogicalPlan
373
- stagesToReplace = Seq .empty[QueryStageExec ]
349
+ if (! currentPhysicalPlan.isInstanceOf [ResultQueryStageExec ]) {
350
+ // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
351
+ // than that of the current plan; otherwise keep the current physical plan together with
352
+ // the current logical plan since the physical plan's logical links point to the logical
353
+ // plan it has originated from.
354
+ // Meanwhile, we keep a list of the query stages that have been created since last plan
355
+ // update, which stands for the "semantic gap" between the current logical and physical
356
+ // plans. And each time before re-planning, we replace the corresponding nodes in the
357
+ // current logical plan with logical query stages to make it semantically in sync with
358
+ // the current physical plan. Once a new plan is adopted and both logical and physical
359
+ // plans are updated, we can clear the query stage list because at this point the two
360
+ // plans are semantically and physically in sync again.
361
+ val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
362
+ val afterReOptimize = reOptimize(logicalPlan)
363
+ if (afterReOptimize.isDefined) {
364
+ val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
365
+ val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
366
+ val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
367
+ if (newCost < origCost ||
368
+ (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
369
+ lazy val plans = sideBySide(
370
+ currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString(" \n " )
371
+ logOnLevel(log " Plan changed: \n ${MDC (QUERY_PLAN , plans)}" )
372
+ cleanUpTempTags(newPhysicalPlan)
373
+ currentPhysicalPlan = newPhysicalPlan
374
+ currentLogicalPlan = newLogicalPlan
375
+ stagesToReplace = Seq .empty[QueryStageExec ]
376
+ }
374
377
}
375
378
}
376
379
// Now that some stages have finished, we can try creating new stages.
377
- result = createQueryStages(currentPhysicalPlan)
380
+ result = createQueryStages(fun, currentPhysicalPlan, firstRun = false )
378
381
}
379
-
380
- // Run the final plan when there's no more unfinished stages.
381
- currentPhysicalPlan = applyPhysicalRules(
382
- optimizeQueryStage(result.newPlan, isFinalStage = true ),
383
- postStageCreationRules(supportsColumnar),
384
- Some ((planChangeLogger, " AQE Post Stage Creation" )))
385
- _isFinalPlan = true
386
- executionId.foreach(onUpdatePlan(_, Seq (currentPhysicalPlan)))
387
- currentPhysicalPlan
388
382
}
383
+ _isFinalPlan = true
384
+ finalPlanUpdate
385
+ // Dereference the result so it can be GCed. After this resultStage.isMaterialized will return
386
+ // false, which is expected. If we want to collect result again, we should invoke
387
+ // `withFinalPlanUpdate` and pass another result handler and we will create a new result stage.
388
+ currentPhysicalPlan.asInstanceOf [ResultQueryStageExec ].resultOption.getAndUpdate(_ => None )
389
+ .get.asInstanceOf [T ]
389
390
}
390
391
391
392
// Use a lazy val to avoid this being called more than once.
392
393
@ transient private lazy val finalPlanUpdate : Unit = {
393
- // Subqueries that don't belong to any query stage of the main query will execute after the
394
- // last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure
395
- // the newly generated nodes of those subqueries are updated.
396
- if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
394
+ // Do final plan update after result stage has materialized.
395
+ if (shouldUpdatePlan) {
397
396
getExecutionId.foreach(onUpdatePlan(_, Seq .empty))
398
397
}
399
398
logOnLevel(log " Final plan: \n ${MDC (QUERY_PLAN , currentPhysicalPlan)}" )
@@ -426,13 +425,6 @@ case class AdaptiveSparkPlanExec(
426
425
}
427
426
}
428
427
429
- private def withFinalPlanUpdate [T ](fun : SparkPlan => T ): T = {
430
- val plan = getFinalPhysicalPlan()
431
- val result = fun(plan)
432
- finalPlanUpdate
433
- result
434
- }
435
-
436
428
protected override def stringArgs : Iterator [Any ] = Iterator (s " isFinalPlan= $isFinalPlan" )
437
429
438
430
override def generateTreeString (
@@ -521,6 +513,66 @@ case class AdaptiveSparkPlanExec(
521
513
this .inputPlan == obj.asInstanceOf [AdaptiveSparkPlanExec ].inputPlan
522
514
}
523
515
516
+ /**
517
+ * We separate stage creation of result and non-result stages because there are several edge cases
518
+ * of result stage creation:
519
+ * - existing ResultQueryStage created in previous `withFinalPlanUpdate`.
520
+ * - the root node is a non-result query stage and we have to create query result stage on top of
521
+ * it.
522
+ * - we create a non-result query stage as root node and the stage is immediately materialized
523
+ * due to stage resue, therefore we have to create a result stage right after.
524
+ *
525
+ * This method wraps around `createNonResultQueryStages`, the general logic is:
526
+ * - Early return if ResultQueryStageExec already created before.
527
+ * - Create non result query stage if possible.
528
+ * - Try to create result query stage when there is no new non-result query stage created and all
529
+ * stages are materialized.
530
+ */
531
+ private def createQueryStages (
532
+ resultHandler : SparkPlan => Any ,
533
+ plan : SparkPlan ,
534
+ firstRun : Boolean ): CreateStageResult = {
535
+ plan match {
536
+ // 1. ResultQueryStageExec is already created, no need to create non-result stages
537
+ case resultStage @ ResultQueryStageExec (_, optimizedPlan, _) =>
538
+ assertStageNotFailed(resultStage)
539
+ if (firstRun) {
540
+ // There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
541
+ // e.g, when we do `df.collect` multiple times. Here we create a new result stage to
542
+ // execute it again, as the handler function can be different.
543
+ val newResultStage = ResultQueryStageExec (currentStageId, optimizedPlan, resultHandler)
544
+ currentStageId += 1
545
+ setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
546
+ CreateStageResult (newPlan = newResultStage,
547
+ allChildStagesMaterialized = false ,
548
+ newStages = Seq (newResultStage))
549
+ } else {
550
+ // We will hit this branch after we've created result query stage in the AQE loop, we
551
+ // should do nothing.
552
+ CreateStageResult (newPlan = resultStage,
553
+ allChildStagesMaterialized = resultStage.isMaterialized,
554
+ newStages = Seq .empty)
555
+ }
556
+ case _ =>
557
+ // 2. Create non result query stage
558
+ val result = createNonResultQueryStages(plan)
559
+ var allNewStages = result.newStages
560
+ var newPlan = result.newPlan
561
+ var allChildStagesMaterialized = result.allChildStagesMaterialized
562
+ // 3. Create result stage
563
+ if (allNewStages.isEmpty && allChildStagesMaterialized) {
564
+ val resultStage = newResultQueryStage(resultHandler, newPlan)
565
+ newPlan = resultStage
566
+ allChildStagesMaterialized = false
567
+ allNewStages :+= resultStage
568
+ }
569
+ CreateStageResult (
570
+ newPlan = newPlan,
571
+ allChildStagesMaterialized = allChildStagesMaterialized,
572
+ newStages = allNewStages)
573
+ }
574
+ }
575
+
524
576
/**
525
577
* This method is called recursively to traverse the plan tree bottom-up and create a new query
526
578
* stage or try reusing an existing stage if the current node is an [[Exchange ]] node and all of
@@ -531,7 +583,7 @@ case class AdaptiveSparkPlanExec(
531
583
* 2) Whether the child query stages (if any) of the current node have all been materialized.
532
584
* 3) A list of the new query stages that have been created.
533
585
*/
534
- private def createQueryStages (plan : SparkPlan ): CreateStageResult = plan match {
586
+ private def createNonResultQueryStages (plan : SparkPlan ): CreateStageResult = plan match {
535
587
case e : Exchange =>
536
588
// First have a quick check in the `stageCache` without having to traverse down the node.
537
589
context.stageCache.get(e.canonicalized) match {
@@ -544,7 +596,7 @@ case class AdaptiveSparkPlanExec(
544
596
newStages = if (isMaterialized) Seq .empty else Seq (stage))
545
597
546
598
case _ =>
547
- val result = createQueryStages (e.child)
599
+ val result = createNonResultQueryStages (e.child)
548
600
val newPlan = e.withNewChildren(Seq (result.newPlan)).asInstanceOf [Exchange ]
549
601
// Create a query stage only when all the child query stages are ready.
550
602
if (result.allChildStagesMaterialized) {
@@ -588,14 +640,28 @@ case class AdaptiveSparkPlanExec(
588
640
if (plan.children.isEmpty) {
589
641
CreateStageResult (newPlan = plan, allChildStagesMaterialized = true , newStages = Seq .empty)
590
642
} else {
591
- val results = plan.children.map(createQueryStages )
643
+ val results = plan.children.map(createNonResultQueryStages )
592
644
CreateStageResult (
593
645
newPlan = plan.withNewChildren(results.map(_.newPlan)),
594
646
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
595
647
newStages = results.flatMap(_.newStages))
596
648
}
597
649
}
598
650
651
+ private def newResultQueryStage (
652
+ resultHandler : SparkPlan => Any ,
653
+ plan : SparkPlan ): ResultQueryStageExec = {
654
+ // Run the final plan when there's no more unfinished stages.
655
+ val optimizedRootPlan = applyPhysicalRules(
656
+ optimizeQueryStage(plan, isFinalStage = true ),
657
+ postStageCreationRules(supportsColumnar),
658
+ Some ((planChangeLogger, " AQE Post Stage Creation" )))
659
+ val resultStage = ResultQueryStageExec (currentStageId, optimizedRootPlan, resultHandler)
660
+ currentStageId += 1
661
+ setLogicalLinkForNewQueryStage(resultStage, plan)
662
+ resultStage
663
+ }
664
+
599
665
private def newQueryStage (plan : SparkPlan ): QueryStageExec = {
600
666
val queryStage = plan match {
601
667
case e : Exchange =>
0 commit comments