Skip to content

Commit 2e14d6f

Browse files
author
changgyoopark-db
committed
Optimize
1 parent 692f1b6 commit 2e14d6f

File tree

5 files changed

+132
-118
lines changed

5 files changed

+132
-118
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import io.grpc.stub.StreamObserver
2727
import org.apache.spark.SparkEnv
2828
import org.apache.spark.connect.proto
2929
import org.apache.spark.connect.proto.ExecutePlanResponse
30+
import org.apache.spark.sql.DataFrame
3031
import org.apache.spark.sql.catalyst.InternalRow
31-
import org.apache.spark.sql.classic.{DataFrame, Dataset}
3232
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
3333
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
3434
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -68,12 +68,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
6868
} else {
6969
DoNotCleanup
7070
}
71+
val rel = request.getPlan.getRoot
7172
val dataframe =
72-
Dataset.ofRows(
73-
sessionHolder.session,
74-
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
75-
tracker,
76-
shuffleCleanupMode)
73+
sessionHolder.createDataFrame(rel, planner, Some((tracker, shuffleCleanupMode)))
7774
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
7875
processAsArrowBatches(dataframe, responseObserver, executeHolder)
7976
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,31 +120,27 @@ class SparkConnectPlanner(
120120
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
121121

122122
/**
123-
* The root of the query plan is a relation and we apply the transformations to it. The resolved
124-
* logical plan will not get cached. If the result needs to be cached, use
125-
* `transformRelation(rel, cachePlan = true)` instead.
123+
* The root of the query plan is a relation and we apply the transformations to it.
126124
* @param rel
127125
* The relation to transform.
128126
* @return
129127
* The resolved logical plan.
130128
*/
131129
@DeveloperApi
132-
def transformRelation(rel: proto.Relation): LogicalPlan =
133-
transformRelation(rel, cachePlan = false)
130+
def transformRelation(rel: proto.Relation): LogicalPlan = transformRelationWithCache(rel)._1
134131

135132
/**
136-
* The root of the query plan is a relation and we apply the transformations to it.
133+
* The root of the query plan is a relation and we apply the transformations to it. If the
134+
* relation exists in the plan cache, return the cached plan, but it does not update the plan
135+
* cache.
137136
* @param rel
138137
* The relation to transform.
139-
* @param cachePlan
140-
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
141-
* upon by further dataset transformation. The default is false.
142138
* @return
143-
* The resolved logical plan.
139+
* The resolved logical plan and a flag indicating that the cache was hit.
144140
*/
145141
@DeveloperApi
146-
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
147-
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
142+
def transformRelationWithCache(rel: proto.Relation): (LogicalPlan, Boolean) = {
143+
sessionHolder.usePlanCache(rel) { rel =>
148144
val plan = rel.getRelTypeCase match {
149145
// DataFrame API
150146
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,19 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException}
3232
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
3333
import org.apache.spark.connect.proto
3434
import org.apache.spark.internal.{Logging, LogKeys, MDC}
35-
import org.apache.spark.sql.DataFrame
35+
import org.apache.spark.sql.{DataFrame, Row}
36+
import org.apache.spark.sql.catalyst.QueryPlanningTracker
37+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3638
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
37-
import org.apache.spark.sql.classic.SparkSession
39+
import org.apache.spark.sql.classic.{Dataset, SparkSession}
3840
import org.apache.spark.sql.connect.common.InvalidPlanInput
3941
import org.apache.spark.sql.connect.config.Connect
4042
import org.apache.spark.sql.connect.ml.MLCache
4143
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
44+
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
4245
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
4346
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
47+
import org.apache.spark.sql.execution.{CommandExecutionMode, ShuffleCleanupMode}
4448
import org.apache.spark.sql.streaming.StreamingQueryListener
4549
import org.apache.spark.util.{SystemClock, Utils}
4650

@@ -440,46 +444,74 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
440444
* `spark.connect.session.planCache.enabled` is true.
441445
* @param rel
442446
* The relation to transform.
443-
* @param cachePlan
444-
* Whether to cache the result logical plan.
445447
* @param transform
446448
* Function to transform the relation into a logical plan.
447449
* @return
448-
* The logical plan.
450+
* The logical plan and a flag indicating that the plan cache was hit.
449451
*/
450-
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
451-
transform: proto.Relation => LogicalPlan): LogicalPlan = {
452-
val planCacheEnabled = Option(session)
453-
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
454-
// We only cache plans that have a plan ID.
455-
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId
456-
457-
def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
458-
planCache match {
459-
case Some(cache) if planCacheEnabled && hasPlanId =>
460-
Option(cache.getIfPresent(rel)) match {
461-
case Some(plan) =>
452+
private[connect] def usePlanCache(rel: proto.Relation)(
453+
transform: proto.Relation => LogicalPlan): (LogicalPlan, Boolean) = {
454+
planCache match {
455+
case Some(cache) if canCachePlan(rel) =>
456+
Option(cache.getIfPresent(rel)) match {
457+
case Some(plan) =>
458+
if (isPlanOutdated(plan)) {
459+
// The plan is outdated, therefore remove it from the cache.
460+
cache.invalidate(rel)
461+
} else {
462462
logDebug(s"Using cached plan for relation '$rel': $plan")
463-
Some(plan)
464-
case None => None
465-
}
466-
case _ => None
467-
}
468-
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
469-
planCache match {
470-
case Some(cache) if planCacheEnabled && hasPlanId =>
471-
cache.put(rel, plan)
472-
case _ =>
463+
return (plan, true)
464+
}
465+
case None => ()
466+
}
467+
case _ => ()
468+
}
469+
(transform(rel), false)
470+
}
471+
472+
/**
473+
* Create a data frame from the supplied relation, and update the plan cache.
474+
*
475+
* @param rel
476+
* A proto.Relation to create a data frame.
477+
* @param options
478+
* Options to pass to the data frame.
479+
* @return
480+
* The created data frame.
481+
*/
482+
private[connect] def createDataFrame(
483+
rel: proto.Relation,
484+
planner: SparkConnectPlanner,
485+
options: Option[(QueryPlanningTracker, ShuffleCleanupMode)] = None): DataFrame = {
486+
val (plan, cacheHit) = planner.transformRelationWithCache(rel)
487+
val qe = session.sessionState.executePlan(plan, CommandExecutionMode.SKIP)
488+
val df = new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
489+
if (!cacheHit && planCache.isDefined && canCachePlan(rel)) {
490+
if (df.queryExecution.isLazyAnalysis) {
491+
val plan = df.queryExecution.logical
492+
logDebug(s"Cache a lazyily analyzed logical plan for '$rel': $plan")
493+
planCache.get.put(rel, plan)
494+
} else {
495+
val plan = df.queryExecution.analyzed
496+
logDebug(s"Cache an analyzed logical plan for '$rel': $plan")
497+
planCache.get.put(rel, plan)
473498
}
499+
}
500+
df
501+
}
474502

475-
getPlanCache(rel)
476-
.getOrElse({
477-
val plan = transform(rel)
478-
if (cachePlan) {
479-
putPlanCache(rel, plan)
480-
}
481-
plan
482-
})
503+
// Return true if the plan is outdated and should be removed from the cache.
504+
private def isPlanOutdated(plan: LogicalPlan): Boolean = {
505+
// Currently, nothing is checked.
506+
false
507+
}
508+
509+
// Return true if the plan cache is enabled for the session and the relation.
510+
private def canCachePlan(rel: proto.Relation): Boolean = {
511+
// We only cache plans that have a plan ID.
512+
rel.hasCommon && rel.getCommon.hasPlanId &&
513+
Option(session)
514+
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
483515
}
484516

485517
// For testing. Expose the plan cache for testing purposes.

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,9 @@ import io.grpc.stub.StreamObserver
2323

2424
import org.apache.spark.connect.proto
2525
import org.apache.spark.internal.Logging
26-
import org.apache.spark.sql.Row
27-
import org.apache.spark.sql.catalyst.encoders.RowEncoder
28-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
29-
import org.apache.spark.sql.classic.{DataFrame, Dataset}
3026
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
3127
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
32-
import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
28+
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
3329
import org.apache.spark.sql.types.{DataType, StructType}
3430
import org.apache.spark.util.ArrayImplicits._
3531

@@ -62,25 +58,20 @@ private[connect] class SparkConnectAnalyzeHandler(
6258
val session = sessionHolder.session
6359
val builder = proto.AnalyzePlanResponse.newBuilder()
6460

65-
def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)
66-
67-
def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = {
68-
val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP)
69-
new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
70-
}
71-
7261
request.getAnalyzeCase match {
7362
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
74-
val rel = transformRelation(request.getSchema.getPlan.getRoot)
75-
val schema = getDataFrameWithoutExecuting(rel).schema
63+
val schema =
64+
sessionHolder.createDataFrame(request.getSchema.getPlan.getRoot, planner).schema
7665
builder.setSchema(
7766
proto.AnalyzePlanResponse.Schema
7867
.newBuilder()
7968
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
8069
.build())
8170
case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
82-
val rel = transformRelation(request.getExplain.getPlan.getRoot)
83-
val queryExecution = getDataFrameWithoutExecuting(rel).queryExecution
71+
val queryExecution =
72+
sessionHolder
73+
.createDataFrame(request.getExplain.getPlan.getRoot, planner)
74+
.queryExecution
8475
val explainString = request.getExplain.getExplainMode match {
8576
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
8677
queryExecution.explainString(SimpleMode)
@@ -101,8 +92,8 @@ private[connect] class SparkConnectAnalyzeHandler(
10192
.build())
10293

10394
case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
104-
val rel = transformRelation(request.getTreeString.getPlan.getRoot)
105-
val schema = getDataFrameWithoutExecuting(rel).schema
95+
val schema =
96+
sessionHolder.createDataFrame(request.getTreeString.getPlan.getRoot, planner).schema
10697
val treeString = if (request.getTreeString.hasLevel) {
10798
schema.treeString(request.getTreeString.getLevel)
10899
} else {
@@ -115,26 +106,28 @@ private[connect] class SparkConnectAnalyzeHandler(
115106
.build())
116107

117108
case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
118-
val rel = transformRelation(request.getIsLocal.getPlan.getRoot)
119-
val isLocal = getDataFrameWithoutExecuting(rel).isLocal
109+
val isLocal =
110+
sessionHolder.createDataFrame(request.getIsLocal.getPlan.getRoot, planner).isLocal
120111
builder.setIsLocal(
121112
proto.AnalyzePlanResponse.IsLocal
122113
.newBuilder()
123114
.setIsLocal(isLocal)
124115
.build())
125116

126117
case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
127-
val rel = transformRelation(request.getIsStreaming.getPlan.getRoot)
128-
val isStreaming = getDataFrameWithoutExecuting(rel).isStreaming
118+
val isStreaming =
119+
sessionHolder
120+
.createDataFrame(request.getIsStreaming.getPlan.getRoot, planner)
121+
.isStreaming
129122
builder.setIsStreaming(
130123
proto.AnalyzePlanResponse.IsStreaming
131124
.newBuilder()
132125
.setIsStreaming(isStreaming)
133126
.build())
134127

135128
case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
136-
val rel = transformRelation(request.getInputFiles.getPlan.getRoot)
137-
val inputFiles = getDataFrameWithoutExecuting(rel).inputFiles
129+
val inputFiles =
130+
sessionHolder.createDataFrame(request.getInputFiles.getPlan.getRoot, planner).inputFiles
138131
builder.setInputFiles(
139132
proto.AnalyzePlanResponse.InputFiles
140133
.newBuilder()
@@ -157,27 +150,27 @@ private[connect] class SparkConnectAnalyzeHandler(
157150
.build())
158151

159152
case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
160-
val targetRel = transformRelation(request.getSameSemantics.getTargetPlan.getRoot)
161-
val otherRel = transformRelation(request.getSameSemantics.getOtherPlan.getRoot)
162-
val target = getDataFrameWithoutExecuting(targetRel)
163-
val other = getDataFrameWithoutExecuting(otherRel)
153+
val target =
154+
sessionHolder.createDataFrame(request.getSameSemantics.getTargetPlan.getRoot, planner)
155+
val other =
156+
sessionHolder.createDataFrame(request.getSameSemantics.getOtherPlan.getRoot, planner)
164157
builder.setSameSemantics(
165158
proto.AnalyzePlanResponse.SameSemantics
166159
.newBuilder()
167160
.setResult(target.sameSemantics(other)))
168161

169162
case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
170-
val rel = transformRelation(request.getSemanticHash.getPlan.getRoot)
171-
val semanticHash = getDataFrameWithoutExecuting(rel)
163+
val semanticHash = sessionHolder
164+
.createDataFrame(request.getSemanticHash.getPlan.getRoot, planner)
172165
.semanticHash()
173166
builder.setSemanticHash(
174167
proto.AnalyzePlanResponse.SemanticHash
175168
.newBuilder()
176169
.setResult(semanticHash))
177170

178171
case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
179-
val rel = transformRelation(request.getPersist.getRelation)
180-
val target = getDataFrameWithoutExecuting(rel)
172+
val target = sessionHolder
173+
.createDataFrame(request.getPersist.getRelation, planner)
181174
if (request.getPersist.hasStorageLevel) {
182175
target.persist(
183176
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
@@ -187,8 +180,8 @@ private[connect] class SparkConnectAnalyzeHandler(
187180
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())
188181

189182
case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
190-
val rel = transformRelation(request.getUnpersist.getRelation)
191-
val target = getDataFrameWithoutExecuting(rel)
183+
val target = sessionHolder
184+
.createDataFrame(request.getUnpersist.getRelation, planner)
192185
if (request.getUnpersist.hasBlocking) {
193186
target.unpersist(request.getUnpersist.getBlocking)
194187
} else {
@@ -197,8 +190,8 @@ private[connect] class SparkConnectAnalyzeHandler(
197190
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())
198191

199192
case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
200-
val rel = transformRelation(request.getGetStorageLevel.getRelation)
201-
val target = getDataFrameWithoutExecuting(rel)
193+
val target = sessionHolder
194+
.createDataFrame(request.getGetStorageLevel.getRelation, planner)
202195
val storageLevel = target.storageLevel
203196
builder.setGetStorageLevel(
204197
proto.AnalyzePlanResponse.GetStorageLevel

0 commit comments

Comments
 (0)