From 1cbcb4a828dfccf08e0d773ea7343eab085cd954 Mon Sep 17 00:00:00 2001 From: Mikhail Bezoyan Date: Thu, 3 Oct 2024 14:07:59 +0000 Subject: [PATCH] [scrooge] Speeding up serialization of collections and in particular arrays of primitives Differential Revision: https://phabricator.twitter.biz/D1173708 --- scrooge-benchmark/BUILD | 5 - scrooge-benchmark/BUILD.bazel | 5 + .../src/main/scala/{BUILD => BUILD.bazel} | 14 ++- .../scrooge/benchmark/Collections.scala | 86 ++++++++++++--- .../src/main/thrift/{BUILD => BUILD.bazel} | 0 .../src/main/thrift/collections.thrift | 4 + .../twitter/scrooge/internal/TProtocols.scala | 100 +++++++++++++++++- .../scrooge/backend/StructTemplate.scala | 23 +++- 8 files changed, 207 insertions(+), 30 deletions(-) delete mode 100644 scrooge-benchmark/BUILD create mode 100644 scrooge-benchmark/BUILD.bazel rename scrooge-benchmark/src/main/scala/{BUILD => BUILD.bazel} (79%) rename scrooge-benchmark/src/main/thrift/{BUILD => BUILD.bazel} (100%) diff --git a/scrooge-benchmark/BUILD b/scrooge-benchmark/BUILD deleted file mode 100644 index 5f3a1dbb6..000000000 --- a/scrooge-benchmark/BUILD +++ /dev/null @@ -1,5 +0,0 @@ -target( - dependencies = [ - "scrooge/scrooge-benchmark/src/main/scala", - ], -) diff --git a/scrooge-benchmark/BUILD.bazel b/scrooge-benchmark/BUILD.bazel new file mode 100644 index 000000000..7f748d37a --- /dev/null +++ b/scrooge-benchmark/BUILD.bazel @@ -0,0 +1,5 @@ +target( + dependencies = [ + "scrooge/scrooge-benchmark/src/main/scala:benchmark", + ], +) diff --git a/scrooge-benchmark/src/main/scala/BUILD b/scrooge-benchmark/src/main/scala/BUILD.bazel similarity index 79% rename from scrooge-benchmark/src/main/scala/BUILD rename to scrooge-benchmark/src/main/scala/BUILD.bazel index 5c39c3927..ac8d4a200 100644 --- a/scrooge-benchmark/src/main/scala/BUILD +++ b/scrooge-benchmark/src/main/scala/BUILD.bazel @@ -1,4 +1,5 @@ -scala_library( +scala_benchmark_jmh( + name = "benchmark", sources = ["**/*.scala"], compiler_option_sets = ["fatal_warnings"], platform = "java8", @@ -11,9 +12,6 @@ scala_library( "scrooge/scrooge-core/src/main/scala", "scrooge/scrooge-serializer", ], - exports = [ - "3rdparty/jvm/org/openjdk/jmh:jmh-core", - ], ) jvm_binary( @@ -21,10 +19,16 @@ jvm_binary( main = "org.openjdk.jmh.Main", platform = "java8", dependencies = [ - ":scala", + ":benchmark_compiled_benchmark_lib", scoped( "3rdparty/jvm/org/slf4j:slf4j-nop", scope = "runtime", ), ], ) + +jvm_app( + name = "jmh-bundle", + basename = "scrooge-benchmark-bundle", + binary = ":jmh", +) diff --git a/scrooge-benchmark/src/main/scala/com/twitter/scrooge/benchmark/Collections.scala b/scrooge-benchmark/src/main/scala/com/twitter/scrooge/benchmark/Collections.scala index 4ed176940..3ed26eede 100644 --- a/scrooge-benchmark/src/main/scala/com/twitter/scrooge/benchmark/Collections.scala +++ b/scrooge-benchmark/src/main/scala/com/twitter/scrooge/benchmark/Collections.scala @@ -1,10 +1,12 @@ package com.twitter.scrooge.benchmark +import com.twitter.scrooge.ThriftStruct import com.twitter.scrooge.ThriftStructCodec import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit import java.util.Random -import org.apache.thrift.protocol.{TProtocol, TBinaryProtocol} +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.protocol.TProtocol import org.apache.thrift.transport.TTransport import org.openjdk.jmh.annotations._ import thrift.benchmark._ @@ -39,6 +41,7 @@ class TRewindable extends TTransport { def rewind(): Unit = { pos = 0 + arr.reset() } def inspect: String = { @@ -63,33 +66,57 @@ class Collections(size: Int) { val list: TRewindable = new TRewindable val listProt: TBinaryProtocol = new TBinaryProtocol(list) + val listDouble: TRewindable = new TRewindable + val listDoubleProt: TBinaryProtocol = new TBinaryProtocol(listDouble) + val rng: Random = new Random(31415926535897932L) val mapVals: mutable.Builder[(Long, String), Map[Long, String]] = Map.newBuilder[Long, String] val setVals: mutable.Builder[Long, Set[Long]] = Set.newBuilder[Long] val listVals: mutable.Builder[Long, Seq[Long]] = Seq.newBuilder[Long] + val arrayVals = new Array[Long](size) + val arrayDoublesVals = new Array[Double](size) - val m: Unit = for (_ <- (0 until size)) { + val m: Unit = for (i <- (0 until size)) { val num = rng.nextLong() mapVals += (num -> num.toString) setVals += num listVals += num + arrayVals(i) = num + arrayDoublesVals(i) = num } - MapCollections.encode(MapCollections(mapVals.result), mapProt) - SetCollections.encode(SetCollections(setVals.result), setProt) - ListCollections.encode(ListCollections(listVals.result), listProt) + val mapCollections: MapCollections = MapCollections(mapVals.result) + val setCollections: SetCollections = SetCollections(setVals.result) + val listCollections: ListCollections = ListCollections(listVals.result) + val arrayCollections: ListCollections = ListCollections(arrayVals) + val arrayDoubleCollections: ListDoubleCollections = ListDoubleCollections(arrayDoublesVals) + + MapCollections.encode(mapCollections, mapProt) + SetCollections.encode(setCollections, setProt) + ListCollections.encode(listCollections, listProt) + ListDoubleCollections.encode(arrayDoubleCollections, listDoubleProt) - def run(codec: ThriftStructCodec[_], prot: TProtocol, buff: TRewindable): Unit = { + def decode(codec: ThriftStructCodec[_], prot: TProtocol, buff: TRewindable): Unit = { codec.decode(prot) buff.rewind() } + + def encode[T <: ThriftStruct]( + codec: ThriftStructCodec[T], + prot: TProtocol, + buff: TRewindable, + obj: T + ): Unit = { + codec.encode(obj, prot) + buff.rewind() + } } object CollectionsBenchmark { @State(Scope.Thread) class CollectionsState { - @Param(Array("1", "5", "10", "100", "500", "1000")) + @Param(Array("1", "5", "10", "100", "500")) var size: Int = 1 var col: Collections = _ @@ -98,24 +125,53 @@ object CollectionsBenchmark { def setup(): Unit = { col = new Collections(size) } - } } -@OutputTimeUnit(TimeUnit.NANOSECONDS) +@OutputTimeUnit(TimeUnit.SECONDS) @BenchmarkMode(Array(Mode.Throughput)) +@Fork(1) +@Warmup(iterations = 3, time = 10, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 10, timeUnit = TimeUnit.SECONDS) class CollectionsBenchmark { import CollectionsBenchmark._ @Benchmark - def timeMap(state: CollectionsState): Unit = - state.col.run(MapCollections, state.col.mapProt, state.col.map) + def timeEncodeList(state: CollectionsState): Unit = + state.col.encode( + ListCollections, + state.col.listProt, + state.col.list, + state.col.listCollections + ) + + @Benchmark + def timeEncodeArray(state: CollectionsState): Unit = + state.col.encode( + ListCollections, + state.col.listProt, + state.col.list, + state.col.arrayCollections + ) + + @Benchmark + def timeEncodeDoubleArray(state: CollectionsState): Unit = + state.col.encode( + ListDoubleCollections, + state.col.listDoubleProt, + state.col.listDouble, + state.col.arrayDoubleCollections + ) + + @Benchmark + def timeDecodeMap(state: CollectionsState): Unit = + state.col.decode(MapCollections, state.col.mapProt, state.col.map) @Benchmark - def timeSet(state: CollectionsState): Unit = - state.col.run(SetCollections, state.col.setProt, state.col.set) + def timeDecodeSet(state: CollectionsState): Unit = + state.col.decode(SetCollections, state.col.setProt, state.col.set) @Benchmark - def timeList(state: CollectionsState): Unit = - state.col.run(ListCollections, state.col.listProt, state.col.list) + def timeDecodeList(state: CollectionsState): Unit = + state.col.decode(ListCollections, state.col.listProt, state.col.list) } diff --git a/scrooge-benchmark/src/main/thrift/BUILD b/scrooge-benchmark/src/main/thrift/BUILD.bazel similarity index 100% rename from scrooge-benchmark/src/main/thrift/BUILD rename to scrooge-benchmark/src/main/thrift/BUILD.bazel diff --git a/scrooge-benchmark/src/main/thrift/collections.thrift b/scrooge-benchmark/src/main/thrift/collections.thrift index 19171ff03..5a31f64c3 100644 --- a/scrooge-benchmark/src/main/thrift/collections.thrift +++ b/scrooge-benchmark/src/main/thrift/collections.thrift @@ -12,3 +12,7 @@ struct SetCollections { struct ListCollections { 1: list longs } + +struct ListDoubleCollections { + 1: list doubles +} diff --git a/scrooge-core/src/main/scala/com/twitter/scrooge/internal/TProtocols.scala b/scrooge-core/src/main/scala/com/twitter/scrooge/internal/TProtocols.scala index 7a1bafe3f..b41241d92 100644 --- a/scrooge-core/src/main/scala/com/twitter/scrooge/internal/TProtocols.scala +++ b/scrooge-core/src/main/scala/com/twitter/scrooge/internal/TProtocols.scala @@ -4,9 +4,12 @@ import com.twitter.scrooge.TFieldBlob import com.twitter.scrooge.ThriftEnum import com.twitter.scrooge.ThriftUnion import java.nio.ByteBuffer +import java.util.function.ObjDoubleConsumer +import java.util.function.ObjLongConsumer import org.apache.thrift.protocol._ import scala.collection.immutable import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer /** * Reads and writes fields for a `TProtocol`. Intended to be used @@ -93,11 +96,26 @@ final class TProtocols private[TProtocols] { elementType: Byte, writeElement: (TProtocol, T) => Unit ): Unit = { - protocol.writeListBegin(new TList(typeForCollection(elementType), list.size)) + val size = list.size + protocol.writeListBegin(new TList(typeForCollection(elementType), size)) list match { + case wrappedArray: mutable.WrappedArray[T] => + val arr = wrappedArray.array + var i = 0 + while (i < size) { + val el: T = arr(i).asInstanceOf[T] + writeElement(protocol, el) + i += 1 + } + case arrayBuffer: ArrayBuffer[T] => + var i = 0 + while (i < size) { + writeElement(protocol, arrayBuffer(i)) + i += 1 + } case _: IndexedSeq[_] => var i = 0 - while (i < list.size) { + while (i < size) { writeElement(protocol, list(i)) i += 1 } @@ -109,6 +127,78 @@ final class TProtocols private[TProtocols] { protocol.writeListEnd() } + def writeListDouble( + protocol: TProtocol, + list: collection.Seq[Double], + elementType: Byte, + writeElement: ObjDoubleConsumer[TProtocol] + ): Unit = { + val size = list.size + protocol.writeListBegin(new TList(typeForCollection(elementType), size)) + list match { + case wrappedArray: mutable.WrappedArray.ofDouble => + val arr = wrappedArray.array + var i = 0 + while (i < size) { + writeElement.accept(protocol, arr(i)) + i += 1 + } + case arrayBuffer: ArrayBuffer[Double] => + var i = 0 + while (i < size) { + writeElement.accept(protocol, arrayBuffer(i)) + i += 1 + } + case _: IndexedSeq[_] => + var i = 0 + while (i < size) { + writeElement.accept(protocol, list(i)) + i += 1 + } + case _ => + list.foreach { element => + writeElement.accept(protocol, element) + } + } + protocol.writeListEnd() + } + + def writeListI64( + protocol: TProtocol, + list: collection.Seq[Long], + elementType: Byte, + writeElement: ObjLongConsumer[TProtocol] + ): Unit = { + val len = list.size + protocol.writeListBegin(new TList(typeForCollection(elementType), len)) + list match { + case wrappedArray: mutable.WrappedArray.ofLong => + val arr = wrappedArray.array + var i = 0 + while (i < len) { + writeElement.accept(protocol, arr(i)) + i += 1 + } + case arrayBuffer: ArrayBuffer[Long] => + var i = 0 + while (i < len) { + writeElement.accept(protocol, arrayBuffer(i)) + i += 1 + } + case _: IndexedSeq[_] => + var i = 0 + while (i < len) { + writeElement.accept(protocol, list(i)) + i += 1 + } + case _ => + list.foreach { element => + writeElement.accept(protocol, element) + } + } + protocol.writeListEnd() + } + def writeSet[T]( protocol: TProtocol, set: collection.Set[T], @@ -193,9 +283,15 @@ object TProtocols { val writeI64Fn: (TProtocol, Long) => Unit = (protocol, value) => protocol.writeI64(value) + val writeI64Consumer: ObjLongConsumer[TProtocol] = + (protocol: TProtocol, value: Long) => protocol.writeI64(value) + val writeDoubleFn: (TProtocol, Double) => Unit = (protocol, value) => protocol.writeDouble(value) + val writeDoubleConsumer: ObjDoubleConsumer[TProtocol] = + (protocol: TProtocol, value: Double) => protocol.writeDouble(value) + val writeStringFn: (TProtocol, String) => Unit = (protocol, value) => protocol.writeString(value) diff --git a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala index 09d93b92a..b257b13ab 100644 --- a/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala +++ b/scrooge-generator/src/main/scala/com/twitter/scrooge/backend/StructTemplate.scala @@ -236,6 +236,25 @@ trait StructTemplate { self: TemplateGenerator => } } + @scala.annotation.tailrec + private[this] def genWriteListFn( + elementFieldType: FieldType, + fieldName: CodeFragment, + protoName: String + ): CodeFragment = { + elementFieldType match { + case at: AnnotatedFieldType => genWriteListFn(at.unwrap, fieldName, protoName) + case TDouble => + v(s"$rootProtos.writeListDouble($protoName, $fieldName, TType.DOUBLE, _root_.com.twitter.scrooge.internal.TProtocols.writeDoubleConsumer)") + case TI64 => + v(s"$rootProtos.writeListI64($protoName, $fieldName, TType.I64, _root_.com.twitter.scrooge.internal.TProtocols.writeI64Consumer)") + case _ => + val elemFieldType = s"TType.${genConstType(elementFieldType)}" + val writeElementFn = genWriteValueFn2(elementFieldType) + v(s"$rootProtos.writeList($protoName, $fieldName, $elemFieldType, $writeElementFn)") + } + } + @scala.annotation.tailrec private[this] def genWriteValueFn2(fieldType: FieldType): CodeFragment = { fieldType match { @@ -306,9 +325,7 @@ trait StructTemplate { self: TemplateGenerator => val writeElement = genWriteValueFn2(t.eltType) v(s"$rootProtos.writeSet($protoName, $fieldName, $elemFieldType, $writeElement)") case t: ListType => - val elemFieldType = s"TType.${genConstType(t.eltType)}" - val writeElement = genWriteValueFn2(t.eltType) - v(s"$rootProtos.writeList($protoName, $fieldName, $elemFieldType, $writeElement)") + genWriteListFn(t.eltType, fieldName, protoName) case t: MapType => val keyType = s"TType.${genConstType(t.keyType)}" val valType = s"TType.${genConstType(t.valueType)}"