Skip to content

Commit eede681

Browse files
committed
Convert nested fields to UnresolvedAttribute's to avoid schema pruning mismatches
1 parent 475451f commit eede681

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

core/src/main/scala/org/apache/spark/sql/delta/stats/DeltaScan.scala

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import org.apache.spark.sql.delta.actions.AddFile
2020
import org.apache.spark.sql.delta.stats.DeltaDataSkippingType.DeltaDataSkippingType
2121
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
2222

23+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
2324
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.types.StructType
2426

2527
/**
2628
* DataSize describes following attributes for data that consists of a list of input files
@@ -82,4 +84,54 @@ case class DeltaScan(
8284
val scanDurationMs: Long,
8385
val dataSkippingType: DeltaDataSkippingType) {
8486
def allFilters: ExpressionSet = partitionFilters ++ dataFilters ++ unusedFilters
87+
88+
/**
89+
* Compare a set of filters to the filters for this DeltaScan. Because these filters could
90+
* be post optimization, nested fields may have different schemas due to schema pruning. To
91+
* get around this, we convert any nested field to an UnresolvedAttribute for the comparison.
92+
*
93+
* @param other ExpressionSet to compare the filters against
94+
* @return Whether the expressions match with nested schemas ignored
95+
*/
96+
def filtersMatch(other: ExpressionSet): Boolean = DeltaScan.filtersMatch(allFilters, other)
97+
}
98+
99+
object DeltaScan {
100+
private def constructSchema(source: StructType, ordinals: Seq[Int]): StructType = {
101+
val extractedField = source.fields(ordinals.head)
102+
val nestedType = if (ordinals.tail.nonEmpty) {
103+
constructSchema(extractedField.dataType.asInstanceOf[StructType], ordinals.tail)
104+
} else {
105+
extractedField.dataType
106+
}
107+
StructType(Seq(extractedField.copy(dataType = nestedType)))
108+
}
109+
110+
private def pruneExpression(expr: Expression): Expression = expr transform {
111+
case NestedFieldExtraction(nameParts) =>
112+
new UnresolvedAttribute(nameParts)
113+
}
114+
115+
private[delta] def filtersMatch(source: ExpressionSet, target: ExpressionSet): Boolean = {
116+
val prunedSource = source.map(pruneExpression _)
117+
val prunedTarget = target.map(pruneExpression _)
118+
prunedSource == prunedTarget
119+
}
120+
}
121+
122+
object NestedFieldExtraction {
123+
def unapply(e: Expression): Option[Seq[String]] = e match {
124+
case GetStructField(child, ordinal, _) =>
125+
val nested = child match {
126+
case NestedFieldExtraction(nameParts) => Some(nameParts)
127+
case _ => None
128+
}
129+
val childSchema = child.dataType.asInstanceOf[StructType]
130+
nested.map { nameParts =>
131+
nameParts :+ childSchema.fields(ordinal).name
132+
}
133+
case a: AttributeReference if a.dataType.isInstanceOf[StructType] =>
134+
Some(Seq(a.name))
135+
case _ => None
136+
}
85137
}

core/src/main/scala/org/apache/spark/sql/delta/stats/PrepareDeltaScan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ case class PreparedDeltaFileIndex(
320320
partitionFilters: Seq[Expression],
321321
dataFilters: Seq[Expression]): Seq[AddFile] = {
322322
val actualFilters = ExpressionSet(partitionFilters ++ dataFilters)
323-
if (preparedScan.allFilters == actualFilters) {
323+
if (preparedScan.filtersMatch(actualFilters)) {
324324
preparedScan.files.distinct
325325
} else {
326326
logInfo(

core/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,41 @@ import org.scalatest.GivenWhenThen
3131

3232
// scalastyle:off import.ordering.noEmptyLine
3333
import org.apache.spark.sql._
34-
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PredicateHelper}
34+
import org.apache.spark.sql.catalyst.expressions._
3535
import org.apache.spark.sql.functions.{col, lit}
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.test.SharedSparkSession
3838
import org.apache.spark.sql.types._
3939
import org.apache.spark.util.Utils
4040

41+
42+
class DeltaScanSuite extends QueryTest
43+
with SharedSparkSession
44+
with PredicateHelper {
45+
46+
test("Comparing filters should ignore nested schema differences") {
47+
// Simulate outer.inner.b = "abc" filter
48+
val schema = StructType.fromDDL("inner STRUCT<a: STRING, b: STRING>, c LONG")
49+
val prunedSchema = StructType.fromDDL("inner STRUCT<b: STRING>")
50+
51+
val originalAttr = AttributeReference("outer", schema, true)(ExprId(1))
52+
val prunedAttr = AttributeReference("outer", prunedSchema, true)(ExprId(1))
53+
54+
val originalExprs = Seq(
55+
GetStructField(GetStructField(originalAttr, 0), 1),
56+
IsNotNull(originalAttr),
57+
IsNotNull(GetStructField(originalAttr, 0))
58+
)
59+
val prunedExprs = Seq(
60+
GetStructField(GetStructField(prunedAttr, 0), 0),
61+
IsNotNull(prunedAttr),
62+
IsNotNull(GetStructField(prunedAttr, 0))
63+
)
64+
65+
assert(DeltaScan.filtersMatch(ExpressionSet(originalExprs), ExpressionSet(prunedExprs)))
66+
}
67+
}
68+
4169
trait DataSkippingDeltaTestsBase extends QueryTest
4270
with SharedSparkSession with DeltaSQLCommandTest
4371
with PredicateHelper

0 commit comments

Comments
 (0)