Skip to content

Commit

Permalink
Add support for instr and unit test in CollationStringExpressionsSuit…
Browse files Browse the repository at this point in the history
…e.scala
  • Loading branch information
miland-db committed Mar 21, 2024
1 parent 3a1609a commit a4d3592
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,33 @@ public int indexOf(UTF8String v, int start) {
return -1;
}

public int indexOf(UTF8String substring, int start, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.indexOf(substring, start);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().indexOf(substring.toLowerCase(), start);
}
return collatedIndexOf(substring, collationId);
}

private int collatedIndexOf(UTF8String substring, int collationId) {
if (substring.numBytes == 0) {
return 0;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, substring, collationId);

int pos = 0;
while ((pos = stringSearch.next()) != StringSearch.DONE) {
if (stringSearch.getMatchLength() == stringSearch.getPattern().length()) {
return pos;
}
}

return 0;
}

/**
* Find the `str` from left to right.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1377,17 +1377,34 @@ case class StringInstr(str: Expression, substr: Expression)
override def left: Expression = str
override def right: Expression = substr
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

override def nullSafeEval(string: Any, sub: Any): Any = {
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0) + 1
val collationId = left.dataType.asInstanceOf[StringType].collationId
string.asInstanceOf[UTF8String].indexOf(sub.asInstanceOf[UTF8String], 0, collationId) + 1
}

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
return defaultCheck
}

val collationId = left.dataType.asInstanceOf[StringType].collationId
CollationTypeConstraints.checkCollationCompatibility(collationId, children.map(_.dataType))
}

override def prettyName: String = "instr"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (l, r) =>
s"($l).indexOf($r, 0) + 1")
val collationId = left.dataType.asInstanceOf[StringType].collationId

if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1")
} else {
defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0, $collationId) + 1")
}
}

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.test.SharedSparkSession
class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession {

case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R)

case class CollationTestFail[R](s1: String, s2: String, collation: String)

test("Support ConcatWs string expression with Collation") {
Expand Down Expand Up @@ -70,6 +71,52 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession
})
}

case class SubstringIndexTestFail[R](s1: String, s2: String, c1: String, c2: String)

test("Support SubstringIndex with Collation") {
// Supported collations
val checks = Seq(
CollationTestCase("The quick brown fox jumps over the dog.", "fox", "UTF8_BINARY", 17),
CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UTF8_BINARY", 0),
CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UTF8_BINARY_LCASE", 17),
CollationTestCase("The quick brown fox jumps over the dog.", "fox", "UNICODE", 17),
CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UNICODE", 0),
CollationTestCase("The quick brown fox jumps over the dog.", "FOX", "UNICODE_CI", 17)
)
checks.foreach(ct => {
checkAnswer(sql(s"SELECT instr(collate('${ct.s1}', '${ct.collation}'), " +
s"collate('${ct.s2}', '${ct.collation}'))"),
Row(ct.expectedResult))
})
// Unsupported collation pairs
val fails = Seq(
SubstringIndexTestFail("The quick brown fox jumps over the dog.",
"Fox", "UTF8_BINARY_LCASE", "UTF8_BINARY"),
SubstringIndexTestFail("The quick brown fox jumps over the dog.",
"FOX", "UNICODE_CI", "UNICODE")
)
fails.foreach(ct => {
val expr = s"instr(collate('${ct.s1}', '${ct.c1}'), collate('${ct.s2}', '${ct.c2}'))"
checkError(
exception = intercept[ExtendedAnalysisException] {
sql(s"SELECT $expr")
},
errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH",
sqlState = "42K09",
parameters = Map(
"sqlExpr" -> s"\"instr(collate(${ct.s1}), collate(${ct.s2}))\"",
"collationNameLeft" -> s"${ct.c1}",
"collationNameRight" -> s"${ct.c2}"
),
context = ExpectedContext(
fragment = s"$expr",
start = 7,
stop = 45 + ct.s1.length + ct.c1.length + ct.s2.length + ct.c2.length
)
)
})
}

// TODO: Add more tests for other string expressions

}
Expand Down

0 comments on commit a4d3592

Please sign in to comment.