Skip to content

Commit db8db47

Browse files
author
changgyoopark-db
committed
Impl
1 parent 6d66f26 commit db8db47

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException}
3232
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
3333
import org.apache.spark.connect.proto
3434
import org.apache.spark.internal.{Logging, LogKeys, MDC}
35-
import org.apache.spark.sql.DataFrame
35+
import org.apache.spark.sql.{DataFrame, Dataset}
3636
import org.apache.spark.sql.SparkSession
3737
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3838
import org.apache.spark.sql.connect.common.InvalidPlanInput
@@ -450,14 +450,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
450450
*/
451451
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
452452
transform: proto.Relation => LogicalPlan): LogicalPlan = {
453-
val planCacheEnabled = Option(session)
454-
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
455453
// We only cache plans that have a plan ID.
456-
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId
454+
val planCacheEnabled = rel.hasCommon && rel.getCommon.hasPlanId &&
455+
Option(session)
456+
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
457457

458458
def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
459459
planCache match {
460-
case Some(cache) if planCacheEnabled && hasPlanId =>
460+
case Some(cache) if planCacheEnabled =>
461461
Option(cache.getIfPresent(rel)) match {
462462
case Some(plan) =>
463463
logDebug(s"Using cached plan for relation '$rel': $plan")
@@ -466,18 +466,26 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
466466
}
467467
case _ => None
468468
}
469-
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
469+
def putPlanCache(rel: proto.Relation, resolvePlan: => LogicalPlan): Unit =
470470
planCache match {
471-
case Some(cache) if planCacheEnabled && hasPlanId =>
472-
cache.put(rel, plan)
471+
case Some(cache) if planCacheEnabled =>
472+
cache.put(rel, resolvePlan)
473473
case _ =>
474474
}
475475

476476
getPlanCache(rel)
477477
.getOrElse({
478478
val plan = transform(rel)
479479
if (cachePlan) {
480-
putPlanCache(rel, plan)
480+
putPlanCache(
481+
rel, {
482+
if (plan.resolved) {
483+
plan
484+
} else {
485+
// Make sure that the plan is fully analyzed before being cached.
486+
Dataset.ofRows(session, plan).logicalPlan
487+
}
488+
})
481489
}
482490
plan
483491
})

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
314314
case Some(expectedCachedRelations) =>
315315
val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala
316316
assert(cachedRelations.size == expectedCachedRelations.size)
317+
val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala
318+
cachedLogicalPlans.foreach(plan => assert(plan.resolved))
317319
expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation)))
318320
case None => assert(sessionHolder.getPlanCache.isEmpty)
319321
}

0 commit comments

Comments
 (0)