Skip to content

Commit 44d9fce

Browse files
[SPARK-52492][SQL] Make InMemoryRelation.convertToColumnarIfPossible customizable
https://issues.apache.org/jira/browse/SPARK-52492 ### What changes were proposed in this pull request? This PR moves `InMemoryRelation.convertToColumnarIfPossible` to as a public API of `CachedBatchSerializer`. ### Why are the changes needed? TL;DR: So that plugins like Gluten could have the relevant logic customized for their own catch serializers. Currently, InMemoryRelation.convertToColumnarIfPossible is highly coupled with vanilla Spark's columnar processing logic. It unwraps the input columnar plan by removing the topmost ColumnarToRowExec, the assumes that the outcome RDD after this process can be recognized by the user-customized cache serializer. But sometimes this assertion is invalid. As in the Apache Gluten project, we may continue distiguishing plans that are all have `supportsColumnar=true` with different columnar batch types. So even the topmost `ColumnarToRowExec` is removed, we still don't know whether the columnar RDD unwrapped can be accepted by Gluten's cache serializer (assuming it only handles one certain type of columnar batch or something). So in Gluten we had a rule to workaround the logic in InMemoryRelation.convertToColumnarIfPossible: https://github.com/apache/incubator-gluten/blob/c6461b4e0c7d3022a31fa832aeab588b1a3200e6/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala#L192-L217. This is the best way we had thought about to get through the issue but it's still not elegant, especially the rule is even caller-sensitive as it needs to determine whether it's called in the caching planning process or not. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With the added UTs. Closes #51189 from zhztheplayer/wip-cache-customize-unwrap. Lead-authored-by: Hongze Zhang <[email protected]> Co-authored-by: Kent Yao <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent e8d6f54 commit 44d9fce

File tree

3 files changed

+100
-20
lines changed

3 files changed

+100
-20
lines changed

sql/core/src/main/scala/org/apache/spark/sql/columnar/CachedBatchSerializer.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.dsl.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, BindReferences, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, Length, LessThan, LessThanOrEqual, Literal, Or, Predicate, StartsWith}
27+
import org.apache.spark.sql.execution.{ColumnarToRowTransition, InputAdapter, SparkPlan, WholeStageCodegenExec}
28+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
2729
import org.apache.spark.sql.execution.columnar.{ColumnStatisticsSchema, PartitionStatistics}
2830
import org.apache.spark.sql.internal.SQLConf
2931
import org.apache.spark.sql.types.{AtomicType, BinaryType, StructType}
@@ -58,6 +60,40 @@ trait CachedBatchSerializer extends Serializable {
5860
*/
5961
def supportsColumnarInput(schema: Seq[Attribute]): Boolean
6062

63+
/**
64+
* Attempt to convert a query plan to its columnar equivalence for columnar caching.
65+
* Called on the query plan that is about to be cached once [[supportsColumnarInput]] returns
66+
* true on its output schema.
67+
*
68+
* The default implementation works by stripping the topmost columnar-to-row transition to
69+
* expose the columnar-based plan to the serializer.
70+
*
71+
* @param plan The plan to convert.
72+
* @return The output plan. Could either be a columnar plan if the input plan is convertible, or
73+
* the input plan unchanged if no viable conversion can be done.
74+
*/
75+
@DeveloperApi
76+
@Since("4.1.0")
77+
def convertToColumnarPlanIfPossible(plan: SparkPlan): SparkPlan = plan match {
78+
case gen: WholeStageCodegenExec =>
79+
gen.child match {
80+
case c2r: ColumnarToRowTransition =>
81+
c2r.child match {
82+
case ia: InputAdapter => ia.child
83+
case _ => plan
84+
}
85+
case _ => plan
86+
}
87+
case c2r: ColumnarToRowTransition => // This matches when whole stage code gen is disabled.
88+
c2r.child
89+
case adaptive: AdaptiveSparkPlanExec =>
90+
// If AQE is enabled for cached plan and table cache supports columnar in, we should mark
91+
// `AdaptiveSparkPlanExec.supportsColumnar` as true to avoid inserting `ColumnarToRow`, so
92+
// that `CachedBatchSerializer` can use `convertColumnarBatchToCachedBatch` to cache data.
93+
adaptive.copy(supportsColumnar = true)
94+
case _ => plan
95+
}
96+
6197
/**
6298
* Convert an `RDD[InternalRow]` into an `RDD[CachedBatch]` in preparation for caching the data.
6399
* @param input the input `RDD` to be converted.

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -381,22 +381,8 @@ object InMemoryRelation {
381381
/* Visible for testing */
382382
private[columnar] def clearSerializer(): Unit = synchronized { ser = None }
383383

384-
def convertToColumnarIfPossible(plan: SparkPlan): SparkPlan = plan match {
385-
case gen: WholeStageCodegenExec => gen.child match {
386-
case c2r: ColumnarToRowTransition => c2r.child match {
387-
case ia: InputAdapter => ia.child
388-
case _ => plan
389-
}
390-
case _ => plan
391-
}
392-
case c2r: ColumnarToRowTransition => // This matches when whole stage code gen is disabled.
393-
c2r.child
394-
case adaptive: AdaptiveSparkPlanExec =>
395-
// If AQE is enabled for cached plan and table cache supports columnar in, we should mark
396-
// `AdaptiveSparkPlanExec.supportsColumnar` as true to avoid inserting `ColumnarToRow`, so
397-
// that `CachedBatchSerializer` can use `convertColumnarBatchToCachedBatch` to cache data.
398-
adaptive.copy(supportsColumnar = true)
399-
case _ => plan
384+
def convertToColumnarIfPossible(plan: SparkPlan): SparkPlan = {
385+
getSerializer(plan.conf).convertToColumnarPlanIfPossible(plan)
400386
}
401387

402388
def apply(
@@ -406,7 +392,7 @@ object InMemoryRelation {
406392
val optimizedPlan = qe.optimizedPlan
407393
val serializer = getSerializer(optimizedPlan.conf)
408394
val child = if (serializer.supportsColumnarInput(optimizedPlan.output)) {
409-
convertToColumnarIfPossible(qe.executedPlan)
395+
serializer.convertToColumnarPlanIfPossible(qe.executedPlan)
410396
} else {
411397
qe.executedPlan
412398
}
@@ -433,8 +419,9 @@ object InMemoryRelation {
433419

434420
def apply(cacheBuilder: CachedRDDBuilder, qe: QueryExecution): InMemoryRelation = {
435421
val optimizedPlan = qe.optimizedPlan
436-
val newBuilder = if (cacheBuilder.serializer.supportsColumnarInput(optimizedPlan.output)) {
437-
cacheBuilder.copy(cachedPlan = convertToColumnarIfPossible(qe.executedPlan))
422+
val serializer = cacheBuilder.serializer
423+
val newBuilder = if (serializer.supportsColumnarInput(optimizedPlan.output)) {
424+
cacheBuilder.copy(cachedPlan = serializer.convertToColumnarPlanIfPossible(qe.executedPlan))
438425
} else {
439426
cacheBuilder.copy(cachedPlan = qe.executedPlan)
440427
}

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/CachedBatchSerializerSuite.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{QueryTest, Row}
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection}
2727
import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer}
28-
import org.apache.spark.sql.execution.ColumnarToRowExec
28+
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, WholeStageCodegenExec}
2929
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
3030
import org.apache.spark.sql.execution.columnar.InMemoryRelation.clearSerializer
3131
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
@@ -122,6 +122,25 @@ class TestSingleIntColumnarCachedBatchSerializer extends CachedBatchSerializer {
122122
}
123123
}
124124

125+
/**
126+
* An equivalence of Spark's [[DefaultCachedBatchSerializer]] while the API
127+
* [[convertToColumnarPlanIfPossible]] is being tested.
128+
*/
129+
class DefaultCachedBatchSerializerNoUnwrap extends DefaultCachedBatchSerializer {
130+
override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = {
131+
// Return true to let Spark call #convertToColumnarPlanIfPossible to unwrap the input
132+
// columnar plan out from the guard of the topmost ColumnarToRowExec.
133+
true
134+
}
135+
136+
override def convertToColumnarPlanIfPossible(plan: SparkPlan): SparkPlan = {
137+
assert(!plan.supportsColumnar)
138+
// Disable the unwrapping code path from default CachedBatchSerializer so
139+
// Spark will keep the topmost columnar-to-row plan node.
140+
plan
141+
}
142+
}
143+
125144
class CachedBatchSerializerSuite extends QueryTest
126145
with SharedSparkSession with AdaptiveSparkPlanHelper {
127146
import testImplicits._
@@ -180,3 +199,41 @@ class CachedBatchSerializerSuite extends QueryTest
180199
}
181200
}
182201
}
202+
203+
204+
class CachedBatchSerializerNoUnwrapSuite extends QueryTest
205+
with SharedSparkSession with AdaptiveSparkPlanHelper {
206+
207+
import testImplicits._
208+
209+
override protected def sparkConf: SparkConf = {
210+
super.sparkConf.set(
211+
StaticSQLConf.SPARK_CACHE_SERIALIZER.key,
212+
classOf[DefaultCachedBatchSerializerNoUnwrap].getName)
213+
}
214+
215+
test("Do not unwrap ColumnarToRowExec") {
216+
withTempPath { workDir =>
217+
val workDirPath = workDir.getAbsolutePath
218+
val input = Seq(100, 200).toDF("count")
219+
input.write.parquet(workDirPath)
220+
val data = spark.read.parquet(workDirPath)
221+
data.cache()
222+
val df = data.union(data)
223+
assert(df.count() == 4)
224+
checkAnswer(df, Row(100) :: Row(200) :: Row(100) :: Row(200) :: Nil)
225+
226+
val finalPlan = df.queryExecution.executedPlan
227+
val cachedPlans = finalPlan.collect {
228+
case i: InMemoryTableScanExec => i.relation.cachedPlan
229+
}
230+
assert(cachedPlans.length == 2)
231+
cachedPlans.foreach {
232+
cachedPlan =>
233+
assert(cachedPlan.isInstanceOf[WholeStageCodegenExec])
234+
assert(cachedPlan.asInstanceOf[WholeStageCodegenExec]
235+
.child.isInstanceOf[ColumnarToRowExec])
236+
}
237+
}
238+
}
239+
}

0 commit comments

Comments
 (0)