Skip to content

Commit

Permalink
[SPARK-49924][SQL] Keep containsNull after ArrayCompact replacement
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Fix `containsNull` of `ArrayCompact`, by adding a new expression `KnownNotContainsNull`

### Why are the changes needed?

#47430 attempted to set `containsNull = false` for `ArrayCompact` for further optimization, but in an incomplete way:

The `ArrayCompact` is a runtime replaceable expression, so will be replaced in optimizer, and cause the `containsNull` be reverted, e.g.

```sql
select array_compact(array(1, null))
```

Rule `ReplaceExpressions` changed `containsNull: false -> true`
```
old schema:
StructField(array_compact(array(1, NULL)),ArrayType(IntegerType,false),false)

new schema
StructField(array_compact(array(1, NULL)),ArrayType(IntegerType,true),false)
```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #48410 from zhengruifeng/fix_array_compact_null.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Oct 12, 2024
1 parent ed4847f commit 62ade5f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
Expand Down Expand Up @@ -5330,15 +5331,12 @@ case class ArrayCompact(child: Expression)
child.dataType.asInstanceOf[ArrayType].elementType, true)
lazy val lambda = LambdaFunction(isNotNull(lv), Seq(lv))

override lazy val replacement: Expression = ArrayFilter(child, lambda)
override lazy val replacement: Expression = KnownNotContainsNull(ArrayFilter(child, lambda))

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def prettyName: String = "array_compact"

override def dataType: ArrayType =
child.dataType.asInstanceOf[ArrayType].copy(containsNull = false)

override protected def withNewChildInternal(newChild: Expression): ArrayCompact =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType}

trait TaggingExpression extends UnaryExpression {
override def nullable: Boolean = child.nullable
Expand Down Expand Up @@ -52,6 +52,17 @@ case class KnownNotNull(child: Expression) extends TaggingExpression {
copy(child = newChild)
}

case class KnownNotContainsNull(child: Expression) extends TaggingExpression {
override def dataType: DataType =
child.dataType.asInstanceOf[ArrayType].copy(containsNull = false)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
child.genCode(ctx)

override protected def withNewChildInternal(newChild: Expression): KnownNotContainsNull =
copy(child = newChild)
}

case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression {
override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized =
copy(child = newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StructField, StructType}

/**
* A dummy optimizer rule for testing that decrements integer literals until 0.
Expand Down Expand Up @@ -313,4 +313,25 @@ class OptimizerSuite extends PlanTest {
assert(message1.contains("not a valid aggregate expression"))
}
}

test("SPARK-49924: Keep containsNull after ArrayCompact replacement") {
val optimizer = new SimpleTestOptimizer() {
override def defaultBatches: Seq[Batch] =
Batch("test", fixedPoint,
ReplaceExpressions) :: Nil
}

val array1 = ArrayCompact(CreateArray(Literal(1) :: Literal.apply(null) :: Nil, false))
val plan1 = Project(Alias(array1, "arr")() :: Nil, OneRowRelation()).analyze
val optimized1 = optimizer.execute(plan1)
assert(optimized1.schema ===
StructType(StructField("arr", ArrayType(IntegerType, false), false) :: Nil))

val struct = CreateStruct(Literal(1) :: Literal(2) :: Nil)
val array2 = ArrayCompact(CreateArray(struct :: Literal.apply(null) :: Nil, false))
val plan2 = Project(Alias(MapFromEntries(array2), "map")() :: Nil, OneRowRelation()).analyze
val optimized2 = optimizer.execute(plan2)
assert(optimized2.schema ===
StructType(StructField("map", MapType(IntegerType, IntegerType, false), false) :: Nil))
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false)) AS array_compact(e)#0]
Project [knownnotcontainsnull(filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false))) AS array_compact(e)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 comments on commit 62ade5f

Please sign in to comment.