Skip to content

Commit

Permalink
IgnoreTransientDefaultMarker - add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sebaciv committed Sep 25, 2024
1 parent e614e2d commit c1307e7
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@ package serialization

trait AcceptsAdditionalCustomMarkers extends AcceptsCustomEvents {

protected def markers: Set[CustomEventMarker[_]]
protected def markers: Set[CustomEventMarker[?]]

override def customEvent[T](marker: CustomEventMarker[T], event: T): Boolean =
marker match {
case marker if markers(marker) => true
case _ => super.customEvent(marker, event)
}
markers(marker) || super.customEvent(marker, event)
}

/**
* [[Input]] implementation that adds additional markers [[CustomEventMarker]] to the provided [[Input]] instance
*/
final class CustomMarkersInputWrapper(
final class CustomMarkersInputWrapper private(
override protected val wrapped: Input,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends InputWrapper with AcceptsAdditionalCustomMarkers {

override def readList(): ListInput =
Expand All @@ -27,19 +24,22 @@ final class CustomMarkersInputWrapper(
new CustomMarkersInputWrapper.AdjustedObjectInput(super.readObject(), markers)
}
object CustomMarkersInputWrapper {
def apply(input: Input, markers: CustomEventMarker[_]*): CustomMarkersInputWrapper =
new CustomMarkersInputWrapper(input, markers.toSet)
def apply(input: Input, markers: CustomEventMarker[?]*): CustomMarkersInputWrapper =
CustomMarkersInputWrapper(input, markers.toSet)

def apply(input: Input, markers: Set[CustomEventMarker[?]]): CustomMarkersInputWrapper =
new CustomMarkersInputWrapper(input, markers)

private final class AdjustedListInput(
override protected val wrapped: ListInput,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends ListInputWrapper with AcceptsAdditionalCustomMarkers {
override def nextElement(): Input = new CustomMarkersInputWrapper(super.nextElement(), markers)
}

private final class AdjustedFieldInput(
override protected val wrapped: FieldInput,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends FieldInputWrapper with AcceptsAdditionalCustomMarkers {

override def readList(): ListInput = new AdjustedListInput(super.readList(), markers)
Expand All @@ -48,7 +48,7 @@ object CustomMarkersInputWrapper {

private final class AdjustedObjectInput(
override protected val wrapped: ObjectInput,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends ObjectInputWrapper with AcceptsAdditionalCustomMarkers {

override def nextField(): FieldInput = new AdjustedFieldInput(super.nextField(), markers)
Expand All @@ -60,9 +60,9 @@ object CustomMarkersInputWrapper {
/**
* [[Output]] implementation that adds additional markers [[CustomEventMarker]] to the provided [[Output]] instance
*/
final class CustomMarkersOutputWrapper(
final class CustomMarkersOutputWrapper private(
override protected val wrapped: Output,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends OutputWrapper with AcceptsAdditionalCustomMarkers {

override def writeSimple(): SimpleOutput =
Expand All @@ -76,17 +76,20 @@ final class CustomMarkersOutputWrapper(
}

object CustomMarkersOutputWrapper {
def apply(output: Output, markers: CustomEventMarker[_]*): CustomMarkersOutputWrapper =
new CustomMarkersOutputWrapper(output, markers.toSet)
def apply(output: Output, markers: CustomEventMarker[?]*): CustomMarkersOutputWrapper =
CustomMarkersOutputWrapper(output, markers.toSet)

def apply(output: Output, markers: Set[CustomEventMarker[?]]): CustomMarkersOutputWrapper =
new CustomMarkersOutputWrapper(output, markers)

private final class AdjustedSimpleOutput(
override protected val wrapped: SimpleOutput,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends SimpleOutputWrapper with AcceptsAdditionalCustomMarkers

private final class AdjustedListOutput(
override protected val wrapped: ListOutput,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends ListOutputWrapper with AcceptsAdditionalCustomMarkers {

override def writeElement(): Output =
Expand All @@ -95,7 +98,7 @@ object CustomMarkersOutputWrapper {

private final class AdjustedObjectOutput(
override protected val wrapped: ObjectOutput,
override protected val markers: Set[CustomEventMarker[_]],
override protected val markers: Set[CustomEventMarker[?]],
) extends ObjectOutputWrapper with AcceptsAdditionalCustomMarkers {

override def writeField(key: String): Output =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ object IgnoreTransientDefaultMarkerTest {
}

class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest {
import IgnoreTransientDefaultMarkerTest._
import IgnoreTransientDefaultMarkerTest.*

override type Raw = Any

def writeToOutput(write: Output => Unit): Any = {
var result: Any = null
write(CustomMarkersOutputWrapper(new SimpleValueOutput(result = _), IgnoreTransientDefaultMarker))
write(CustomMarkersOutputWrapper(new SimpleValueOutput(v => result = v), IgnoreTransientDefaultMarker))
result
}

Expand All @@ -40,6 +40,7 @@ class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest {
testWrite(HasDefaults(str = "dafuq"), Map("str" -> "dafuq", "int" -> 42))
}

//noinspection RedundantDefaultArgument
test("read case class with default values") {
testRead(Map("str" -> "lol", "int" -> 42), HasDefaults(str = "lol", int = 42))
testRead(Map("str" -> "lol"), HasDefaults(str = "lol", int = 42))
Expand All @@ -54,6 +55,7 @@ class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest {
testWrite(HasOptParam(), Map("flag" -> false))
}

//noinspection RedundantDefaultArgument
test("write nested case class with default values") {
testWrite(
value = NestedHasDefaults(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,39 @@ package serialization

import org.scalatest.funsuite.AnyFunSuite

case class RecordWithDefaults(
final case class RecordWithDefaults(
@transientDefault a: String = "",
b: Int = 42
) {
@generated def c: String = s"$a-$b"
}
object RecordWithDefaults extends HasApplyUnapplyCodec[RecordWithDefaults]

class CustomRecordWithDefaults(val a: String, val b: Int)
final class CustomRecordWithDefaults(val a: String, val b: Int)
object CustomRecordWithDefaults extends HasApplyUnapplyCodec[CustomRecordWithDefaults] {
def apply(@transientDefault a: String = "", b: Int = 42): CustomRecordWithDefaults =
new CustomRecordWithDefaults(a, b)
def unapply(crwd: CustomRecordWithDefaults): Opt[(String, Int)] =
Opt((crwd.a, crwd.b))
}

class CustomWrapper(val a: String)
final class CustomWrapper(val a: String)
object CustomWrapper extends HasApplyUnapplyCodec[CustomWrapper] {
def apply(@transientDefault a: String = ""): CustomWrapper = new CustomWrapper(a)
def unapply(cw: CustomWrapper): Opt[String] = Opt(cw.a)
}

case class RecordWithOpts(
final case class RecordWithOpts(
@optionalParam abc: Opt[String] = Opt.Empty,
@transientDefault flag: Opt[Boolean] = Opt.Empty,
b: Int = 42,
)
object RecordWithOpts extends HasApplyUnapplyCodec[RecordWithOpts]

case class SingleFieldRecordWithOpts(@optionalParam abc: Opt[String] = Opt.Empty)
final case class SingleFieldRecordWithOpts(@optionalParam abc: Opt[String] = Opt.Empty)
object SingleFieldRecordWithOpts extends HasApplyUnapplyCodec[SingleFieldRecordWithOpts]

case class SingleFieldRecordWithTD(@transientDefault abc: String = "abc")
final case class SingleFieldRecordWithTD(@transientDefault abc: String = "abc")
object SingleFieldRecordWithTD extends HasApplyUnapplyCodec[SingleFieldRecordWithTD]

class ObjectSizeTest extends AnyFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.scalatest.funsuite.AnyFunSuite

import java.io.{ByteArrayOutputStream, DataOutputStream}

case class Record(
final case class Record(
b: Boolean,
i: Int,
l: List[String],
Expand All @@ -19,7 +19,7 @@ case class Record(
)
object Record extends HasGenCodec[Record]

case class CustomKeysRecord(
final case class CustomKeysRecord(
@cborKey(1) first: Int,
@cborKey(true) second: Boolean,
@cborKey(Vector(1, 2, 3)) third: String,
Expand All @@ -28,13 +28,13 @@ case class CustomKeysRecord(
)
object CustomKeysRecord extends HasCborCodec[CustomKeysRecord]

case class CustomKeysRecordWithDefaults(
final case class CustomKeysRecordWithDefaults(
@transientDefault @cborKey(1) first: Int = 0,
@cborKey(true) second: Boolean,
)
object CustomKeysRecordWithDefaults extends HasCborCodec[CustomKeysRecordWithDefaults]

case class CustomKeysRecordWithNoDefaults(
final case class CustomKeysRecordWithNoDefaults(
@cborKey(1) first: Int = 0,
@cborKey(true) second: Boolean,
)
Expand Down Expand Up @@ -242,11 +242,13 @@ class CborInputOutputTest extends AnyFunSuite {
val value = CustomKeysRecordWithDefaults(first = 0, second = true)
GenCodec.write(output, value)
val bytes = Bytes(baos.toByteArray)
assert(bytes.toString == "A20100F5F5")

val expectedRawValue = "A20100F5F5"
assert(bytes.toString == expectedRawValue)
assert(RawCbor(bytes.bytes).readAs[CustomKeysRecordWithDefaults](keyCodec) == value)

// should be the same as model with @transientDefault and serialization ignoring it
assertRoundtrip(CustomKeysRecordWithNoDefaults(first = 0, second = true), "A20100F5F5")
assertRoundtrip(CustomKeysRecordWithNoDefaults(first = 0, second = true), expectedRawValue)
}

test("chunked text string") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with

import c.universe._

private def IgnoreTransientDefaultMarkerObj: Tree = q"$SerializationPkg.IgnoreTransientDefaultMarker"

override def allowOptionalParams: Boolean = true

def mkTupleCodec[T: WeakTypeTag](elementCodecs: Tree*): Tree = instrument {
Expand Down Expand Up @@ -183,37 +185,37 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with
doWriteField(p, value, transientValue)
}

def writeFieldTransientDefaultPossible(p: ApplyParam, value: Tree): Tree = {
val transientValue =
if (isTransientDefault(p)) Some(p.defaultValue)
else p.optionLike.map(ol => q"${ol.reference(Nil)}.none")
doWriteField(p, value, transientValue)
}
def writeFieldTransientDefaultPossible(p: ApplyParam, value: Tree): Tree =
if (isTransientDefault(p)) doWriteField(p, value, Some(p.defaultValue))
else writeFieldNoTransientDefault(p, value)

def writeField(p: ApplyParam, value: Tree, ignoreTransientDefault: Tree): Tree =
if (isTransientDefault(p))
if (isTransientDefault(p)) // optimize code to avoid calling 'output.customEvent' when param does not have @transientDefault
q"""
if ($ignoreTransientDefault) ${writeFieldNoTransientDefault(p, value)}
if($ignoreTransientDefault) ${writeFieldNoTransientDefault(p, value)}
else ${writeFieldTransientDefaultPossible(p, value)}
"""
else
writeFieldNoTransientDefault(p, value)

def ignoreTransientDefaultCheck: Tree =
q"output.customEvent($SerializationPkg.IgnoreTransientDefaultMarker, ())"
q"output.customEvent($IgnoreTransientDefaultMarkerObj, ())"

// when params size is 1
def writeSingle(p: ApplyParam, value: Tree): Tree =
writeField(p, value, ignoreTransientDefaultCheck)

// when params size is greater than 1
def writeMultiple(value: ApplyParam => Tree): Tree =
if (anyParamHasTransientDefault) {
// optimize code to avoid calling 'output.customEvent' when there no params with @transientDefault
// extracted to `val` to avoid calling 'output.customEvent' multiple times
if (anyParamHasTransientDefault)
q"""
val ignoreTransientDefault = $ignoreTransientDefaultCheck
..${params.map(p => writeField(p, value(p), q"ignoreTransientDefault"))}
"""
} else q"..${params.map(p => writeFieldNoTransientDefault(p, value(p)))}"
else
q"..${params.map(p => writeFieldNoTransientDefault(p, value(p)))}"

def writeFields: Tree = params match {
case Nil =>
Expand All @@ -231,7 +233,7 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with
else
q"""
val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value)
if (unapplyRes.isEmpty) unapplyFailed
if(unapplyRes.isEmpty) unapplyFailed
else ${writeSingle(p, q"unapplyRes.get")}
"""
case _ =>
Expand All @@ -240,7 +242,7 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with
else
q"""
val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value)
if (unapplyRes.isEmpty) unapplyFailed
if(unapplyRes.isEmpty) unapplyFailed
else {
val t = unapplyRes.get
${writeMultiple(p => q"t.${tupleGet(p.idx)}")}
Expand All @@ -262,10 +264,10 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with
case None => p.defaultValue
}

// assumes usage in in size(value, output) method implementation
// assumes usage in SizedCodec.size(value, output) method implementation
def countTransientFields: Tree = {
def checkIgnoreTransientDefaultMarker: Tree =
q"output.isDefined && output.get.customEvent($SerializationPkg.IgnoreTransientDefaultMarker, ())"
q"output.isDefined && output.get.customEvent($IgnoreTransientDefaultMarkerObj, ())"

def doCount(paramsToCount: List[ApplyParam], accessor: ApplyParam => Tree): Tree =
paramsToCount.foldLeft[Tree](q"0") {
Expand Down

0 comments on commit c1307e7

Please sign in to comment.