Skip to content

Commit 207390b

Browse files
liuzqtcloud-fan
andcommitted
[SPARK-51008][SQL] Add ResultStage for AQE
### What changes were proposed in this pull request? Added ResultQueryStageExec for AQE How does the query plan look like in explain string: ``` AdaptiveSparkPlan isFinalPlan=true +- == Final Plan == ResultQueryStage 2 ------> newly added +- *(5) Project [id#26L] +- *(5) SortMergeJoin [id#26L], [id#27L], Inner :- *(3) Sort [id#26L ASC NULLS FIRST], false, 0 : +- AQEShuffleRead coalesced : +- ShuffleQueryStage 0 : +- Exchange hashpartitioning(id#26L, 200), ENSURE_REQUIREMENTS, [plan_id=247] : +- *(1) Range (0, 25600, step=1, splits=10) +- *(4) Sort [id#27L ASC NULLS FIRST], false, 0 +- AQEShuffleRead coalesced +- ShuffleQueryStage 1 +- Exchange hashpartitioning(id#27L, 200), ENSURE_REQUIREMENTS, [plan_id=257] +- *(2) Ran... ``` How does the query plan look like in Spark UI: <img width="680" alt="Screenshot 2025-02-03 at 4 11 43 PM" src="https://github.com/user-attachments/assets/86946e19-ffdd-42dd-974a-62a8300ddac8" /> ### Why are the changes needed? Currently AQE framework is not fully self-contained since not all plan segments can be put into a query stage: the final "stage" basically executed as a nonAQE plan. This PR added a result query stage for AQE to unify the framework. With this change, we can build more query stage level features, one use case like #44013 (comment) ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? new unit tests. Also exisiting tests which are impacted by this change are updated to keep their original test semantics. ### Was this patch authored or co-authored using generative AI tooling? NO Closes #49715 from liuzqt/SPARK-51008. Lead-authored-by: liuzqt <[email protected]> Co-authored-by: Ziqi Liu <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent cb2732d commit 207390b

File tree

13 files changed

+270
-125
lines changed

13 files changed

+270
-125
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,15 @@ object StaticSQLConf {
210210
.checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].")
211211
.createWithDefault(16)
212212

213+
val RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD =
214+
buildStaticConf("spark.sql.resultQueryStage.maxThreadThreshold")
215+
.internal()
216+
.doc("The maximum degree of parallelism to execute ResultQueryStageExec in AQE")
217+
.version("4.0.0")
218+
.intConf
219+
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
220+
.createWithDefault(1024)
221+
213222
val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length")
214223
.doc("Threshold of SQL length beyond which it will be truncated before adding to " +
215224
"event. Defaults to no truncation. If set to 0, callsite will be logged instead.")

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
20+
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, ExecutorService}
2121
import java.util.concurrent.atomic.AtomicLong
2222

2323
import scala.jdk.CollectionConverters._
@@ -301,15 +301,15 @@ object SQLExecution extends Logging {
301301
* SparkContext local properties are forwarded to execution thread
302302
*/
303303
def withThreadLocalCaptured[T](
304-
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
304+
sparkSession: SparkSession, exec: ExecutorService) (body: => T): CompletableFuture[T] = {
305305
val activeSession = sparkSession
306306
val sc = sparkSession.sparkContext
307307
val localProps = Utils.cloneProperties(sc.getLocalProperties)
308308
// `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect
309309
// mode, we default back to the resources of the current Spark session.
310310
val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
311311
activeSession.artifactManager.state)
312-
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
312+
CompletableFuture.supplyAsync(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
313313
val originalSession = SparkSession.getActiveSession
314314
val originalLocalProps = sc.getLocalProperties
315315
SparkSession.setActiveSession(activeSession)
@@ -326,6 +326,6 @@ object SQLExecution extends Logging {
326326
SparkSession.clearActiveSession()
327327
}
328328
res
329-
})
329+
}, exec)
330330
}
331331
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 121 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,11 @@ case class AdaptiveSparkPlanExec(
268268

269269
def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)
270270

271-
private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
272-
if (isFinalPlan) return currentPhysicalPlan
273-
271+
/**
272+
* Run `fun` on finalized physical plan
273+
*/
274+
def withFinalPlanUpdate[T](fun: SparkPlan => T): T = lock.synchronized {
275+
_isFinalPlan = false
274276
// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
275277
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
276278
// created in the middle of the execution.
@@ -279,7 +281,7 @@ case class AdaptiveSparkPlanExec(
279281
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
280282
// during `initialPlan`
281283
var currentLogicalPlan = inputPlan.logicalLink.get
282-
var result = createQueryStages(currentPhysicalPlan)
284+
var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true)
283285
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
284286
val errors = new mutable.ArrayBuffer[Throwable]()
285287
var stagesToReplace = Seq.empty[QueryStageExec]
@@ -344,56 +346,53 @@ case class AdaptiveSparkPlanExec(
344346
if (errors.nonEmpty) {
345347
cleanUpAndThrowException(errors.toSeq, None)
346348
}
347-
348-
// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
349-
// than that of the current plan; otherwise keep the current physical plan together with
350-
// the current logical plan since the physical plan's logical links point to the logical
351-
// plan it has originated from.
352-
// Meanwhile, we keep a list of the query stages that have been created since last plan
353-
// update, which stands for the "semantic gap" between the current logical and physical
354-
// plans. And each time before re-planning, we replace the corresponding nodes in the
355-
// current logical plan with logical query stages to make it semantically in sync with
356-
// the current physical plan. Once a new plan is adopted and both logical and physical
357-
// plans are updated, we can clear the query stage list because at this point the two plans
358-
// are semantically and physically in sync again.
359-
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
360-
val afterReOptimize = reOptimize(logicalPlan)
361-
if (afterReOptimize.isDefined) {
362-
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
363-
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
364-
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
365-
if (newCost < origCost ||
366-
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
367-
lazy val plans =
368-
sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
369-
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
370-
cleanUpTempTags(newPhysicalPlan)
371-
currentPhysicalPlan = newPhysicalPlan
372-
currentLogicalPlan = newLogicalPlan
373-
stagesToReplace = Seq.empty[QueryStageExec]
349+
if (!currentPhysicalPlan.isInstanceOf[ResultQueryStageExec]) {
350+
// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
351+
// than that of the current plan; otherwise keep the current physical plan together with
352+
// the current logical plan since the physical plan's logical links point to the logical
353+
// plan it has originated from.
354+
// Meanwhile, we keep a list of the query stages that have been created since last plan
355+
// update, which stands for the "semantic gap" between the current logical and physical
356+
// plans. And each time before re-planning, we replace the corresponding nodes in the
357+
// current logical plan with logical query stages to make it semantically in sync with
358+
// the current physical plan. Once a new plan is adopted and both logical and physical
359+
// plans are updated, we can clear the query stage list because at this point the two
360+
// plans are semantically and physically in sync again.
361+
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
362+
val afterReOptimize = reOptimize(logicalPlan)
363+
if (afterReOptimize.isDefined) {
364+
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
365+
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
366+
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
367+
if (newCost < origCost ||
368+
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
369+
lazy val plans = sideBySide(
370+
currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
371+
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
372+
cleanUpTempTags(newPhysicalPlan)
373+
currentPhysicalPlan = newPhysicalPlan
374+
currentLogicalPlan = newLogicalPlan
375+
stagesToReplace = Seq.empty[QueryStageExec]
376+
}
374377
}
375378
}
376379
// Now that some stages have finished, we can try creating new stages.
377-
result = createQueryStages(currentPhysicalPlan)
380+
result = createQueryStages(fun, currentPhysicalPlan, firstRun = false)
378381
}
379-
380-
// Run the final plan when there's no more unfinished stages.
381-
currentPhysicalPlan = applyPhysicalRules(
382-
optimizeQueryStage(result.newPlan, isFinalStage = true),
383-
postStageCreationRules(supportsColumnar),
384-
Some((planChangeLogger, "AQE Post Stage Creation")))
385-
_isFinalPlan = true
386-
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
387-
currentPhysicalPlan
388382
}
383+
_isFinalPlan = true
384+
finalPlanUpdate
385+
// Dereference the result so it can be GCed. After this resultStage.isMaterialized will return
386+
// false, which is expected. If we want to collect result again, we should invoke
387+
// `withFinalPlanUpdate` and pass another result handler and we will create a new result stage.
388+
currentPhysicalPlan.asInstanceOf[ResultQueryStageExec].resultOption.getAndUpdate(_ => None)
389+
.get.asInstanceOf[T]
389390
}
390391

391392
// Use a lazy val to avoid this being called more than once.
392393
@transient private lazy val finalPlanUpdate: Unit = {
393-
// Subqueries that don't belong to any query stage of the main query will execute after the
394-
// last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure
395-
// the newly generated nodes of those subqueries are updated.
396-
if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
394+
// Do final plan update after result stage has materialized.
395+
if (shouldUpdatePlan) {
397396
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
398397
}
399398
logOnLevel(log"Final plan:\n${MDC(QUERY_PLAN, currentPhysicalPlan)}")
@@ -426,13 +425,6 @@ case class AdaptiveSparkPlanExec(
426425
}
427426
}
428427

429-
private def withFinalPlanUpdate[T](fun: SparkPlan => T): T = {
430-
val plan = getFinalPhysicalPlan()
431-
val result = fun(plan)
432-
finalPlanUpdate
433-
result
434-
}
435-
436428
protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")
437429

438430
override def generateTreeString(
@@ -521,6 +513,66 @@ case class AdaptiveSparkPlanExec(
521513
this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
522514
}
523515

516+
/**
517+
* We separate stage creation of result and non-result stages because there are several edge cases
518+
* of result stage creation:
519+
* - existing ResultQueryStage created in previous `withFinalPlanUpdate`.
520+
* - the root node is a non-result query stage and we have to create query result stage on top of
521+
* it.
522+
* - we create a non-result query stage as root node and the stage is immediately materialized
523+
* due to stage resue, therefore we have to create a result stage right after.
524+
*
525+
* This method wraps around `createNonResultQueryStages`, the general logic is:
526+
* - Early return if ResultQueryStageExec already created before.
527+
* - Create non result query stage if possible.
528+
* - Try to create result query stage when there is no new non-result query stage created and all
529+
* stages are materialized.
530+
*/
531+
private def createQueryStages(
532+
resultHandler: SparkPlan => Any,
533+
plan: SparkPlan,
534+
firstRun: Boolean): CreateStageResult = {
535+
plan match {
536+
// 1. ResultQueryStageExec is already created, no need to create non-result stages
537+
case resultStage @ ResultQueryStageExec(_, optimizedPlan, _) =>
538+
assertStageNotFailed(resultStage)
539+
if (firstRun) {
540+
// There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
541+
// e.g, when we do `df.collect` multiple times. Here we create a new result stage to
542+
// execute it again, as the handler function can be different.
543+
val newResultStage = ResultQueryStageExec(currentStageId, optimizedPlan, resultHandler)
544+
currentStageId += 1
545+
setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
546+
CreateStageResult(newPlan = newResultStage,
547+
allChildStagesMaterialized = false,
548+
newStages = Seq(newResultStage))
549+
} else {
550+
// We will hit this branch after we've created result query stage in the AQE loop, we
551+
// should do nothing.
552+
CreateStageResult(newPlan = resultStage,
553+
allChildStagesMaterialized = resultStage.isMaterialized,
554+
newStages = Seq.empty)
555+
}
556+
case _ =>
557+
// 2. Create non result query stage
558+
val result = createNonResultQueryStages(plan)
559+
var allNewStages = result.newStages
560+
var newPlan = result.newPlan
561+
var allChildStagesMaterialized = result.allChildStagesMaterialized
562+
// 3. Create result stage
563+
if (allNewStages.isEmpty && allChildStagesMaterialized) {
564+
val resultStage = newResultQueryStage(resultHandler, newPlan)
565+
newPlan = resultStage
566+
allChildStagesMaterialized = false
567+
allNewStages :+= resultStage
568+
}
569+
CreateStageResult(
570+
newPlan = newPlan,
571+
allChildStagesMaterialized = allChildStagesMaterialized,
572+
newStages = allNewStages)
573+
}
574+
}
575+
524576
/**
525577
* This method is called recursively to traverse the plan tree bottom-up and create a new query
526578
* stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of
@@ -531,7 +583,7 @@ case class AdaptiveSparkPlanExec(
531583
* 2) Whether the child query stages (if any) of the current node have all been materialized.
532584
* 3) A list of the new query stages that have been created.
533585
*/
534-
private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match {
586+
private def createNonResultQueryStages(plan: SparkPlan): CreateStageResult = plan match {
535587
case e: Exchange =>
536588
// First have a quick check in the `stageCache` without having to traverse down the node.
537589
context.stageCache.get(e.canonicalized) match {
@@ -544,7 +596,7 @@ case class AdaptiveSparkPlanExec(
544596
newStages = if (isMaterialized) Seq.empty else Seq(stage))
545597

546598
case _ =>
547-
val result = createQueryStages(e.child)
599+
val result = createNonResultQueryStages(e.child)
548600
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
549601
// Create a query stage only when all the child query stages are ready.
550602
if (result.allChildStagesMaterialized) {
@@ -588,14 +640,28 @@ case class AdaptiveSparkPlanExec(
588640
if (plan.children.isEmpty) {
589641
CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty)
590642
} else {
591-
val results = plan.children.map(createQueryStages)
643+
val results = plan.children.map(createNonResultQueryStages)
592644
CreateStageResult(
593645
newPlan = plan.withNewChildren(results.map(_.newPlan)),
594646
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
595647
newStages = results.flatMap(_.newStages))
596648
}
597649
}
598650

651+
private def newResultQueryStage(
652+
resultHandler: SparkPlan => Any,
653+
plan: SparkPlan): ResultQueryStageExec = {
654+
// Run the final plan when there's no more unfinished stages.
655+
val optimizedRootPlan = applyPhysicalRules(
656+
optimizeQueryStage(plan, isFinalStage = true),
657+
postStageCreationRules(supportsColumnar),
658+
Some((planChangeLogger, "AQE Post Stage Creation")))
659+
val resultStage = ResultQueryStageExec(currentStageId, optimizedRootPlan, resultHandler)
660+
currentStageId += 1
661+
setLogicalLinkForNewQueryStage(resultStage, plan)
662+
resultStage
663+
}
664+
599665
private def newQueryStage(plan: SparkPlan): QueryStageExec = {
600666
val queryStage = plan match {
601667
case e: Exchange =>

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,12 @@ trait AdaptiveSparkPlanHelper {
129129
}
130130

131131
/**
132-
* Strip the executePlan of AdaptiveSparkPlanExec leaf node.
132+
* Strip the top [[AdaptiveSparkPlanExec]] and [[ResultQueryStageExec]] nodes off
133+
* the [[SparkPlan]].
133134
*/
134135
def stripAQEPlan(p: SparkPlan): SparkPlan = p match {
135-
case a: AdaptiveSparkPlanExec => a.executedPlan
136+
case a: AdaptiveSparkPlanExec => stripAQEPlan(a.executedPlan)
137+
case ResultQueryStageExec(_, plan, _) => plan
136138
case other => other
137139
}
138140
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.adaptive
1919

2020
import java.util.concurrent.atomic.AtomicReference
2121

22+
import scala.concurrent.ExecutionContext
2223
import scala.concurrent.Future
24+
import scala.concurrent.Promise
2325

2426
import org.apache.spark.{MapOutputStatistics, SparkException}
2527
import org.apache.spark.broadcast.Broadcast
@@ -32,7 +34,10 @@ import org.apache.spark.sql.columnar.CachedBatch
3234
import org.apache.spark.sql.execution._
3335
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
3436
import org.apache.spark.sql.execution.exchange._
37+
import org.apache.spark.sql.internal.SQLConf
38+
import org.apache.spark.sql.internal.StaticSQLConf
3539
import org.apache.spark.sql.vectorized.ColumnarBatch
40+
import org.apache.spark.util.ThreadUtils
3641

3742
/**
3843
* A query stage is an independent subgraph of the query plan. AQE framework will materialize its
@@ -303,3 +308,43 @@ case class TableCacheQueryStageExec(
303308

304309
override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics
305310
}
311+
312+
case class ResultQueryStageExec(
313+
override val id: Int,
314+
override val plan: SparkPlan,
315+
resultHandler: SparkPlan => Any) extends QueryStageExec {
316+
317+
override def resetMetrics(): Unit = {
318+
plan.resetMetrics()
319+
}
320+
321+
override protected def doMaterialize(): Future[Any] = {
322+
val javaFuture = SQLExecution.withThreadLocalCaptured(
323+
session,
324+
ResultQueryStageExec.executionContext) {
325+
resultHandler(plan)
326+
}
327+
val scalaPromise: Promise[Any] = Promise()
328+
javaFuture.whenComplete { (result: Any, exception: Throwable) =>
329+
if (exception != null) {
330+
scalaPromise.failure(exception match {
331+
case completionException: java.util.concurrent.CompletionException =>
332+
completionException.getCause
333+
case ex => ex
334+
})
335+
} else {
336+
scalaPromise.success(result)
337+
}
338+
}
339+
scalaPromise.future
340+
}
341+
342+
// Result stage could be any SparkPlan, so we don't have a specific runtime statistics for it.
343+
override def getRuntimeStatistics: Statistics = Statistics.DUMMY
344+
}
345+
346+
object ResultQueryStageExec {
347+
private[execution] val executionContext = ExecutionContext.fromExecutorService(
348+
ThreadUtils.newDaemonCachedThreadPool("ResultQueryStageExecution",
349+
SQLConf.get.getConf(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD)))
350+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ object SparkPlanGraph {
106106
buildSparkPlanGraphNode(
107107
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
108108
}
109-
case "TableCacheQueryStage" =>
109+
case "TableCacheQueryStage" | "ResultQueryStage" =>
110110
buildSparkPlanGraphNode(
111111
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
112112
case "Subquery" if subgraph != null =>

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
16591659
_.nodeName.contains("TableCacheQueryStage"))
16601660
val aqeNode = findNodeInSparkPlanInfo(inMemoryScanNode.get,
16611661
_.nodeName.contains("AdaptiveSparkPlan"))
1662-
aqeNode.get.children.head.nodeName == "AQEShuffleRead"
1662+
val aqePlanRoot = findNodeInSparkPlanInfo(inMemoryScanNode.get,
1663+
_.nodeName.contains("ResultQueryStage"))
1664+
aqePlanRoot.get.children.head.nodeName == "AQEShuffleRead"
16631665
}
16641666

16651667
withTempView("t0", "t1", "t2") {

0 commit comments

Comments
 (0)