diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2c55e4c8fd375..2606dd2d77378 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -738,18 +738,19 @@ object LikeSimplification extends Rule[LogicalPlan] with PredicateHelper { } else { pattern match { case startsWith(prefix) => - Some(StartsWith(input, Literal(prefix))) + Some(StartsWith(input, Literal.create(prefix, input.dataType))) case endsWith(postfix) => - Some(EndsWith(input, Literal(postfix))) + Some(EndsWith(input, Literal.create(postfix, input.dataType))) // 'a%a' pattern is basically same with 'a%' && '%a'. // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) => - Some(And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), - And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))) + case startsAndEndsWith(prefix, postfix) => Some( + And(GreaterThanOrEqual(Length(input), Literal.create(prefix.length + postfix.length)), + And(StartsWith(input, Literal.create(prefix, input.dataType)), + EndsWith(input, Literal.create(postfix, input.dataType))))) case contains(infix) => - Some(Contains(input, Literal(infix))) + Some(Contains(input, Literal.create(infix, input.dataType))) case equalTo(str) => - Some(EqualTo(input, Literal(str))) + Some(EqualTo(input, Literal.create(str, input.dataType))) case _ => None } } @@ -785,7 +786,7 @@ object LikeSimplification extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( _.containsPattern(LIKE_FAMLIY), ruleId) { - case l @ Like(input, Literal(pattern, StringType), escapeChar) => + case l @ Like(input, Literal(pattern, _: StringType), escapeChar) => if (pattern == null) { // If pattern is null, return null value directly, since "col like null" == null. Literal(null, BooleanType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala index 7405830642796..885ed37098680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLRegexpSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType, StringType} @@ -55,6 +57,60 @@ class CollationSQLRegexpSuite }) } + test("Like simplification should work with collated strings") { + case class SimplifyLikeTestCase[R](collation: String, str: String, cls: Class[_], result: R) + val testCases = Seq( + SimplifyLikeTestCase("UTF8_BINARY", "ab%", classOf[StartsWith], false), + SimplifyLikeTestCase("UTF8_BINARY", "%bc", classOf[EndsWith], false), + SimplifyLikeTestCase("UTF8_BINARY", "a%c", classOf[And], false), + SimplifyLikeTestCase("UTF8_BINARY", "%b%", classOf[Contains], false), + SimplifyLikeTestCase("UTF8_BINARY", "abc", classOf[EqualTo], false), + SimplifyLikeTestCase("UTF8_LCASE", "ab%", classOf[StartsWith], true), + SimplifyLikeTestCase("UTF8_LCASE", "%bc", classOf[EndsWith], true), + SimplifyLikeTestCase("UTF8_LCASE", "a%c", classOf[And], true), + SimplifyLikeTestCase("UTF8_LCASE", "%b%", classOf[Contains], true), + SimplifyLikeTestCase("UTF8_LCASE", "abc", classOf[EqualTo], true) + ) + val tableName = "T" + withTable(tableName) { + sql(s"CREATE TABLE IF NOT EXISTS $tableName(c STRING) using PARQUET") + sql(s"INSERT INTO $tableName(c) VALUES('ABC')") + testCases.foreach { t => + val query = sql(s"select c collate ${t.collation} like '${t.str}' FROM t") + checkAnswer(query, Row(t.result)) + val optimizedPlan = query.queryExecution.optimizedPlan.asInstanceOf[Project] + assert(optimizedPlan.projectList.head.asInstanceOf[Alias].child.getClass == t.cls) + } + } + } + + test("Like simplification should work with collated strings (for default collation)") { + val tableNameBinary = "T_BINARY" + withTable(tableNameBinary) { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY") { + sql(s"CREATE TABLE IF NOT EXISTS $tableNameBinary(c STRING) using PARQUET") + sql(s"INSERT INTO $tableNameBinary(c) VALUES('ABC')") + checkAnswer(sql(s"select c like 'ab%' FROM $tableNameBinary"), Row(false)) + checkAnswer(sql(s"select c like '%bc' FROM $tableNameBinary"), Row(false)) + checkAnswer(sql(s"select c like 'a%c' FROM $tableNameBinary"), Row(false)) + checkAnswer(sql(s"select c like '%b%' FROM $tableNameBinary"), Row(false)) + checkAnswer(sql(s"select c like 'abc' FROM $tableNameBinary"), Row(false)) + } + } + val tableNameLcase = "T_LCASE" + withTable(tableNameLcase) { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { + sql(s"CREATE TABLE IF NOT EXISTS $tableNameLcase(c STRING) using PARQUET") + sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')") + checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'a%c' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like '%b%' FROM $tableNameLcase"), Row(true)) + checkAnswer(sql(s"select c like 'abc' FROM $tableNameLcase"), Row(true)) + } + } + } + test("Support ILike string expression with collation") { // Supported collations case class ILikeTestCase[R](l: String, r: String, c: String, result: R)