From b5e1b7988031044d3cbdb277668b775c08db1a74 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 12 Jun 2024 20:23:03 +0800 Subject: [PATCH] [SPARK-48596][SQL] Perf improvement for calculating hex string for long ### What changes were proposed in this pull request? This pull request optimizes the `Hex.hex(num: Long)` method by removing leading zeros, thus eliminating the need to copy the array to remove them afterward. ### Why are the changes needed? - Unit tests added - Did a benchmark locally (30~50% speedup) ```scala Hex Long Tests: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ Legacy 1062 1094 16 9.4 106.2 1.0X New 739 807 26 13.5 73.9 1.4X ``` ```scala object HexBenchmark extends BenchmarkBase { override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { val N = 10_000_000 runBenchmark("Hex") { val benchmark = new Benchmark("Hex Long Tests", N, 10, output = output) val range = 1 to 12 benchmark.addCase("Legacy") { _ => (1 to N).foreach(x => range.foreach(y => hexLegacy(x - y))) } benchmark.addCase("New") { _ => (1 to N).foreach(x => range.foreach(y => Hex.hex(x - y))) } benchmark.run() } } def hexLegacy(num: Long): UTF8String = { // Extract the hex digits of num into value[] from right to left val value = new Array[Byte](16) var numBuf = num var len = 0 do { len += 1 // Hex.hexDigits need to be seen here value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } } ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ### Was this patch authored or co-authored using generative AI tooling? no Closes #46952 from yaooqinn/SPARK-48596. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../expressions/mathExpressions.scala | 28 +++++++------ .../sql/catalyst/expressions/HexSuite.scala | 40 +++++++++++++++++++ 2 files changed, 55 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 8df46500ddcf0..6801fc7c257c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1018,9 +1018,9 @@ case class Bin(child: Expression) } object Hex { - val hexDigits = Array[Char]( - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' - ).map(_.toByte) + private final val hexDigits = + Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F') + private final val ZERO_UTF8 = UTF8String.fromBytes(Array[Byte]('0')) // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 val unhexDigits = { @@ -1036,24 +1036,26 @@ object Hex { val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) - value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) + value(i * 2) = hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) } def hex(num: Long): UTF8String = { - // Extract the hex digits of num into value[] from right to left - val value = new Array[Byte](16) + val zeros = jl.Long.numberOfLeadingZeros(num) + if (zeros == jl.Long.SIZE) return ZERO_UTF8 + val len = (jl.Long.SIZE - zeros + 3) / 4 var numBuf = num - var len = 0 - do { - len += 1 - value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) + val value = new Array[Byte](len) + var i = len - 1 + while (i >= 0) { + value(i) = hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 - } while (numBuf != 0) - UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) + i -= 1 + } + UTF8String.fromBytes(value) } def unhex(bytes: Array[Byte]): Array[Byte] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala new file mode 100644 index 0000000000000..a3f963538f447 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HexSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite + +class HexSuite extends SparkFunSuite { + test("SPARK-48596: hex long values") { + assert(Hex.hex(0).toString === "0") + assert(Hex.hex(1).toString === "1") + assert(Hex.hex(15).toString === "F") + assert(Hex.hex(16).toString === "10") + assert(Hex.hex(255).toString === "FF") + assert(Hex.hex(256).toString === "100") + assert(Hex.hex(4095).toString === "FFF") + assert(Hex.hex(4096).toString === "1000") + assert(Hex.hex(65535).toString === "FFFF") + assert(Hex.hex(65536).toString === "10000") + assert(Hex.hex(1048575).toString === "FFFFF") + assert(Hex.hex(1048576).toString === "100000") + assert(Hex.hex(-1).toString === "FFFFFFFFFFFFFFFF") + assert(Hex.hex(Long.MinValue).toString === "8000000000000000") + assert(Hex.hex(Long.MaxValue).toString === "7FFFFFFFFFFFFFFF") + } +}