Skip to content

Fix double file scan from nested schema pruning #1096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import org.apache.spark.sql.delta.actions.AddFile
import org.apache.spark.sql.delta.stats.DeltaDataSkippingType.DeltaDataSkippingType
import com.fasterxml.jackson.databind.annotation.JsonDeserialize

import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StructType

/**
* DataSize describes following attributes for data that consists of a list of input files
Expand Down Expand Up @@ -96,4 +98,44 @@ case class DeltaScan(

lazy val filtersUsedForSkipping: ExpressionSet = partitionFilters ++ dataFilters
lazy val allFilters: ExpressionSet = filtersUsedForSkipping ++ unusedFilters

/**
* Compare a set of filters to the filters for this DeltaScan. Because these filters could
* be post optimization, nested fields may have different schemas due to schema pruning. To
* get around this, we convert any nested field to an UnresolvedAttribute for the comparison.
*
* @param other ExpressionSet to compare the filters against
* @param resolver resolver used to transform prepared filters to attributes in other
* @return Whether the expressions match with nested schemas ignored
*/
def filtersMatch(other: ExpressionSet, resolver: Resolver): Boolean = {
DeltaScan.filtersMatch(allFilters, other, resolver) ||
DeltaScan.filtersMatch(filtersUsedForSkipping, other, resolver)
}
}

object DeltaScan {
private[delta] def filtersMatch(source: ExpressionSet, target: ExpressionSet,
resolver: Resolver): Boolean = {
// Create a map of exprId -> Attribute from target expressions
val targetAttrs = target.flatMap { e =>
e.collect { case a: Attribute => a.exprId -> a }
}.toMap

// Convert all GetStructField expressions to UnresolvedExtractValue
val unresolvedSource = source.map(_ transform {
case g: GetStructField =>
UnresolvedExtractValue(g.child, Literal(g.extractFieldName))
})

// Re-resolve extract values on new attributes
val resolvedSource = unresolvedSource.map(_ transformUp {
case a: Attribute =>
targetAttrs.getOrElse(a.exprId, a)
case UnresolvedExtractValue(child, extraction) =>
ExtractValue(child, extraction, resolver)
})

resolvedSource == target
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ case class PreparedDeltaFileIndex(
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[AddFile] = {
val currentFilters = ExpressionSet(partitionFilters ++ dataFilters)
val (addFiles, eventData) = if (currentFilters == preparedScan.allFilters ||
currentFilters == preparedScan.filtersUsedForSkipping) {
val resolver = spark.sessionState.conf.resolver
val (addFiles, eventData) = if (preparedScan.filtersMatch(currentFilters, resolver)) {
// [[DeltaScan]] was created using `allFilters` out of which only `filtersUsedForSkipping`
// filters were used for skipping while creating the DeltaScan.
// If currentFilters is same as allFilters, then no need to recalculate files and we can use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ import org.scalatest.GivenWhenThen
// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PredicateHelper}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1831,6 +1831,23 @@ trait DataSkippingDeltaTestsBase extends DeltaExcludedBySparkVersionTestMixinShi
}
}

test("Ensure that we do reuse a scan with nested column pruning") {
withTempDir { dir =>
Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").select(struct('a, 'b).alias("nested"))
.write.format("delta").save(dir.getCanonicalPath)

val plans = DeltaTestUtils.withLogicalPlansCaptured(spark, optimizedPlan = true) {
sql(s"SELECT nested.b FROM delta.`$dir` WHERE nested.b < 2").collect()
}
val rddScans = plans.flatMap(_.collect {
// Only look for the RDD containg the path, ignore the initial cached RDD for all AddFile's
case l: LogicalRDD if l.output.exists(_.name == "path") => l
})
// We should only scan the log once
assert(rddScans.length == 1)
}
}

protected def expectedStatsForFile(index: Int, colName: String, deltaLog: DeltaLog): String = {
s"""{"numRecords":1,"minValues":{"$colName":$index},"maxValues":{"$colName":$index},""" +
s""""nullCount":{"$colName":0}}""".stripMargin
Expand Down Expand Up @@ -1902,6 +1919,30 @@ trait DataSkippingDeltaTestsBase extends DeltaExcludedBySparkVersionTestMixinShi
}
}

test("Comparing filters should ignore nested schema differences") {
// Simulate outer.inner.b = "abc" filter
val schema = StructType.fromDDL("inner STRUCT<a: STRING, b: STRING>, c LONG")
val prunedSchema = StructType.fromDDL("inner STRUCT<b: STRING>")

val originalAttr = AttributeReference("outer", schema, true)(ExprId(1))
val prunedAttr = AttributeReference("outer", prunedSchema, true)(ExprId(1))

val originalExprs = Seq(
GetStructField(GetStructField(originalAttr, 0), 1),
IsNotNull(originalAttr),
IsNotNull(GetStructField(originalAttr, 0))
)
val prunedExprs = Seq(
GetStructField(GetStructField(prunedAttr, 0), 0),
IsNotNull(prunedAttr),
IsNotNull(GetStructField(prunedAttr, 0))
)

val resolver = org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
assert(DeltaScan.filtersMatch(ExpressionSet(originalExprs), ExpressionSet(prunedExprs),
resolver))
}

protected def parse(deltaLog: DeltaLog, predicate: String): Seq[Expression] = {

// We produce a wrong filter in this case otherwise
Expand Down
Loading