Skip to content

Commit 7aad4d0

Browse files
mihailotim-dbcloud-fan
authored andcommitted
[SPARK-51160][SQL] Refactor literal function resolution
### What changes were proposed in this pull request? Refactor literal function resolution to a separate object to enable single-pass analyzer to reuse this logic. ### Why are the changes needed? Necessary to support literal function resolution in single-pass ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #49887 from mihailotim-db/mihailotim-db/literal_functions_refactor. Authored-by: Mihailo Timotic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 207390b commit 7aad4d0

File tree

2 files changed

+70
-33
lines changed

2 files changed

+70
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference
2828
import org.apache.spark.sql.catalyst.plans.logical._
2929
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
3030
import org.apache.spark.sql.catalyst.trees.TreePattern._
31-
import org.apache.spark.sql.catalyst.util.toPrettySQL
3231
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier}
3332
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
3433
import org.apache.spark.sql.internal.SQLConf
@@ -96,30 +95,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
9695
}
9796
}
9897

99-
// support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_USER, USER, SESSION_USER and grouping__id
100-
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
101-
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
102-
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
103-
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL),
104-
("user", () => CurrentUser(), toPrettySQL),
105-
("session_user", () => CurrentUser(), toPrettySQL),
106-
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
107-
)
108-
109-
/**
110-
* Literal functions do not require the user to specify braces when calling them
111-
* When an attributes is not resolvable, we try to resolve it as a literal function.
112-
*/
113-
private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = {
114-
if (nameParts.length != 1) return None
115-
val name = nameParts.head
116-
literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map {
117-
case (_, getFuncExpr, getAliasName) =>
118-
val funcExpr = getFuncExpr()
119-
Alias(funcExpr, getAliasName(funcExpr))()
120-
}
121-
}
122-
12398
/**
12499
* Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by
125100
* traversing the input expression in top-down manner. It must be top-down because we need to
@@ -167,14 +142,17 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
167142

168143
case u @ UnresolvedAttribute(nameParts) =>
169144
val result = withPosition(u) {
170-
resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map {
171-
// We trim unnecessary alias here. Note that, we cannot trim the alias at top-level,
172-
// as we should resolve `UnresolvedAttribute` to a named expression. The caller side
173-
// can trim the top-level alias if it's safe to do so. Since we will call
174-
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe.
175-
case Alias(child, _) if !isTopLevel => child
176-
case other => other
177-
}.getOrElse(u)
145+
resolveColumnByName(nameParts)
146+
.orElse(LiteralFunctionResolution.resolve(nameParts))
147+
.map {
148+
// We trim unnecessary alias here. Note that, we cannot trim the alias at top-level,
149+
// as we should resolve `UnresolvedAttribute` to a named expression. The caller side
150+
// can trim the top-level alias if it's safe to do so. Since we will call
151+
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe.
152+
case Alias(child, _) if !isTopLevel => child
153+
case other => other
154+
}
155+
.getOrElse(u)
178156
}
179157
logDebug(s"Resolving $u to $result")
180158
result
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions.{
21+
Alias,
22+
CurrentDate,
23+
CurrentTimestamp,
24+
CurrentUser,
25+
Expression,
26+
GroupingID,
27+
NamedExpression,
28+
VirtualColumn
29+
}
30+
import org.apache.spark.sql.catalyst.util.toPrettySQL
31+
32+
/**
33+
* Resolves literal functions by mapping them to their ''real'' function counterparts.
34+
*/
35+
object LiteralFunctionResolution {
36+
/**
37+
* Literal functions do not require the user to specify braces when calling them
38+
* When an attributes is not resolvable, we try to resolve it as a literal function.
39+
*/
40+
def resolve(nameParts: Seq[String]): Option[NamedExpression] = {
41+
if (nameParts.length != 1) return None
42+
val name = nameParts.head
43+
literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map {
44+
case (_, getFuncExpr, getAliasName) =>
45+
val funcExpr = getFuncExpr()
46+
Alias(funcExpr, getAliasName(funcExpr))()
47+
}
48+
}
49+
50+
// support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_USER, USER, SESSION_USER and grouping__id
51+
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
52+
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
53+
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
54+
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL),
55+
("user", () => CurrentUser(), toPrettySQL),
56+
("session_user", () => CurrentUser(), toPrettySQL),
57+
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
58+
)
59+
}

0 commit comments

Comments
 (0)