Skip to content

[SPARK-52494] Support colon-sign operator syntax to access Variant fields #51190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ primaryExpression
| constant #constantDefault
| ASTERISK exceptClause? #star
| qualifiedName DOT ASTERISK exceptClause? #star
| col=primaryExpression COLON path=semiStructuredExtractionPath #semiStructuredExtract
| LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor
| LEFT_PAREN query RIGHT_PAREN #subqueryExpression
| functionName LEFT_PAREN (setQuantifier? argument+=functionArgument
Expand All @@ -1230,6 +1231,35 @@ primaryExpression
FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN #overlay
;

semiStructuredExtractionPath
: jsonPathFirstPart (jsonPathParts)*
;

jsonPathIdentifier
: identifier
| BACKQUOTED_IDENTIFIER
;

jsonPathBracketedIdentifier
: LEFT_BRACKET stringLit RIGHT_BRACKET
;

jsonPathFirstPart
: jsonPathIdentifier
| jsonPathBracketedIdentifier
| DOT
| LEFT_BRACKET INTEGER_VALUE RIGHT_BRACKET
| LEFT_BRACKET ASTERISK RIGHT_BRACKET
;

jsonPathParts
: DOT jsonPathIdentifier
| jsonPathBracketedIdentifier
| LEFT_BRACKET INTEGER_VALUE RIGHT_BRACKET
| LEFT_BRACKET ASTERISK RIGHT_BRACKET
| LEFT_BRACKET identifier RIGHT_BRACKET
;

literalType
: DATE
| TIME
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.variant.VariantGet
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{SEMI_STRUCTURED_EXTRACT, TreePattern}
import org.apache.spark.sql.types.{DataType, StringType, VariantType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Represents the extraction of data from a field that contains semi-structured data. The
* semi-structured format can be anything (JSON, key-value delimited, etc), and that information
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be VARIANT only now

* comes from the child expression's column metadata.
* @param child The semi-structured column
* @param field The field to extract
*/
case class SemiStructuredExtract(
child: Expression, field: String) extends UnaryExpression with Unevaluable {
override lazy val resolved = false
override def dataType: DataType = StringType

final override val nodePatterns: Seq[TreePattern] = Seq(SEMI_STRUCTURED_EXTRACT)

override protected def withNewChildInternal(newChild: Expression): SemiStructuredExtract =
copy(child = newChild)
}

/**
* Replaces SemiStructuredExtract expressions by extracting the specified field from the
* semi-structured column (only VariantType is supported for now).
*/
case object ExtractSemiStructuredFields extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
_.containsPattern(SEMI_STRUCTURED_EXTRACT), ruleId) {
case SemiStructuredExtract(column, field) if column.resolved =>
if (column.dataType.isInstanceOf[VariantType]) {
VariantGet(column, Literal(UTF8String.fromString(field)), VariantType, failOnError = true)
} else {
throw new AnalysisException(
errorClass = "COLUMN_IS_NOT_VARIANT_TYPE", messageParameters = Map.empty)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, Str
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.util.Utils

private[this] sealed trait PathInstruction
private[this] object PathInstruction {
sealed trait PathInstruction
object PathInstruction {
private[expressions] case object Subscript extends PathInstruction
private[expressions] case object Wildcard extends PathInstruction
private[expressions] case object Key extends PathInstruction
private[expressions] case class Index(index: Long) extends PathInstruction
private[expressions] case class Named(name: String) extends PathInstruction
case class Named(name: String) extends PathInstruction
}

private[this] sealed trait WriteStyle
Expand All @@ -49,7 +49,7 @@ private[this] object WriteStyle {
private[expressions] case object FlattenStyle extends WriteStyle
}

private[this] object JsonPathParser extends RegexParsers {
object JsonPathParser extends RegexParsers {
import PathInstruction._

def root: Parser[Char] = '$'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, ClusterBySpec}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last}
import org.apache.spark.sql.catalyst.expressions.json.JsonPathParser
import org.apache.spark.sql.catalyst.expressions.json.PathInstruction.Named
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -3322,6 +3324,24 @@ class AstBuilder extends DataTypeAstBuilder
}
}

/**
* Create a [[SemiStructuredExtract]] expression.
*/
override def visitSemiStructuredExtract(
ctx: SemiStructuredExtractContext): Expression = withOrigin(ctx) {
val field = ctx.path.getText
// When `field` starts with a bracket, do not add a `.` as the bracket already implies nesting
// Also the bracket will imply case sensitive field extraction.
val path = if (field.startsWith("[")) "$" + field else s"$$.$field"
val parsedPath = JsonPathParser.parse(path)
if (parsedPath.isEmpty) {
throw new ParseException(errorClass = "PARSE_SYNTAX_ERROR", ctx = ctx)
}
val potentialAlias = parsedPath.get.collect { case Named(name) => name }.lastOption
val node = SemiStructuredExtract(expression(ctx.col), path)
potentialAlias.map { colName => Alias(node, colName)() }.getOrElse(node)
}

/**
* Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
* quoted in ``
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
"org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" ::
"org.apache.spark.sql.catalyst.analysis.ResolveTableConstraints" ::
"org.apache.spark.sql.catalyst.expressions.ExtractSemiStructuredFields" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ object TreePattern extends Enumeration {
val REGEXP_EXTRACT_FAMILY: Value = Value
val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
val SEMI_STRUCTURED_EXTRACT: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val SCALA_UDF: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveDataSource, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields}
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -244,6 +244,7 @@ abstract class BaseSessionStateBuilder(
new EvalSubqueriesForTimeTravel +:
new ResolveTranspose(session) +:
new InvokeProcedures(session) +:
ExtractSemiStructuredFields +:
customResolutionRules

override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
select parse_json('{ "price": 5 }'):price
-- !query analysis
Project [variant_get(parse_json({ "price": 5 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) AS price#x]
+- OneRowRelation


-- !query
select parse_json('{ "price": 30 }'):price::decimal(5, 2)
-- !query analysis
Project [cast(variant_get(parse_json({ "price": 30 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) as decimal(5,2)) AS price#x]
+- OneRowRelation


-- !query
select parse_json('{ "price": 30 }'):price::string
-- !query analysis
Project [cast(variant_get(parse_json({ "price": 30 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) as string) AS price#x]
+- OneRowRelation


-- !query
select parse_json('{ "price": 12345.678 }'):price::decimal(3, 2)
-- !query analysis
Project [cast(variant_get(parse_json({ "price": 12345.678 }, true), $.price, VariantType, true, Some(America/Los_Angeles)) as decimal(3,2)) AS price#x]
+- OneRowRelation


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::double
-- !query analysis
Project [cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].price, VariantType, true, Some(America/Los_Angeles)) as double) AS price#x]
+- OneRowRelation


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::int
-- !query analysis
Project [cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].price, VariantType, true, Some(America/Los_Angeles)) as int) AS price#x]
+- OneRowRelation


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model
-- !query analysis
Project [variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].model, VariantType, true, Some(America/Los_Angeles)) AS model#x]
+- OneRowRelation


-- !query
select substr(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model, 2, 3)
-- !query analysis
Project [substr(cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[0].model, VariantType, true, Some(America/Los_Angeles)) as string), 2, 3) AS substr(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[0].model) AS model, 2, 3)#x]
+- OneRowRelation


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double
-- !query analysis
Project [cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[1].price, VariantType, true, Some(America/Los_Angeles)) as double) AS price#x]
+- OneRowRelation


-- !query
select ceil(sqrt(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double))
-- !query analysis
Project [CEIL(SQRT(cast(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }, true), $.item[1].price, VariantType, true, Some(America/Los_Angeles)) as double))) AS CEIL(SQRT(CAST(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[1].price) AS price AS DOUBLE)))#xL]
+- OneRowRelation
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Simple field extraction and type casting.
select parse_json('{ "price": 5 }'):price;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can create a temp view with one or more VARIANT columns, to simplify the other SELECT queries in this test.

select parse_json('{ "price": 30 }'):price::decimal(5, 2);
select parse_json('{ "price": 30 }'):price::string;
-- Applying an invalid function.
select parse_json('{ "price": 12345.678 }'):price::decimal(3, 2);
-- Access field in an array and feed it into functions.
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::double;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's test all the valid syntaxes, e.g. ASTERISK, brackets with string, etc.

select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::int;
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model;
select substr(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model, 2, 3);
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double;
select ceil(sqrt(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double));
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
select parse_json('{ "price": 5 }'):price
-- !query schema
struct<price:variant>
-- !query output
5


-- !query
select parse_json('{ "price": 30 }'):price::decimal(5, 2)
-- !query schema
struct<price:decimal(5,2)>
-- !query output
30.00


-- !query
select parse_json('{ "price": 30 }'):price::string
-- !query schema
struct<price:string>
-- !query output
30


-- !query
select parse_json('{ "price": 12345.678 }'):price::decimal(3, 2)
-- !query schema
struct<>
-- !query output
org.apache.spark.SparkRuntimeException
{
"errorClass" : "INVALID_VARIANT_CAST",
"sqlState" : "22023",
"messageParameters" : {
"dataType" : "\"DECIMAL(3,2)\"",
"value" : "12345.678"
}
}


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::double
-- !query schema
struct<price:double>
-- !query output
6.12


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].price::int
-- !query schema
struct<price:int>
-- !query output
6


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model
-- !query schema
struct<model:variant>
-- !query output
"basic"


-- !query
select substr(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[0].model, 2, 3)
-- !query schema
struct<substr(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[0].model) AS model, 2, 3):string>
-- !query output
asi


-- !query
select parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double
-- !query schema
struct<price:double>
-- !query output
9.24


-- !query
select ceil(sqrt(parse_json('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }'):item[1].price::double))
-- !query schema
struct<CEIL(SQRT(CAST(variant_get(parse_json({ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }), $.item[1].price) AS price AS DOUBLE))):bigint>
-- !query output
4
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveDataSource, ResolveSessionCatalog, ResolveTranspose}
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.classic.{SparkSession, Strategy}
Expand Down Expand Up @@ -133,6 +133,7 @@ class HiveSessionStateBuilder(
new DetermineTableStats(session) +:
new ResolveTranspose(session) +:
new InvokeProcedures(session) +:
ExtractSemiStructuredFields +:
customResolutionRules

override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
Expand Down