Skip to content

Commit 7b7430f

Browse files
author
changgyoopark-db
committed
Impl
1 parent 1a49237 commit 7b7430f

File tree

5 files changed

+141
-112
lines changed

5 files changed

+141
-112
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
6868
} else {
6969
DoNotCleanup
7070
}
71+
val rel = request.getPlan.getRoot
7172
val dataframe =
72-
Dataset.ofRows(
73-
sessionHolder.session,
74-
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
75-
tracker,
76-
shuffleCleanupMode)
73+
sessionHolder
74+
.updatePlanCache(
75+
rel,
76+
Dataset.ofRows(session, planner.transformRelation(rel), tracker, shuffleCleanupMode))
7777
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
7878
processAsArrowBatches(dataframe, responseObserver, executeHolder)
7979
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,32 +117,16 @@ class SparkConnectPlanner(
117117
private lazy val pythonExec =
118118
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
119119

120-
/**
121-
* The root of the query plan is a relation and we apply the transformations to it. The resolved
122-
* logical plan will not get cached. If the result needs to be cached, use
123-
* `transformRelation(rel, cachePlan = true)` instead.
124-
* @param rel
125-
* The relation to transform.
126-
* @return
127-
* The resolved logical plan.
128-
*/
129-
@DeveloperApi
130-
def transformRelation(rel: proto.Relation): LogicalPlan =
131-
transformRelation(rel, cachePlan = false)
132-
133120
/**
134121
* The root of the query plan is a relation and we apply the transformations to it.
135122
* @param rel
136123
* The relation to transform.
137-
* @param cachePlan
138-
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
139-
* upon by further dataset transformation. The default is false.
140124
* @return
141125
* The resolved logical plan.
142126
*/
143127
@DeveloperApi
144-
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
145-
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
128+
def transformRelation(rel: proto.Relation): LogicalPlan = {
129+
sessionHolder.usePlanCache(rel) { rel =>
146130
val plan = rel.getRelTypeCase match {
147131
// DataFrame API
148132
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -441,46 +441,67 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
441441
* `spark.connect.session.planCache.enabled` is true.
442442
* @param rel
443443
* The relation to transform.
444-
* @param cachePlan
445-
* Whether to cache the result logical plan.
446444
* @param transform
447445
* Function to transform the relation into a logical plan.
448446
* @return
449447
* The logical plan.
450448
*/
451-
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
449+
private[connect] def usePlanCache(rel: proto.Relation)(
452450
transform: proto.Relation => LogicalPlan): LogicalPlan = {
453-
val planCacheEnabled = Option(session)
454-
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
455-
// We only cache plans that have a plan ID.
456-
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId
457-
458-
def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
459-
planCache match {
460-
case Some(cache) if planCacheEnabled && hasPlanId =>
461-
Option(cache.getIfPresent(rel)) match {
462-
case Some(plan) =>
463-
logDebug(s"Using cached plan for relation '$rel': $plan")
464-
Some(plan)
465-
case None => None
466-
}
467-
case _ => None
468-
}
469-
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
470-
planCache match {
471-
case Some(cache) if planCacheEnabled && hasPlanId =>
472-
cache.put(rel, plan)
473-
case _ =>
474-
}
451+
val cachedPlan = planCache match {
452+
case Some(cache) if planCacheEnabled(rel) =>
453+
Option(cache.getIfPresent(rel)) match {
454+
case Some(plan) =>
455+
logDebug(s"Using cached plan for relation '$rel': $plan")
456+
Some(plan)
457+
case None => None
458+
}
459+
case _ => None
460+
}
461+
cachedPlan.getOrElse(transform(rel))
462+
}
475463

476-
getPlanCache(rel)
477-
.getOrElse({
478-
val plan = transform(rel)
479-
if (cachePlan) {
480-
putPlanCache(rel, plan)
464+
/**
465+
* Update the plan cache with the supplied data frame.
466+
*
467+
* @param rel
468+
* A proto.Relation that is used as the key for the cache.
469+
* @param df
470+
* A data frame containing the corresponding logical plan.
471+
* @return
472+
* The supplied data frame is returned.
473+
*/
474+
private[connect] def updatePlanCache(rel: proto.Relation, df: DataFrame): DataFrame = {
475+
if (planCache.isDefined && planCacheEnabled(rel)) {
476+
val plan = if (df.queryExecution.isLazyAnalysis) {
477+
// Try to cache the unanalyzed plan if the plan is intended to be lazily analyzed.
478+
if (planCache.get.getIfPresent(rel) == null) {
479+
Some(df.queryExecution.logical)
480+
} else {
481+
None
481482
}
482-
plan
483-
})
483+
} else if (df.queryExecution.logical.analyzed) {
484+
// The plan was analyzed during transformation or the cache was hit.
485+
if (planCache.get.getIfPresent(rel) == null) {
486+
Some(df.queryExecution.analyzed)
487+
} else {
488+
None
489+
}
490+
} else {
491+
// Being not `isLazyAnalysis` and not analyzed implies that the plan is not in the cache.
492+
Some(df.queryExecution.analyzed)
493+
}
494+
plan.foreach(p => planCache.get.put(rel, p))
495+
}
496+
df
497+
}
498+
499+
// Return true if the plan cache is enabled for the session and the relation.
500+
private def planCacheEnabled(rel: proto.Relation): Boolean = {
501+
// We only cache plans that have a plan ID.
502+
rel.hasCommon && rel.getCommon.hasPlanId &&
503+
Option(session)
504+
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
484505
}
485506

486507
// For testing. Expose the plan cache for testing purposes.

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@ private[connect] class SparkConnectAnalyzeHandler(
5959
val session = sessionHolder.session
6060
val builder = proto.AnalyzePlanResponse.newBuilder()
6161

62-
def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)
62+
def transformRelation(rel: proto.Relation) = planner.transformRelation(rel)
6363

6464
request.getAnalyzeCase match {
6565
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
66-
val schema = Dataset
67-
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
66+
val rel = request.getSchema.getPlan.getRoot
67+
val schema = sessionHolder
68+
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
6869
.schema
6970
builder.setSchema(
7071
proto.AnalyzePlanResponse.Schema
@@ -73,8 +74,9 @@ private[connect] class SparkConnectAnalyzeHandler(
7374
.build())
7475

7576
case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
76-
val queryExecution = Dataset
77-
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
77+
val rel = request.getExplain.getPlan.getRoot
78+
val queryExecution = sessionHolder
79+
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
7880
.queryExecution
7981
val explainString = request.getExplain.getExplainMode match {
8082
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
@@ -96,8 +98,9 @@ private[connect] class SparkConnectAnalyzeHandler(
9698
.build())
9799

98100
case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
99-
val schema = Dataset
100-
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
101+
val rel = request.getTreeString.getPlan.getRoot
102+
val schema = sessionHolder
103+
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
101104
.schema
102105
val treeString = if (request.getTreeString.hasLevel) {
103106
schema.treeString(request.getTreeString.getLevel)
@@ -111,8 +114,9 @@ private[connect] class SparkConnectAnalyzeHandler(
111114
.build())
112115

113116
case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
114-
val isLocal = Dataset
115-
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
117+
val rel = request.getIsLocal.getPlan.getRoot
118+
val isLocal = sessionHolder
119+
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
116120
.isLocal
117121
builder.setIsLocal(
118122
proto.AnalyzePlanResponse.IsLocal
@@ -121,8 +125,9 @@ private[connect] class SparkConnectAnalyzeHandler(
121125
.build())
122126

123127
case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
124-
val isStreaming = Dataset
125-
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
128+
val rel = request.getIsStreaming.getPlan.getRoot
129+
val isStreaming = sessionHolder
130+
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
126131
.isStreaming
127132
builder.setIsStreaming(
128133
proto.AnalyzePlanResponse.IsStreaming
@@ -131,8 +136,9 @@ private[connect] class SparkConnectAnalyzeHandler(
131136
.build())
132137

133138
case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
134-
val inputFiles = Dataset
135-
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
139+
val rel = request.getInputFiles.getPlan.getRoot
140+
val inputFiles = sessionHolder
141+
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
136142
.inputFiles
137143
builder.setInputFiles(
138144
proto.AnalyzePlanResponse.InputFiles
@@ -156,29 +162,37 @@ private[connect] class SparkConnectAnalyzeHandler(
156162
.build())
157163

158164
case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
159-
val target = Dataset.ofRows(
160-
session,
161-
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
162-
val other = Dataset.ofRows(
163-
session,
164-
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
165+
val targetRel = request.getSameSemantics.getTargetPlan.getRoot
166+
val target = sessionHolder
167+
.updatePlanCache(targetRel, Dataset.ofRows(session, transformRelation(targetRel)))
168+
val otherRel = request.getSameSemantics.getOtherPlan.getRoot
169+
val other = sessionHolder
170+
.updatePlanCache(otherRel, Dataset.ofRows(session, transformRelation(otherRel)))
165171
builder.setSameSemantics(
166172
proto.AnalyzePlanResponse.SameSemantics
167173
.newBuilder()
168174
.setResult(target.sameSemantics(other)))
169175

170176
case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
171-
val semanticHash = Dataset
172-
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
177+
val rel = request.getSemanticHash.getPlan.getRoot
178+
val semanticHash = sessionHolder
179+
.updatePlanCache(
180+
rel,
181+
Dataset
182+
.ofRows(session, transformRelation(rel)))
173183
.semanticHash()
174184
builder.setSemanticHash(
175185
proto.AnalyzePlanResponse.SemanticHash
176186
.newBuilder()
177187
.setResult(semanticHash))
178188

179189
case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
180-
val target = Dataset
181-
.ofRows(session, transformRelation(request.getPersist.getRelation))
190+
val rel = request.getPersist.getRelation
191+
val target = sessionHolder
192+
.updatePlanCache(
193+
rel,
194+
Dataset
195+
.ofRows(session, transformRelation(rel)))
182196
if (request.getPersist.hasStorageLevel) {
183197
target.persist(
184198
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
@@ -188,8 +202,12 @@ private[connect] class SparkConnectAnalyzeHandler(
188202
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())
189203

190204
case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
191-
val target = Dataset
192-
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
205+
val rel = request.getUnpersist.getRelation
206+
val target = sessionHolder
207+
.updatePlanCache(
208+
rel,
209+
Dataset
210+
.ofRows(session, transformRelation(rel)))
193211
if (request.getUnpersist.hasBlocking) {
194212
target.unpersist(request.getUnpersist.getBlocking)
195213
} else {
@@ -198,8 +216,12 @@ private[connect] class SparkConnectAnalyzeHandler(
198216
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())
199217

200218
case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
201-
val target = Dataset
202-
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
219+
val rel = request.getGetStorageLevel.getRelation
220+
val target = sessionHolder
221+
.updatePlanCache(
222+
rel,
223+
Dataset
224+
.ofRows(session, transformRelation(rel)))
203225
val storageLevel = target.storageLevel
204226
builder.setGetStorageLevel(
205227
proto.AnalyzePlanResponse.GetStorageLevel

0 commit comments

Comments
 (0)