Skip to content

Commit

Permalink
[SPARK-51008][SQL] Add ResultStage for AQE
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Added ResultQueryStageExec for AQE

How does the query plan look like in explain string:
```
AdaptiveSparkPlan isFinalPlan=true
+- == Final Plan ==
   ResultQueryStage 2 ------> newly added
   +- *(5) Project [id#26L]
      +- *(5) SortMergeJoin [id#26L], [id#27L], Inner
         :- *(3) Sort [id#26L ASC NULLS FIRST], false, 0
         :  +- AQEShuffleRead coalesced
         :     +- ShuffleQueryStage 0
         :        +- Exchange hashpartitioning(id#26L, 200), ENSURE_REQUIREMENTS, [plan_id=247]
         :           +- *(1) Range (0, 25600, step=1, splits=10)
         +- *(4) Sort [id#27L ASC NULLS FIRST], false, 0
            +- AQEShuffleRead coalesced
               +- ShuffleQueryStage 1
                  +- Exchange hashpartitioning(id#27L, 200), ENSURE_REQUIREMENTS, [plan_id=257]
                     +- *(2) Ran...

```
How does the query plan look like in Spark UI:

<img width="680" alt="Screenshot 2025-02-03 at 4 11 43 PM" src="https://github.com/user-attachments/assets/86946e19-ffdd-42dd-974a-62a8300ddac8" />

### Why are the changes needed?

Currently AQE framework is not fully self-contained since not all plan segments can be put into a query stage: the final "stage" basically executed as a nonAQE plan. This PR added a result query stage for AQE to unify the framework. With this change, we can build more query stage level features, one use case like #44013 (comment)

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
new unit tests.

Also exisiting tests which are impacted by this change are updated to keep their original test semantics.

### Was this patch authored or co-authored using generative AI tooling?
NO

Closes #49715 from liuzqt/SPARK-51008.

Lead-authored-by: liuzqt <liuzq12@hotmail.com>
Co-authored-by: Ziqi Liu <ziqi.liu@databricks.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
3 people committed Feb 12, 2025
1 parent cb2732d commit 207390b
Show file tree
Hide file tree
Showing 13 changed files with 270 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ object StaticSQLConf {
.checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].")
.createWithDefault(16)

val RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD =
buildStaticConf("spark.sql.resultQueryStage.maxThreadThreshold")
.internal()
.doc("The maximum degree of parallelism to execute ResultQueryStageExec in AQE")
.version("4.0.0")
.intConf
.checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].")
.createWithDefault(1024)

val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length")
.doc("Threshold of SQL length beyond which it will be truncated before adding to " +
"event. Defaults to no truncation. If set to 0, callsite will be logged instead.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, ExecutorService}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -301,15 +301,15 @@ object SQLExecution extends Logging {
* SparkContext local properties are forwarded to execution thread
*/
def withThreadLocalCaptured[T](
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
sparkSession: SparkSession, exec: ExecutorService) (body: => T): CompletableFuture[T] = {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
// `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect
// mode, we default back to the resources of the current Spark session.
val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
activeSession.artifactManager.state)
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
CompletableFuture.supplyAsync(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
Expand All @@ -326,6 +326,6 @@ object SQLExecution extends Logging {
SparkSession.clearActiveSession()
}
res
})
}, exec)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ case class AdaptiveSparkPlanExec(

def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)

private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan

/**
* Run `fun` on finalized physical plan
*/
def withFinalPlanUpdate[T](fun: SparkPlan => T): T = lock.synchronized {
_isFinalPlan = false
// In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
// `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
// created in the middle of the execution.
Expand All @@ -279,7 +281,7 @@ case class AdaptiveSparkPlanExec(
// Use inputPlan logicalLink here in case some top level physical nodes may be removed
// during `initialPlan`
var currentLogicalPlan = inputPlan.logicalLink.get
var result = createQueryStages(currentPhysicalPlan)
var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true)
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
val errors = new mutable.ArrayBuffer[Throwable]()
var stagesToReplace = Seq.empty[QueryStageExec]
Expand Down Expand Up @@ -344,56 +346,53 @@ case class AdaptiveSparkPlanExec(
if (errors.nonEmpty) {
cleanUpAndThrowException(errors.toSeq, None)
}

// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
// than that of the current plan; otherwise keep the current physical plan together with
// the current logical plan since the physical plan's logical links point to the logical
// plan it has originated from.
// Meanwhile, we keep a list of the query stages that have been created since last plan
// update, which stands for the "semantic gap" between the current logical and physical
// plans. And each time before re-planning, we replace the corresponding nodes in the
// current logical plan with logical query stages to make it semantically in sync with
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two plans
// are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
val afterReOptimize = reOptimize(logicalPlan)
if (afterReOptimize.isDefined) {
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
lazy val plans =
sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
if (!currentPhysicalPlan.isInstanceOf[ResultQueryStageExec]) {
// Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
// than that of the current plan; otherwise keep the current physical plan together with
// the current logical plan since the physical plan's logical links point to the logical
// plan it has originated from.
// Meanwhile, we keep a list of the query stages that have been created since last plan
// update, which stands for the "semantic gap" between the current logical and physical
// plans. And each time before re-planning, we replace the corresponding nodes in the
// current logical plan with logical query stages to make it semantically in sync with
// the current physical plan. Once a new plan is adopted and both logical and physical
// plans are updated, we can clear the query stage list because at this point the two
// plans are semantically and physically in sync again.
val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
val afterReOptimize = reOptimize(logicalPlan)
if (afterReOptimize.isDefined) {
val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
if (newCost < origCost ||
(newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
lazy val plans = sideBySide(
currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n")
logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}")
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]
}
}
}
// Now that some stages have finished, we can try creating new stages.
result = createQueryStages(currentPhysicalPlan)
result = createQueryStages(fun, currentPhysicalPlan, firstRun = false)
}

// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(
optimizeQueryStage(result.newPlan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
_isFinalPlan = true
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
currentPhysicalPlan
}
_isFinalPlan = true
finalPlanUpdate
// Dereference the result so it can be GCed. After this resultStage.isMaterialized will return
// false, which is expected. If we want to collect result again, we should invoke
// `withFinalPlanUpdate` and pass another result handler and we will create a new result stage.
currentPhysicalPlan.asInstanceOf[ResultQueryStageExec].resultOption.getAndUpdate(_ => None)
.get.asInstanceOf[T]
}

// Use a lazy val to avoid this being called more than once.
@transient private lazy val finalPlanUpdate: Unit = {
// Subqueries that don't belong to any query stage of the main query will execute after the
// last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure
// the newly generated nodes of those subqueries are updated.
if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
// Do final plan update after result stage has materialized.
if (shouldUpdatePlan) {
getExecutionId.foreach(onUpdatePlan(_, Seq.empty))
}
logOnLevel(log"Final plan:\n${MDC(QUERY_PLAN, currentPhysicalPlan)}")
Expand Down Expand Up @@ -426,13 +425,6 @@ case class AdaptiveSparkPlanExec(
}
}

private def withFinalPlanUpdate[T](fun: SparkPlan => T): T = {
val plan = getFinalPhysicalPlan()
val result = fun(plan)
finalPlanUpdate
result
}

protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")

override def generateTreeString(
Expand Down Expand Up @@ -521,6 +513,66 @@ case class AdaptiveSparkPlanExec(
this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan
}

/**
* We separate stage creation of result and non-result stages because there are several edge cases
* of result stage creation:
* - existing ResultQueryStage created in previous `withFinalPlanUpdate`.
* - the root node is a non-result query stage and we have to create query result stage on top of
* it.
* - we create a non-result query stage as root node and the stage is immediately materialized
* due to stage resue, therefore we have to create a result stage right after.
*
* This method wraps around `createNonResultQueryStages`, the general logic is:
* - Early return if ResultQueryStageExec already created before.
* - Create non result query stage if possible.
* - Try to create result query stage when there is no new non-result query stage created and all
* stages are materialized.
*/
private def createQueryStages(
resultHandler: SparkPlan => Any,
plan: SparkPlan,
firstRun: Boolean): CreateStageResult = {
plan match {
// 1. ResultQueryStageExec is already created, no need to create non-result stages
case resultStage @ ResultQueryStageExec(_, optimizedPlan, _) =>
assertStageNotFailed(resultStage)
if (firstRun) {
// There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
// e.g, when we do `df.collect` multiple times. Here we create a new result stage to
// execute it again, as the handler function can be different.
val newResultStage = ResultQueryStageExec(currentStageId, optimizedPlan, resultHandler)
currentStageId += 1
setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
CreateStageResult(newPlan = newResultStage,
allChildStagesMaterialized = false,
newStages = Seq(newResultStage))
} else {
// We will hit this branch after we've created result query stage in the AQE loop, we
// should do nothing.
CreateStageResult(newPlan = resultStage,
allChildStagesMaterialized = resultStage.isMaterialized,
newStages = Seq.empty)
}
case _ =>
// 2. Create non result query stage
val result = createNonResultQueryStages(plan)
var allNewStages = result.newStages
var newPlan = result.newPlan
var allChildStagesMaterialized = result.allChildStagesMaterialized
// 3. Create result stage
if (allNewStages.isEmpty && allChildStagesMaterialized) {
val resultStage = newResultQueryStage(resultHandler, newPlan)
newPlan = resultStage
allChildStagesMaterialized = false
allNewStages :+= resultStage
}
CreateStageResult(
newPlan = newPlan,
allChildStagesMaterialized = allChildStagesMaterialized,
newStages = allNewStages)
}
}

/**
* This method is called recursively to traverse the plan tree bottom-up and create a new query
* stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of
Expand All @@ -531,7 +583,7 @@ case class AdaptiveSparkPlanExec(
* 2) Whether the child query stages (if any) of the current node have all been materialized.
* 3) A list of the new query stages that have been created.
*/
private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match {
private def createNonResultQueryStages(plan: SparkPlan): CreateStageResult = plan match {
case e: Exchange =>
// First have a quick check in the `stageCache` without having to traverse down the node.
context.stageCache.get(e.canonicalized) match {
Expand All @@ -544,7 +596,7 @@ case class AdaptiveSparkPlanExec(
newStages = if (isMaterialized) Seq.empty else Seq(stage))

case _ =>
val result = createQueryStages(e.child)
val result = createNonResultQueryStages(e.child)
val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange]
// Create a query stage only when all the child query stages are ready.
if (result.allChildStagesMaterialized) {
Expand Down Expand Up @@ -588,14 +640,28 @@ case class AdaptiveSparkPlanExec(
if (plan.children.isEmpty) {
CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty)
} else {
val results = plan.children.map(createQueryStages)
val results = plan.children.map(createNonResultQueryStages)
CreateStageResult(
newPlan = plan.withNewChildren(results.map(_.newPlan)),
allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
newStages = results.flatMap(_.newStages))
}
}

private def newResultQueryStage(
resultHandler: SparkPlan => Any,
plan: SparkPlan): ResultQueryStageExec = {
// Run the final plan when there's no more unfinished stages.
val optimizedRootPlan = applyPhysicalRules(
optimizeQueryStage(plan, isFinalStage = true),
postStageCreationRules(supportsColumnar),
Some((planChangeLogger, "AQE Post Stage Creation")))
val resultStage = ResultQueryStageExec(currentStageId, optimizedRootPlan, resultHandler)
currentStageId += 1
setLogicalLinkForNewQueryStage(resultStage, plan)
resultStage
}

private def newQueryStage(plan: SparkPlan): QueryStageExec = {
val queryStage = plan match {
case e: Exchange =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ trait AdaptiveSparkPlanHelper {
}

/**
* Strip the executePlan of AdaptiveSparkPlanExec leaf node.
* Strip the top [[AdaptiveSparkPlanExec]] and [[ResultQueryStageExec]] nodes off
* the [[SparkPlan]].
*/
def stripAQEPlan(p: SparkPlan): SparkPlan = p match {
case a: AdaptiveSparkPlanExec => a.executedPlan
case a: AdaptiveSparkPlanExec => stripAQEPlan(a.executedPlan)
case ResultQueryStageExec(_, plan, _) => plan
case other => other
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.adaptive

import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise

import org.apache.spark.{MapOutputStatistics, SparkException}
import org.apache.spark.broadcast.Broadcast
Expand All @@ -32,7 +34,10 @@ import org.apache.spark.sql.columnar.CachedBatch
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils

/**
* A query stage is an independent subgraph of the query plan. AQE framework will materialize its
Expand Down Expand Up @@ -303,3 +308,43 @@ case class TableCacheQueryStageExec(

override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics
}

case class ResultQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
resultHandler: SparkPlan => Any) extends QueryStageExec {

override def resetMetrics(): Unit = {
plan.resetMetrics()
}

override protected def doMaterialize(): Future[Any] = {
val javaFuture = SQLExecution.withThreadLocalCaptured(
session,
ResultQueryStageExec.executionContext) {
resultHandler(plan)
}
val scalaPromise: Promise[Any] = Promise()
javaFuture.whenComplete { (result: Any, exception: Throwable) =>
if (exception != null) {
scalaPromise.failure(exception match {
case completionException: java.util.concurrent.CompletionException =>
completionException.getCause
case ex => ex
})
} else {
scalaPromise.success(result)
}
}
scalaPromise.future
}

// Result stage could be any SparkPlan, so we don't have a specific runtime statistics for it.
override def getRuntimeStatistics: Statistics = Statistics.DUMMY
}

object ResultQueryStageExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("ResultQueryStageExecution",
SQLConf.get.getConf(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD)))
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ object SparkPlanGraph {
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
}
case "TableCacheQueryStage" =>
case "TableCacheQueryStage" | "ResultQueryStage" =>
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" if subgraph != null =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1659,7 +1659,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
_.nodeName.contains("TableCacheQueryStage"))
val aqeNode = findNodeInSparkPlanInfo(inMemoryScanNode.get,
_.nodeName.contains("AdaptiveSparkPlan"))
aqeNode.get.children.head.nodeName == "AQEShuffleRead"
val aqePlanRoot = findNodeInSparkPlanInfo(inMemoryScanNode.get,
_.nodeName.contains("ResultQueryStage"))
aqePlanRoot.get.children.head.nodeName == "AQEShuffleRead"
}

withTempView("t0", "t1", "t2") {
Expand Down
Loading

0 comments on commit 207390b

Please sign in to comment.