diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysNull.java new file mode 100644 index 0000000000000..6abd036cb8952 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysNull.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +/** + * A predicate that always evaluates to {@code null}. + * + * @since 4.1.0 + */ +@Evolving +public final class AlwaysNull extends Predicate implements Literal { + + public AlwaysNull() { + super("ALWAYS_NULL", new Predicate[]{}); + } + + @Override + public Boolean value() { + return null; + } + + @Override + public DataType dataType() { + return DataTypes.BooleanType; + } + + @Override + public String toString() { return "NULL"; } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 7cc03f3ac3fa6..bf25eaa716830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{BucketTransform, Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} -import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysNull, AlwaysTrue} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -210,6 +210,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { def toCatalyst(expr: V2Expression): Option[Expression] = expr match { case _: AlwaysTrue => Some(Literal.TrueLiteral) case _: AlwaysFalse => Some(Literal.FalseLiteral) + case _: AlwaysNull => Some(Literal.NullPredicateLiteral) case l: V2Literal[_] => Some(Literal(l.value, l.dataType)) case r: NamedReference => Some(UnresolvedAttribute(r.fieldNames.toImmutableArraySeq)) case c: V2Cast => toCatalyst(c.expression).map(Cast(_, c.dataType, ansiEnabled = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala index a27460e2be1cd..1a7e3b03c0e6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} trait TableConstraint extends Expression with Unevaluable { /** Convert to a data source v2 constraint */ @@ -122,9 +122,12 @@ case class CheckConstraint( override val tableName: String = null, override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty) extends UnaryExpression - with TableConstraint { + with TableConstraint + with ImplicitCastInputTypes { // scalastyle:on line.size.limit + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + def toV2Constraint: Constraint = { val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull val enforced = userProvidedCharacteristic.enforced.getOrElse(true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e3ed2c4a0b0b8..f0c3d74edf443 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -65,6 +65,8 @@ object Literal { val FalseLiteral: Literal = Literal(false, BooleanType) + val NullPredicateLiteral: Literal = Literal(null, BooleanType) + def apply(v: Any): Literal = v match { case i: Int => Literal(i, IntegerType) case l: Long => Literal(l, LongType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index fad73a6d81464..e3ebf1822041e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} -import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysNull, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType} +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, NullType, StringType} /** * The builder to generate V2 expressions from catalyst expressions. @@ -78,6 +78,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match { case Literal(true, BooleanType) => Some(new AlwaysTrue()) case Literal(false, BooleanType) => Some(new AlwaysFalse()) + case Cast(Literal(null, NullType), BooleanType, _, _) if isPredicate => Some(new AlwaysNull()) case Literal(value, dataType) => Some(LiteralValue(value, dataType)) case col @ ColumnOrField(nameParts) => val ref = FieldReference(nameParts) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala index a734f8507dac8..90f92aa06e752 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CheckConstraintSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.catalog.constraints.Check import org.apache.spark.sql.execution.command.DDLCommandTestUtils @@ -908,4 +909,47 @@ class CheckConstraintSuite extends QueryTest with CommandSuiteBase with DDLComma } } } + + test("Check constraint with constant valid expression should be optimized out") { + Seq( + "1 > 0", + "abs(-99) < 100", + "null", + "current_date() > DATE'2023-01-01'" + ).foreach { constant => + withNamespaceAndTable("ns", "tbl", nonPartitionCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" CONSTRAINT positive_id CHECK ($constant)) $defaultUsing") + val optimizedPlan = + sql(s"INSERT INTO $t VALUES (1, 10), (2, 20)").queryExecution.optimizedPlan + val filter = optimizedPlan.collectFirst { + case f: Filter => f + } + assert(filter.isEmpty) + } + } + } + + test("Check constraint with constant invalid expression should throw error") { + Seq( + "1 < 0", + "abs(-99) > 100", + "current_date() < DATE'2023-01-01'" + ).foreach { constant => + withNamespaceAndTable("ns", "tbl", nonPartitionCatalog) { t => + sql(s"CREATE TABLE $t (id INT, value INT," + + s" CONSTRAINT positive_id CHECK ($constant)) $defaultUsing") + val error = intercept[SparkRuntimeException] { + sql(s"INSERT INTO $t VALUES (1, 10), (2, 20)") + } + checkError( + exception = error, + condition = "CHECK_CONSTRAINT_VIOLATION", + sqlState = "23001", + parameters = Map("constraintName" -> "positive_id", "expression" -> constant, + "values" -> "") + ) + } + } + } }