From 1a09157c9e912182343a9eef1214e10f5144a263 Mon Sep 17 00:00:00 2001 From: Matt Dziuban Date: Tue, 9 Jul 2024 15:40:45 -0400 Subject: [PATCH] Add support for encoding enum singletons as products. --- .../io/bullet/borer/AdtEncodingStrategy.scala | 8 +++- .../borer/derivation/ArrayBasedCodecs.scala | 2 + .../borer/derivation/DerivedAdtEncoder.scala | 1 + .../io/bullet/borer/derivation/Deriver.scala | 42 +++++++++++++++---- .../borer/derivation/MapBasedCodecs.scala | 2 + 5 files changed, 46 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/io/bullet/borer/AdtEncodingStrategy.scala b/core/src/main/scala/io/bullet/borer/AdtEncodingStrategy.scala index 443a586a..61fca137 100644 --- a/core/src/main/scala/io/bullet/borer/AdtEncodingStrategy.scala +++ b/core/src/main/scala/io/bullet/borer/AdtEncodingStrategy.scala @@ -15,7 +15,13 @@ import io.bullet.borer.internal.{ElementDeque, Util} import scala.annotation.tailrec -sealed abstract class AdtEncodingStrategy: +sealed abstract class AdtEncodingStrategy(private var enumCasesAsProduct0: Boolean = false): + final def enumCasesAsProduct: Boolean = enumCasesAsProduct0 + final def withEnumCasesAsProduct(enumCasesAsProduct: Boolean): this.type = { + enumCasesAsProduct0 = enumCasesAsProduct + this + } + def writeAdtEnvelopeOpen(w: Writer, typeName: String): w.type def writeAdtEnvelopeClose(w: Writer, typeName: String): w.type diff --git a/derivation/src/main/scala/io/bullet/borer/derivation/ArrayBasedCodecs.scala b/derivation/src/main/scala/io/bullet/borer/derivation/ArrayBasedCodecs.scala index 6f7e37c6..d9f9e99c 100644 --- a/derivation/src/main/scala/io/bullet/borer/derivation/ArrayBasedCodecs.scala +++ b/derivation/src/main/scala/io/bullet/borer/derivation/ArrayBasedCodecs.scala @@ -335,6 +335,8 @@ object ArrayBasedCodecs extends DerivationApi { case enc: AdtEncoder[A] => enc.write(w, value) case enc => enc.write(w.writeArrayOpen(2).writeString(typeId), value).writeArrayClose() } + + final def enumCasesAsProduct: Boolean = false } abstract class ArrayBasedAdtDecoder[T] extends DerivedAdtDecoder[T] diff --git a/derivation/src/main/scala/io/bullet/borer/derivation/DerivedAdtEncoder.scala b/derivation/src/main/scala/io/bullet/borer/derivation/DerivedAdtEncoder.scala index b3cdd974..ba628ebf 100644 --- a/derivation/src/main/scala/io/bullet/borer/derivation/DerivedAdtEncoder.scala +++ b/derivation/src/main/scala/io/bullet/borer/derivation/DerivedAdtEncoder.scala @@ -17,6 +17,7 @@ abstract class DerivedAdtEncoder[T] extends AdtEncoder[T] { def writeAdtValue[A](w: Writer, typeId: Long, value: A)(using encoder: Encoder[A]): Writer def writeAdtValue[A](w: Writer, typeId: String, value: A)(using encoder: Encoder[A]): Writer + def enumCasesAsProduct: Boolean } /** diff --git a/derivation/src/main/scala/io/bullet/borer/derivation/Deriver.scala b/derivation/src/main/scala/io/bullet/borer/derivation/Deriver.scala index b7b89bbc..7327cbd6 100644 --- a/derivation/src/main/scala/io/bullet/borer/derivation/Deriver.scala +++ b/derivation/src/main/scala/io/bullet/borer/derivation/Deriver.scala @@ -242,15 +242,44 @@ abstract private[derivation] class Deriver[F[_]: Type, T: Type, Q <: Quotes](usi val flattened = flattenedSubs[Encoder](rootNode, deepRecurse = false, includeEnumSingletonCases = true) val typeIdsAndNodesSorted = extractTypeIdsAndSort(rootNode, flattened) + def writeEnumSingletonTypeId(typeId: Long | String): Expr[Writer] = + typeId match + case x: Long => '{ $w.writeLong(${ Expr(x) }) } + case x: String => '{ $w.writeString(${ Expr(x) }) } + + def writeAdtValue[A: Type]( + typeId: Long | String, + valueAsA: Expr[A], + encA: Expr[Encoder[A]]): Expr[Writer] = + typeId match + case x: Long => '{ $self.writeAdtValue[A]($w, ${ Expr(x) }, $valueAsA)(using $encA) } + case x: String => '{ $self.writeAdtValue[A]($w, ${ Expr(x) }, $valueAsA)(using $encA) } + def rec(ix: Int): Expr[Writer] = if (ix < typeIdsAndNodesSorted.length) { val (typeId, sub) = typeIdsAndNodesSorted(ix) if (sub.isEnumSingletonCase) { val enumRef = sub.enumRef.asExprOf[AnyRef] - val writeTypeId = typeId match - case x: Long => '{ $w.writeLong(${ Expr(x) }) } - case x: String => '{ $w.writeString(${ Expr(x) }) } - '{ if ($value.asInstanceOf[AnyRef] eq $enumRef) $writeTypeId else ${ rec(ix + 1) } } + lazy val writeTypeId = writeEnumSingletonTypeId(typeId) + lazy val writeAsProduct = sub.tpe.asType match + case '[a] => + val valueAsA = '{ $value.asInstanceOf[a] } + val encA = '{ + new Encoder[a] { + def write(w: Writer, value: a): Writer = { + w.writeMapStart() + w.writeBreak() + } + } + } + writeAdtValue[a](typeId, valueAsA, encA) + + '{ + if ($value.asInstanceOf[AnyRef] eq $enumRef) + if ($self.enumCasesAsProduct) $writeAsProduct else $writeTypeId + else + ${ rec(ix + 1) } + } } else { val testType = sub.tpe match case AppliedType(x, _) => x @@ -259,10 +288,7 @@ abstract private[derivation] class Deriver[F[_]: Type, T: Type, Q <: Quotes](usi case '[a] => val valueAsA = '{ $value.asInstanceOf[a] } val encA = Expr.summon[Encoder[a]].getOrElse(fail(s"Cannot find given Encoder[${Type.show[a]}]")) - val writeKeyed = typeId match - case x: Long => '{ $self.writeAdtValue[a]($w, ${ Expr(x) }, $valueAsA)(using $encA) } - case x: String => - '{ $self.writeAdtValue[a]($w, ${ Expr(x) }, $valueAsA)(using $encA) } + val writeKeyed = writeAdtValue[a](typeId, valueAsA, encA) testType.asType match case '[b] => '{ if ($value.isInstanceOf[b]) $writeKeyed else ${ rec(ix + 1) } } } diff --git a/derivation/src/main/scala/io/bullet/borer/derivation/MapBasedCodecs.scala b/derivation/src/main/scala/io/bullet/borer/derivation/MapBasedCodecs.scala index 51ceb97f..d84ea18d 100644 --- a/derivation/src/main/scala/io/bullet/borer/derivation/MapBasedCodecs.scala +++ b/derivation/src/main/scala/io/bullet/borer/derivation/MapBasedCodecs.scala @@ -617,6 +617,8 @@ object MapBasedCodecs extends DerivationApi { enc.write(w.writeString(typeId), value) strategy.writeAdtEnvelopeClose(w, typeName) } + + final def enumCasesAsProduct: Boolean = strategy.enumCasesAsProduct } abstract class MapBasedAdtDecoder[T] extends DerivedAdtDecoder[T]