From 8c464a862e8801f66a4c9d2b1842bca5c5a38c77 Mon Sep 17 00:00:00 2001 From: piyushn-stripe <74793191+piyushn-stripe@users.noreply.github.com> Date: Tue, 25 Jul 2023 09:17:22 -0400 Subject: [PATCH] Add codec support to ser / deser pre-aggregated IR tiles (#523) * Add codec support to ser / deser pre-aggregated IR tiles * Remove trailing comma for Scala 2.11 * Use unpack instead of UnpackedAggregations --- .../scala/ai/chronon/online/TileCodec.scala | 57 ++++++++++++ .../ai/chronon/online/TileCodecTest.scala | 89 +++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 online/src/main/scala/ai/chronon/online/TileCodec.scala create mode 100644 online/src/test/scala/ai/chronon/online/TileCodecTest.scala diff --git a/online/src/main/scala/ai/chronon/online/TileCodec.scala b/online/src/main/scala/ai/chronon/online/TileCodec.scala new file mode 100644 index 000000000..aa70bc687 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/TileCodec.scala @@ -0,0 +1,57 @@ +package ai.chronon.online + +import ai.chronon.aggregator.row.RowAggregator +import ai.chronon.api.{BooleanType, DataType, GroupBy, StructType} +import org.apache.avro.generic.GenericData +import ai.chronon.api.Extensions.{AggregationOps, MetadataOps} + +import scala.collection.JavaConverters._ + +object TileCodec { + def buildRowAggregator(groupBy: GroupBy, inputSchema: Seq[(String, DataType)]): RowAggregator = { + // a set of Chronon groupBy aggregations needs to be flatted out to get the + // full cross-product of all the feature column aggregations to be computed + val unpackedAggs = groupBy.aggregations.asScala.flatMap(_.unpack) + new RowAggregator(inputSchema, unpackedAggs) + } +} + +/** + * TileCodec is a helper class that allows for the creation of pre-aggregated tiles of feature values. + * These pre-aggregated tiles can be used in the serving layer to compute the final feature values along + * with batch pre-aggregates produced by GroupByUploads. + * The pre-aggregated tiles are serialized as Avro and indicate whether the tile is complete or not (partial aggregates) + */ +class TileCodec(rowAggregator: RowAggregator, groupBy: GroupBy) { + val windowedIrSchema: StructType = StructType.from("WindowedIr", rowAggregator.irSchema) + val fields: Array[(String, DataType)] = Array( + "collapsedIr" -> windowedIrSchema, + "isComplete" -> BooleanType + ) + + val tileChrononSchema: StructType = + StructType.from(s"${groupBy.metaData.cleanName}_TILE_IR", fields) + val tileAvroSchema: String = AvroConversions.fromChrononSchema(tileChrononSchema).toString() + val tileAvroCodec: AvroCodec = AvroCodec.of(tileAvroSchema) + private val irToBytesFn = AvroConversions.encodeBytes(tileChrononSchema, null) + + def makeTileIr(ir: Array[Any], isComplete: Boolean): Array[Byte] = { + val normalizedIR = rowAggregator.normalize(ir) + val tileIr: Array[Any] = Array(normalizedIR, Boolean.box(isComplete)) + irToBytesFn(tileIr) + } + + def decodeTileIr(tileIr: Array[Byte]): (Array[Any], Boolean) = { + val decodedTileIr = tileAvroCodec.decode(tileIr) + val collapsedIr = decodedTileIr + .get("collapsedIr") + .asInstanceOf[GenericData.Record] + + val ir = AvroConversions + .toChrononRow(collapsedIr, windowedIrSchema) + .asInstanceOf[Array[Any]] + val denormalizedIr = rowAggregator.denormalize(ir) + val isComplete = decodedTileIr.get("isComplete").asInstanceOf[Boolean] + (denormalizedIr, isComplete) + } +} diff --git a/online/src/test/scala/ai/chronon/online/TileCodecTest.scala b/online/src/test/scala/ai/chronon/online/TileCodecTest.scala new file mode 100644 index 000000000..c666211fd --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/TileCodecTest.scala @@ -0,0 +1,89 @@ +package ai.chronon.online + +import ai.chronon.api.{Aggregation, Builders, FloatType, IntType, ListType, LongType, Operation, Row, StringType, TimeUnit, Window} +import org.junit.Assert.assertEquals +import org.junit.Test +import scala.collection.JavaConverters._ + +class TileCodecTest { + private val histogram = Map[String, Int]("A" -> 3, "B" -> 2).asJava + + private val aggregationsAndExpected: Array[(Aggregation, Any)] = Array( + Builders.Aggregation(Operation.AVERAGE, "views", Seq(new Window(1, TimeUnit.DAYS))) -> 16.0, + Builders.Aggregation(Operation.AVERAGE, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 4.0, + + Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(1, TimeUnit.DAYS))) -> 12.0f, + Builders.Aggregation(Operation.SUM, "rating", Seq(new Window(7, TimeUnit.DAYS))) -> 12.0f, + + Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L, + Builders.Aggregation(Operation.UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L, + + Builders.Aggregation(Operation.LAST, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "C", + Builders.Aggregation(Operation.LAST, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "C", + + Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava, + Builders.Aggregation(Operation.LAST_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> List("C", "B").asJava, + + Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava, + Builders.Aggregation(Operation.TOP_K, "title", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "1")) -> List("C").asJava, + + Builders.Aggregation(Operation.MIN, "title", Seq(new Window(1, TimeUnit.DAYS))) -> "A", + Builders.Aggregation(Operation.MIN, "title", Seq(new Window(7, TimeUnit.DAYS))) -> "A", + + Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(1, TimeUnit.DAYS))) -> 3L, + Builders.Aggregation(Operation.APPROX_UNIQUE_COUNT, "title", Seq(new Window(7, TimeUnit.DAYS))) -> 3L, + + Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(1, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram, + Builders.Aggregation(Operation.HISTOGRAM, "hist_input", Seq(new Window(7, TimeUnit.DAYS)), argMap = Map("k" -> "2")) -> histogram + ) + + private val schema = List( + "created" -> LongType, + "views" -> IntType, + "rating" -> FloatType, + "title" -> StringType, + "hist_input" -> ListType(StringType) + ) + + @Test + def testTileCodecIrSerRoundTrip(): Unit = { + val groupByMetadata = Builders.MetaData(name = "my_group_by") + val (aggregations, expectedVals) = aggregationsAndExpected.unzip + val groupBy = Builders.GroupBy(metaData = groupByMetadata, aggregations = aggregations) + val rowAggregator = TileCodec.buildRowAggregator(groupBy, schema) + val rowIR = rowAggregator.init + val tileCodec = new TileCodec(rowAggregator, groupBy) + + val originalIsComplete = true + val rows = Seq( + createRow(1519862399984L, 4, 4.0f, "A", Seq("D", "A", "B", "A")), + createRow(1519862399984L, 40, 5.0f, "B", Seq()), + createRow(1519862399988L, 4, 3.0f, "C", Seq("A", "B", "C")) + ) + rows.foreach(row => rowAggregator.update(rowIR, row)) + val bytes = tileCodec.makeTileIr(rowIR, originalIsComplete) + assert(bytes.length > 0) + + val (deserPayload, isComplete) = tileCodec.decodeTileIr(bytes) + assert(isComplete == originalIsComplete) + + // lets finalize the payload intermediate results and verify things + val finalResults = rowAggregator.finalize(deserPayload) + expectedVals.zip(finalResults).zip(rowAggregator.outputSchema.map(_._1)).foreach { + case ((expected, actual), name) => + println(s"Checking: $name") + assertEquals(expected, actual) + } + } + + def createRow(ts: Long, views: Int, rating: Float, title: String, histInput: Seq[String]): Row = { + val values: Array[(String, Any)] = Array( + "created" -> ts, + "views" -> views, + "rating" -> rating, + "title" -> title, + "hist_input" -> histInput + ) + new ArrayRow(values.map(_._2), ts) + } +}