Skip to content

[SPARK-52503][SQL][CONNECT] Fix drop when the input column is not existent #51196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@
from contextlib import redirect_stdout

from pyspark.sql import Row, functions, DataFrame
from pyspark.sql.functions import col, lit, count, struct, date_format, to_date, array, explode
from pyspark.sql.functions import (
col,
lit,
count,
struct,
date_format,
to_date,
array,
explode,
when,
concat,
)
from pyspark.sql.types import (
StringType,
IntegerType,
Expand Down Expand Up @@ -189,6 +200,26 @@ def test_drop(self):
self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"])

def test_drop_II(self):
df1 = self.spark.createDataFrame(
[("a", "b", "c")],
schema="colA string, colB string, colC string",
)
df2 = self.spark.createDataFrame(
[("c", "d", "")],
schema="colC string, colD string, colE string",
)
df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn(
"colB",
when(df1["colB"] == "b", concat(df1["colB"].cast("string"), lit("x"))).otherwise(
df1["colB"]
),
)
df4 = df3.drop(df1["colB"])

self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", "colE"])
self.assertEqual(df4.count(), 1)

def test_drop_join(self):
left_df = self.spark.createDataFrame(
[(1, "a"), (2, "b"), (3, "c")],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveNaturalAndUsingJoin ::
ResolveOutputRelation ::
new ResolveTableConstraints(catalogManager) ::
new ResolveDataFrameDropColumns(catalogManager) ::
new ResolveSetVariable(catalogManager) ::
ExtractWindowExpressions ::
GlobalAggregates ::
Expand Down Expand Up @@ -1483,6 +1482,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
new ResolveReferencesInUpdate(catalogManager)
private val resolveReferencesInSort =
new ResolveReferencesInSort(catalogManager)
private val resolveDataFrameDropColumns =
new ResolveDataFrameDropColumns(catalogManager)

/**
* Return true if there're conflicting attributes among children's outputs of a plan
Expand Down Expand Up @@ -1791,6 +1792,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// Pass for Execute Immediate as arguments will be resolved by [[SubstituteExecuteImmediate]].
case e : ExecuteImmediateQuery => e

case d: DataFrameDropColumns if !d.resolved =>
resolveDataFrameDropColumns(d)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}")
q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,27 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
includeLastResort = includeLastResort)
}

// Tries to resolve `UnresolvedAttribute` by the children with Plan Ids.
// Returns `None` if fail to resolve.
private[sql] def tryResolveUnresolvedAttributeByPlanChildren(
u: UnresolvedAttribute,
q: LogicalPlan,
includeLastResort: Boolean = false): Option[Expression] = {
resolveDataFrameColumn(u, q.children).map { r =>
resolveExpression(
r,
resolveColumnByName = nameParts => {
q.resolveChildren(nameParts, conf.resolver)
},
getAttrCandidates = () => {
assert(q.children.length == 1)
q.children.head.output
},
throws = true,
includeLastResort = includeLastResort)
}
}

/**
* The last resort to resolve columns. Currently it does two things:
* - Try to resolve column names as outer references
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
case d: DataFrameDropColumns if d.childrenResolved =>
// expressions in dropList can be unresolved, e.g.
// df.drop(col("non-existing-column"))
val dropped = d.dropList.map {
val dropped = d.dropList.flatMap {
case u: UnresolvedAttribute =>
resolveExpressionByPlanChildren(u, d)
case e => e
if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) {
// Plan ID comes from Spark Connect,
// here we ignore the column if fail to resolve by plan Id.
tryResolveUnresolvedAttributeByPlanChildren(u, d)
} else {
Some(resolveExpressionByPlanChildren(u, d))
}
case e => Some(e)
}
val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr)))
if (remaining.size == d.child.output.size) {
Expand Down