Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b03dd56

Browse files
committedJun 17, 2025·
numTargetRowsCopied should refer only to rows actually copied
1 parent e16d540 commit b03dd56

File tree

5 files changed

+103
-25
lines changed

5 files changed

+103
-25
lines changed
 

‎sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
2525
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
26-
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
26+
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{CarryOver, Discard, Instruction, Keep, ROW_ID, Split}
2727
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
2828
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
2929
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
@@ -199,8 +199,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
199199
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
200200
// this logic is specific to data sources that replace groups of data
201201
val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output
202-
val keepCarryoverRowsInstruction = Keep(TrueLiteral, carryoverRowsOutput,
203-
systemPredicate = true)
202+
val keepCarryoverRowsInstruction = CarryOver(carryoverRowsOutput)
204203

205204
val matchedInstructions = matchedActions.map { action =>
206205
toInstruction(action, metadataAttrs)

‎sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable}
21+
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2122
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, ROW_ID}
2223
import org.apache.spark.sql.catalyst.trees.UnaryLike
2324
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -87,12 +88,21 @@ object MergeRows {
8788
override def dataType: DataType = NullType
8889
}
8990

91+
// A special case of Keep where the row is kept as is.
92+
case class CarryOver(output: Seq[Expression]) extends Instruction {
93+
override def condition: Expression = TrueLiteral
94+
override def outputs: Seq[Seq[Expression]] = Seq(output)
95+
override def children: Seq[Expression] = output
96+
97+
override protected def withNewChildrenInternal(
98+
newChildren: IndexedSeq[Expression]): Expression = {
99+
copy(output = newChildren)
100+
}
101+
}
102+
90103
case class Keep(
91104
condition: Expression,
92-
output: Seq[Expression],
93-
// flag marking that row should be considered not matching
94-
// any user predicate for metric calculations
95-
systemPredicate: Boolean = false)
105+
output: Seq[Expression])
96106
extends Instruction {
97107
def children: Seq[Expression] = condition +: output
98108
override def outputs: Seq[Seq[Expression]] = Seq(output)

‎sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
2626
import org.apache.spark.sql.catalyst.expressions.AttributeSet
2727
import org.apache.spark.sql.catalyst.expressions.BasePredicate
2828
import org.apache.spark.sql.catalyst.expressions.Expression
29+
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2930
import org.apache.spark.sql.catalyst.expressions.Projection
3031
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
3132
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
32-
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
33+
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{CarryOver, Discard, Instruction, Keep, ROW_ID, Split}
3334
import org.apache.spark.sql.catalyst.util.truncatedString
3435
import org.apache.spark.sql.errors.QueryExecutionErrors
3536
import org.apache.spark.sql.execution.SparkPlan
@@ -72,7 +73,10 @@ case class MergeRowsExec(
7273

7374
override lazy val metrics: Map[String, SQLMetric] = Map(
7475
"numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext,
75-
"number of target rows rewritten unmodified"))
76+
"Number of target rows copied over because they did not match any condition."),
77+
"numTargetRowsUnmatched" -> SQLMetrics.createMetric(sparkContext,
78+
"Number of target rows processed that do not match any condition. " +
79+
"These will be dropped for delta-based merge and retained for group-based merge."))
7680

7781
protected override def doExecute(): RDD[InternalRow] = {
7882
child.execute().mapPartitions(processPartition)
@@ -112,8 +116,11 @@ case class MergeRowsExec(
112116

113117
private def planInstructions(instructions: Seq[Instruction]): Seq[InstructionExec] = {
114118
instructions.map {
115-
case Keep(cond, output, isSystem) =>
116-
KeepExec(createPredicate(cond), createProjection(output), isSystem)
119+
case CarryOver(output) =>
120+
CarryOverExec(createProjection(output))
121+
122+
case Keep(cond, output) =>
123+
KeepExec(createPredicate(cond), createProjection(output))
117124

118125
case Discard(cond) =>
119126
DiscardExec(createPredicate(cond))
@@ -132,12 +139,14 @@ case class MergeRowsExec(
132139
def condition: BasePredicate
133140
}
134141

142+
case class CarryOverExec(projection: Projection) extends InstructionExec {
143+
override def condition: BasePredicate = createPredicate(TrueLiteral)
144+
def apply(row: InternalRow): InternalRow = projection.apply(row)
145+
}
146+
135147
case class KeepExec(
136148
condition: BasePredicate,
137-
projection: Projection,
138-
// flag marking that row should be considered not matching
139-
// any user predicate for metric calculations
140-
systemPredicate: Boolean = false) extends InstructionExec {
149+
projection: Projection) extends InstructionExec {
141150
def apply(row: InternalRow): InternalRow = projection.apply(row)
142151
}
143152

@@ -231,12 +240,13 @@ case class MergeRowsExec(
231240
for (instruction <- instructions) {
232241
if (instruction.condition.eval(row)) {
233242
instruction match {
234-
case keep: KeepExec =>
235-
// For GroupBased Merge, Spark inserts a keep predicate for join matches
243+
case carryOver: CarryOverExec =>
244+
// For GroupBased Merge, Spark inserts a CarryOver predicate
236245
// to retain the row if no other case matches
237-
if (keep.systemPredicate) {
238-
longMetric("numTargetRowsCopied") += 1
239-
}
246+
longMetric("numTargetRowsCopied") += 1
247+
longMetric("numTargetRowsUnmatched") += 1
248+
return carryOver.apply(row)
249+
case keep: KeepExec =>
240250
return keep.apply(row)
241251

242252
case _: DiscardExec =>
@@ -250,7 +260,7 @@ case class MergeRowsExec(
250260
}
251261

252262
if (targetPresent) {
253-
longMetric("numTargetRowsCopied") += 1
263+
longMetric("numTargetRowsUnmatched") += 1
254264
}
255265
null
256266
}

‎sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class DeltaBasedMergeIntoTableSuite extends DeltaBasedMergeIntoTableSuiteBase {
2424

2525
import testImplicits._
2626

27+
override protected def deltaMerge = true
28+
2729
override protected lazy val extraTableProps: java.util.Map[String, String] = {
2830
val props = new java.util.HashMap[String, String]()
2931
props.put("supports-deltas", "true")

‎sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
3535

3636
import testImplicits._
3737

38+
protected def deltaMerge: Boolean = false
39+
3840
test("merge into table containing added column with default value") {
3941
withTempView("source") {
4042
sql(
@@ -1729,7 +1731,48 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
17291731
}
17301732
}
17311733

1732-
test("Emit numTargetRowsCopied metrics") {
1734+
test("Merge metrics with matched clause") {
1735+
withTempView("source") {
1736+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
1737+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
1738+
|{ "pk": 2, "salary": 200, "dep": "software" }
1739+
|{ "pk": 3, "salary": 300, "dep": "hr" }
1740+
|""".stripMargin)
1741+
1742+
val sourceDF = Seq(1, 2).toDF("pk")
1743+
sourceDF.createOrReplaceTempView("source")
1744+
1745+
val mergeExec = findMergeExec {
1746+
s"""MERGE INTO $tableNameAsString t
1747+
|USING source s
1748+
|ON t.pk = s.pk
1749+
|WHEN MATCHED AND salary < 200 THEN
1750+
| UPDATE SET salary = 1000
1751+
|""".stripMargin
1752+
}
1753+
1754+
mergeExec.metrics.get("numTargetRowsCopied") match {
1755+
case Some(metric) =>
1756+
val expectedMetrics = if (deltaMerge) 0 else 2
1757+
assert(metric.value == expectedMetrics)
1758+
case None => fail("numCopiedRows metric not found")
1759+
}
1760+
1761+
mergeExec.metrics.get("numTargetRowsUnmatched") match {
1762+
case Some(metric) => assert(metric.value == 2, "2 rows unmatched")
1763+
case None => fail("numTargetRowsUnmatched metric not found")
1764+
}
1765+
1766+
checkAnswer(
1767+
sql(s"SELECT * FROM $tableNameAsString"),
1768+
Seq(
1769+
Row(1, 1000, "hr"), // updated
1770+
Row(2, 200, "software"),
1771+
Row(3, 300, "hr")))
1772+
}
1773+
}
1774+
1775+
test("Merge metrics with matched and not matched by source clauses") {
17331776
withTempView("source") {
17341777
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
17351778
"""{ "pk": 1, "salary": 100, "dep": "hr" }
@@ -1754,10 +1797,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
17541797
}
17551798

17561799
mergeExec.metrics.get("numTargetRowsCopied") match {
1757-
case Some(metric) => assert(metric.value == 3, "3 rows copied without updates")
1800+
case Some(metric) =>
1801+
val expectedMetrics = if (deltaMerge) 0 else 3
1802+
assert(metric.value == expectedMetrics)
17581803
case None => fail("numCopiedRows metric not found")
17591804
}
17601805

1806+
mergeExec.metrics.get("numTargetRowsUnmatched") match {
1807+
case Some(metric) => assert(metric.value == 3, "3 rows unmatched")
1808+
case None => fail("numTargetRowsUnmatched metric not found")
1809+
}
1810+
17611811
checkAnswer(
17621812
sql(s"SELECT * FROM $tableNameAsString"),
17631813
Seq(
@@ -1769,7 +1819,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
17691819
}
17701820
}
17711821

1772-
test("Emit numTargetRowsCopied metrics 2") {
1822+
test("Merge metrics with matched, not matched, and not matched by source clauses") {
17731823
withTempView("source") {
17741824
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
17751825
"""{ "pk": 1, "salary": 100, "dep": "hr" }
@@ -1796,10 +1846,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
17961846
}
17971847

17981848
mergeExec.metrics.get("numTargetRowsCopied") match {
1799-
case Some(metric) => assert(metric.value == 3, "3 rows copied without updates")
1849+
case Some(metric) =>
1850+
val expectedMetrics = if (deltaMerge) 0 else 3
1851+
assert(metric.value == expectedMetrics)
18001852
case None => fail("numCopiedRows metric not found")
18011853
}
18021854

1855+
mergeExec.metrics.get("numTargetRowsUnmatched") match {
1856+
case Some(metric) => assert(metric.value == 3, "3 rows unmatched")
1857+
case None => fail("numTargetRowsUnmatched metric not found")
1858+
}
1859+
18031860
checkAnswer(
18041861
sql(s"SELECT * FROM $tableNameAsString"),
18051862
Seq(

0 commit comments

Comments
 (0)
Please sign in to comment.