Skip to content

Commit

Permalink
Add UTF8_BINARY_LCASE collation support using custom function
Browse files Browse the repository at this point in the history
  • Loading branch information
miland-db committed Mar 25, 2024
1 parent 2a5fce7 commit e0ce699
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,104 @@ public UTF8String replace(UTF8String search, UTF8String replace) {
return buf.build();
}

public UTF8String replace(UTF8String search, UTF8String replace, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.replace(search, replace);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return lowercaseReplace(search, replace);
}
return collatedReplace(search, replace, collationId);
}

public UTF8String lowercaseReplace(UTF8String search, UTF8String replace) {
if (numBytes == 0 || search.numBytes == 0) {
return this;
}
UTF8String lowercaseString = this.toLowerCase();
UTF8String lowercaseSearch = search.toLowerCase();

int start = 0;
int end = lowercaseString.indexOf(lowercaseSearch, 0);
if (end == -1) {
// Search string was not found, so string is unchanged.
return this;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < numBytes && c < end) {
byteEnd += numBytesForFirstByte(getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes - search.numBytes) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
while (end != -1) {
buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart);
buf.append(replace);
// Update character positions
start = end + lowercaseSearch.numChars();
end = lowercaseString.indexOf(lowercaseSearch, start);
// Update byte positions
byteStart = byteEnd + search.numBytes;
while (byteEnd < numBytes && c < end) {
byteEnd += numBytesForFirstByte(getByte(byteEnd));
c += 1;
}
}
buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart);
return buf.build();
}

private UTF8String collatedReplace(UTF8String search, UTF8String replace, int collationId) {
if (numBytes == 0 || search.numBytes == 0) {
return this;
}

StringSearch stringSearch = CollationFactory.getStringSearch(this, search, collationId);

// Find the first occurrence of the search string.
int end = stringSearch.next();
if (end == StringSearch.DONE) {
// Search string was not found, so string is unchanged.
return this;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < numBytes && c < end) {
byteEnd += numBytesForFirstByte(getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, Math.abs(replace.numBytes - search.numBytes)) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
while (end != StringSearch.DONE) {
if(stringSearch.getMatchLength() == stringSearch.getPattern().length()) {
buf.appendBytes(this.base, this.offset + byteStart, byteEnd - byteStart);
buf.append(replace);
byteStart = byteEnd + search.numBytes;
}
end = stringSearch.next();
// Update byte positions
while (byteEnd < numBytes && c < end) {
byteEnd += numBytesForFirstByte(getByte(byteEnd));
c += 1;
}
}
buf.appendBytes(this.base, this.offset + byteStart, numBytes - byteStart);
return buf.build();
}

public UTF8String translate(Map<String, String> dict) {
String srcStr = this.toString();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ import scala.collection.immutable.Seq

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal, StringReplace}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType

class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession {
class CollationStringExpressionsSuite extends QueryTest
with SharedSparkSession with ExpressionEvalHelper {

case class CollationTestCase[R](s1: String, s2: String, collation: String, expectedResult: R)
case class CollationTestFail[R](s1: String, s2: String, collation: String)
Expand Down Expand Up @@ -70,6 +73,46 @@ class CollationStringExpressionsSuite extends QueryTest with SharedSparkSession
})
}

test("REPLACE check result on explicitly collated strings") {
def testReplace(expected: String, collationId: Integer,
source: String, search: String, replace: String): Unit = {
val sourceLiteral = Literal.create(source, StringType(collationId))
val searchLiteral = Literal.create(search, StringType(collationId))
val replaceLiteral = Literal.create(replace, StringType(collationId))

checkEvaluation(StringReplace(sourceLiteral, searchLiteral, replaceLiteral), expected)
}

// scalastyle:off
// UTF8_BINARY
testReplace("r世e123ace", 0, "r世eplace", "pl", "123")
testReplace("reace", 0, "replace", "pl", "")
testReplace("repl世ace", 0, "repl世ace", "Pl", "")
testReplace("replace", 0, "replace", "", "123")
testReplace("a12ca12c", 0, "abcabc", "b", "12")
testReplace("adad", 0, "abcdabcd", "bc", "")
// UTF8_BINARY_LCASE
testReplace("r世exxace", 1, "r世eplace", "pl", "xx")
testReplace("reAB世ace", 1, "repl世ace", "PL", "AB")
testReplace("Replace", 1, "Replace", "", "123")
testReplace("rexplace", 1, "re世place", "", "x")
testReplace("a12ca12c", 1, "abcaBc", "B", "12")
testReplace("Adad", 1, "AbcdabCd", "Bc", "")
// UNICODE
testReplace("re世place", 2, "re世place", "plx", "123")
testReplace("世Replace", 2, "世Replace", "re", "")
testReplace("replace世", 2, "replace世", "", "123")
testReplace("aBc世a12c", 2, "aBc世abc", "b", "12")
testReplace("adad", 2, "abcdabcd", "bc", "")
// UNICODE_CI
testReplace("replace", 3, "replace", "plx", "123")
testReplace("place", 3, "Replace", "re", "")
testReplace("replace", 3, "replace", "", "123")
testReplace("a12c世a12c", 3, "aBc世abc", "b", "12")
testReplace("a世dad", 3, "a世Bcdabcd", "bC", "")
// scalastyle:on
}

// TODO: Add more tests for other string expressions

}
Expand Down

0 comments on commit e0ce699

Please sign in to comment.