Skip to content

Commit b44c7b3

Browse files
committed
another method
1 parent a0aa0c2 commit b44c7b3

File tree

7 files changed

+37
-47
lines changed

7 files changed

+37
-47
lines changed

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def test_drop_II(self):
212212
df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn(
213213
"colB",
214214
when(
215-
df1["colB"] == "b", concat(df1["colB"].cast("string"), lit("_newValue"))
215+
df1["colB"] == "b", concat(df1["colB"].cast("string"), lit("x"))
216216
).otherwise(df1["colB"]),
217217
)
218218
df4 = df3.drop(df1["colB"])

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
454454
ResolveNaturalAndUsingJoin ::
455455
ResolveOutputRelation ::
456456
new ResolveTableConstraints(catalogManager) ::
457-
new ResolveDataFrameDropColumns(catalogManager) ::
458457
new ResolveSetVariable(catalogManager) ::
459458
ExtractWindowExpressions ::
460459
GlobalAggregates ::
@@ -1483,6 +1482,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
14831482
new ResolveReferencesInUpdate(catalogManager)
14841483
private val resolveReferencesInSort =
14851484
new ResolveReferencesInSort(catalogManager)
1485+
private val resolveDataFrameDropColumns =
1486+
new ResolveDataFrameDropColumns(catalogManager)
14861487

14871488
/**
14881489
* Return true if there're conflicting attributes among children's outputs of a plan
@@ -1791,6 +1792,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17911792
// Pass for Execute Immediate as arguments will be resolved by [[SubstituteExecuteImmediate]].
17921793
case e : ExecuteImmediateQuery => e
17931794

1795+
case d: DataFrameDropColumns if !d.resolved =>
1796+
resolveDataFrameDropColumns(d)
1797+
17941798
case q: LogicalPlan =>
17951799
logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}")
17961800
q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,27 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
509509
includeLastResort = includeLastResort)
510510
}
511511

512+
// Tries to resolve `UnresolvedAttribute` by the children with Plan Ids.
513+
// Returns `None` if fail to resolve.
514+
private[spark] def tryResolveUnresolvedAttributeByPlanChildren(
515+
u: UnresolvedAttribute,
516+
q: LogicalPlan,
517+
includeLastResort: Boolean = false): Option[Expression] = {
518+
resolveDataFrameColumn(u, q.children).map { r =>
519+
resolveExpression(
520+
r,
521+
resolveColumnByName = nameParts => {
522+
q.resolveChildren(nameParts, conf.resolver)
523+
},
524+
getAttrCandidates = () => {
525+
assert(q.children.length == 1)
526+
q.children.head.output
527+
},
528+
throws = true,
529+
includeLastResort = includeLastResort)
530+
}
531+
}
532+
512533
/**
513534
* The last resort to resolve columns. Currently it does two things:
514535
* - Try to resolve column names as outer references
@@ -538,7 +559,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
538559
e: Expression,
539560
q: Seq[LogicalPlan]): Expression = e match {
540561
case u: UnresolvedAttribute =>
541-
542562
resolveDataFrameColumn(u, q).getOrElse(u)
543563
case u: UnresolvedDataFrameStar =>
544564
resolveDataFrameStar(u, q)
@@ -566,14 +586,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
566586
// df1.select(df2.a) <- illegal reference df2.a
567587
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
568588
}
569-
570-
if (resolved.nonEmpty) {
571-
resolved.map(_._1)
572-
} else if (u.getTagValue(LogicalPlan.ALLOW_NON_EXISTENT_COL).nonEmpty) {
573-
Some(NonExistentAttribute(u.name))
574-
} else {
575-
None
576-
}
589+
resolved.map(_._1)
577590
}
578591

579592
private def resolveDataFrameColumnByPlanId(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20-
import org.apache.spark.sql.catalyst.expressions.NonExistentAttribute
2120
import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, LogicalPlan, Project}
2221
import org.apache.spark.sql.catalyst.rules.Rule
2322
import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS
@@ -36,8 +35,14 @@ class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
3635
// expressions in dropList can be unresolved, e.g.
3736
// df.drop(col("non-existing-column"))
3837
val dropped = d.dropList.flatMap {
39-
case u: UnresolvedAttribute => Some(resolveExpressionByPlanChildren(u, d))
40-
case n: NonExistentAttribute => None
38+
case u: UnresolvedAttribute =>
39+
if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) {
40+
// Plan ID comes from Spark Connect,
41+
// here we ignore the column if fail to resolve by plan Id.
42+
tryResolveUnresolvedAttributeByPlanChildren(u, d)
43+
} else {
44+
Some(resolveExpressionByPlanChildren(u, d))
45+
}
4146
case e => Some(e)
4247
}
4348
val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr)))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -790,29 +790,3 @@ object FileSourceGeneratedMetadataAttribute {
790790
.map(attr -> _)
791791
}
792792
}
793-
794-
795-
case class NonExistentAttribute(name: String) extends Attribute with Unevaluable {
796-
override def dataType: DataType = NullType
797-
override def canEqual(that: Any): Boolean = that.isInstanceOf[NonExistentAttribute]
798-
799-
override def withNullability(newNullability: Boolean): Attribute =
800-
throw SparkUnsupportedOperationException()
801-
override def newInstance(): Attribute =
802-
throw SparkUnsupportedOperationException()
803-
override def withQualifier(newQualifier: Seq[String]): Attribute =
804-
throw SparkUnsupportedOperationException()
805-
override def withName(newName: String): Attribute =
806-
throw SparkUnsupportedOperationException()
807-
override def withMetadata(newMetadata: Metadata): Attribute =
808-
throw SparkUnsupportedOperationException()
809-
override def qualifier: Seq[String] =
810-
throw SparkUnsupportedOperationException()
811-
override def exprId: ExprId =
812-
throw SparkUnsupportedOperationException()
813-
override def withExprId(newExprId: ExprId): Attribute =
814-
throw SparkUnsupportedOperationException()
815-
override def withDataType(newType: DataType): Attribute =
816-
throw SparkUnsupportedOperationException()
817-
override def nullable: Boolean = true
818-
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,6 @@ object LogicalPlan {
217217
// to the old code path.
218218
private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
219219
private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col")
220-
// Whether an expression can be resolved to a non-existent column.
221-
private[spark] val ALLOW_NON_EXISTENT_COL = TreeNodeTag[Unit]("allow_non_existent_col")
222220
}
223221

224222
/**

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,11 +2420,7 @@ class SparkConnectPlanner(
24202420
private def transformDrop(rel: proto.Drop): LogicalPlan = {
24212421
var output = Dataset.ofRows(session, transformRelation(rel.getInput))
24222422
if (rel.getColumnsCount > 0) {
2423-
val cols = rel.getColumnsList.asScala.toSeq.map { expr =>
2424-
val e = transformExpression(expr)
2425-
e.setTagValue(LogicalPlan.ALLOW_NON_EXISTENT_COL, ())
2426-
Column(e)
2427-
}
2423+
val cols = rel.getColumnsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
24282424
output = output.drop(cols.head, cols.tail: _*)
24292425
}
24302426
if (rel.getColumnNamesCount > 0) {

0 commit comments

Comments
 (0)