Skip to content

Commit e948a7c

Browse files
author
changgyoopark-db
committed
Impl
1 parent 454463b commit e948a7c

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,9 @@ def test_extended_hint_types(self):
440440
self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df))
441441

442442
with io.StringIO() as buf, redirect_stdout(buf):
443-
hinted_df.explain(True)
443+
# the plan cache may hold a fully analyzed plan
444+
with self.sql_conf({"spark.connect.session.planCache.enabled": False}):
445+
hinted_df.explain(True)
444446
explain_output = buf.getvalue()
445447
self.assertGreaterEqual(explain_output.count("1.2345"), 1)
446448
self.assertGreaterEqual(explain_output.count("what"), 1)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException}
3232
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
3333
import org.apache.spark.connect.proto
3434
import org.apache.spark.internal.{Logging, LogKeys, MDC}
35-
import org.apache.spark.sql.DataFrame
35+
import org.apache.spark.sql.{DataFrame, Dataset}
3636
import org.apache.spark.sql.SparkSession
3737
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3838
import org.apache.spark.sql.connect.common.InvalidPlanInput
@@ -450,14 +450,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
450450
*/
451451
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
452452
transform: proto.Relation => LogicalPlan): LogicalPlan = {
453-
val planCacheEnabled = Option(session)
454-
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
455453
// We only cache plans that have a plan ID.
456-
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId
454+
val planCacheEnabled = rel.hasCommon && rel.getCommon.hasPlanId &&
455+
Option(session)
456+
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
457457

458458
def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
459459
planCache match {
460-
case Some(cache) if planCacheEnabled && hasPlanId =>
460+
case Some(cache) if planCacheEnabled =>
461461
Option(cache.getIfPresent(rel)) match {
462462
case Some(plan) =>
463463
logDebug(s"Using cached plan for relation '$rel': $plan")
@@ -466,20 +466,34 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
466466
}
467467
case _ => None
468468
}
469-
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
469+
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): LogicalPlan =
470470
planCache match {
471-
case Some(cache) if planCacheEnabled && hasPlanId =>
472-
cache.put(rel, plan)
473-
case _ =>
471+
case Some(cache) if planCacheEnabled =>
472+
val analyzedPlan = if (plan.analyzed) {
473+
plan
474+
} else {
475+
val qe = Dataset.ofRows(session, plan).queryExecution
476+
if (qe.isLazyAnalysis) {
477+
// The plan is intended to be lazily analyzed.
478+
plan
479+
} else {
480+
// Make sure that the plan is fully analyzed before being cached.
481+
qe.analyzed
482+
}
483+
}
484+
cache.put(rel, analyzedPlan)
485+
analyzedPlan
486+
case _ => plan
474487
}
475488

476489
getPlanCache(rel)
477490
.getOrElse({
478491
val plan = transform(rel)
479492
if (cachePlan) {
480493
putPlanCache(rel, plan)
494+
} else {
495+
plan
481496
}
482-
plan
483497
})
484498
}
485499

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
314314
case Some(expectedCachedRelations) =>
315315
val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala
316316
assert(cachedRelations.size == expectedCachedRelations.size)
317+
val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala
318+
cachedLogicalPlans.foreach(plan => assert(plan.analyzed))
317319
expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation)))
318320
case None => assert(sessionHolder.getPlanCache.isEmpty)
319321
}

0 commit comments

Comments
 (0)