Skip to content

Commit

Permalink
Merge pull request #633 from AVSystem/ignore-transient-default
Browse files Browse the repository at this point in the history
Allow Output to ignore @transientDefault
  • Loading branch information
sebaciv authored Sep 30, 2024
2 parents 3242127 + c1307e7 commit e950fc6
Show file tree
Hide file tree
Showing 9 changed files with 401 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,13 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs {
}

trait SizedCodec[T] extends GenCodec[T] {
def size(value: T): Int
def size(value: T): Int = size(value, Opt.Empty)

def size(value: T, output: Opt[SequentialOutput]): Int

protected final def declareSizeFor(output: SequentialOutput, value: T): Unit =
if (output.sizePolicy != SizePolicy.Ignored) {
output.declareSize(size(value))
output.declareSize(size(value, output.opt))
}
}

Expand All @@ -336,8 +338,8 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs {
object OOOFieldsObjectCodec {
// this was introduced so that transparent wrapper cases are possible in flat sealed hierarchies
final class Transformed[A, B](val wrapped: OOOFieldsObjectCodec[B], onWrite: A => B, onRead: B => A) extends OOOFieldsObjectCodec[A] {
def size(value: A): Int =
wrapped.size(onWrite(value))
def size(value: A, output: Opt[SequentialOutput]): Int =
wrapped.size(onWrite(value), output)

def readObject(input: ObjectInput, outOfOrderFields: FieldValues): A =
onRead(wrapped.readObject(input, outOfOrderFields))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.avsystem.commons
package serialization

/**
* Instructs [[GenCodec]] to <b>ignore</b> the [[transientDefault]] annotation when serializing a case class.
* This ensures that even if a field's value is the same as its default, it will be <b>included</b> in the serialized
* representation. Deserialization behavior remains <b>unchanged</b>. If a field is missing from the input, the default
* value will be used as usual.
*
* This marker can be helpful when using the same model class in multiple contexts with different serialization
* formats that have conflicting requirements for handling default values.
*
* @see [[CustomMarkersOutputWrapper]] for an easy way to add markers to existing [[Output]] implementations
*/
object IgnoreTransientDefaultMarker extends CustomEventMarker[Unit]
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class OOOFieldCborRawKeysCodec[T](stdObjectCodec: OOOFieldsObjectCodec[T], keyCo
stdObjectCodec.writeFields(output, value)
}

def size(value: T): Int = stdObjectCodec.size(value)
def size(value: T, output: Opt[SequentialOutput]): Int = stdObjectCodec.size(value, output)
def nullable: Boolean = stdObjectCodec.nullable
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.avsystem.commons
package serialization

trait AcceptsAdditionalCustomMarkers extends AcceptsCustomEvents {

protected def markers: Set[CustomEventMarker[?]]

override def customEvent[T](marker: CustomEventMarker[T], event: T): Boolean =
markers(marker) || super.customEvent(marker, event)
}

/**
* [[Input]] implementation that adds additional markers [[CustomEventMarker]] to the provided [[Input]] instance
*/
final class CustomMarkersInputWrapper private(
override protected val wrapped: Input,
override protected val markers: Set[CustomEventMarker[?]],
) extends InputWrapper with AcceptsAdditionalCustomMarkers {

override def readList(): ListInput =
new CustomMarkersInputWrapper.AdjustedListInput(super.readList(), markers)

override def readObject(): ObjectInput =
new CustomMarkersInputWrapper.AdjustedObjectInput(super.readObject(), markers)
}
object CustomMarkersInputWrapper {
def apply(input: Input, markers: CustomEventMarker[?]*): CustomMarkersInputWrapper =
CustomMarkersInputWrapper(input, markers.toSet)

def apply(input: Input, markers: Set[CustomEventMarker[?]]): CustomMarkersInputWrapper =
new CustomMarkersInputWrapper(input, markers)

private final class AdjustedListInput(
override protected val wrapped: ListInput,
override protected val markers: Set[CustomEventMarker[?]],
) extends ListInputWrapper with AcceptsAdditionalCustomMarkers {
override def nextElement(): Input = new CustomMarkersInputWrapper(super.nextElement(), markers)
}

private final class AdjustedFieldInput(
override protected val wrapped: FieldInput,
override protected val markers: Set[CustomEventMarker[?]],
) extends FieldInputWrapper with AcceptsAdditionalCustomMarkers {

override def readList(): ListInput = new AdjustedListInput(super.readList(), markers)
override def readObject(): ObjectInput = new AdjustedObjectInput(super.readObject(), markers)
}

private final class AdjustedObjectInput(
override protected val wrapped: ObjectInput,
override protected val markers: Set[CustomEventMarker[?]],
) extends ObjectInputWrapper with AcceptsAdditionalCustomMarkers {

override def nextField(): FieldInput = new AdjustedFieldInput(super.nextField(), markers)
override def peekField(name: String): Opt[FieldInput] =
super.peekField(name).map(new AdjustedFieldInput(_, markers))
}
}

/**
* [[Output]] implementation that adds additional markers [[CustomEventMarker]] to the provided [[Output]] instance
*/
final class CustomMarkersOutputWrapper private(
override protected val wrapped: Output,
override protected val markers: Set[CustomEventMarker[?]],
) extends OutputWrapper with AcceptsAdditionalCustomMarkers {

override def writeSimple(): SimpleOutput =
new CustomMarkersOutputWrapper.AdjustedSimpleOutput(super.writeSimple(), markers)

override def writeList(): ListOutput =
new CustomMarkersOutputWrapper.AdjustedListOutput(super.writeList(), markers)

override def writeObject(): ObjectOutput =
new CustomMarkersOutputWrapper.AdjustedObjectOutput(super.writeObject(), markers)
}

object CustomMarkersOutputWrapper {
def apply(output: Output, markers: CustomEventMarker[?]*): CustomMarkersOutputWrapper =
CustomMarkersOutputWrapper(output, markers.toSet)

def apply(output: Output, markers: Set[CustomEventMarker[?]]): CustomMarkersOutputWrapper =
new CustomMarkersOutputWrapper(output, markers)

private final class AdjustedSimpleOutput(
override protected val wrapped: SimpleOutput,
override protected val markers: Set[CustomEventMarker[?]],
) extends SimpleOutputWrapper with AcceptsAdditionalCustomMarkers

private final class AdjustedListOutput(
override protected val wrapped: ListOutput,
override protected val markers: Set[CustomEventMarker[?]],
) extends ListOutputWrapper with AcceptsAdditionalCustomMarkers {

override def writeElement(): Output =
new CustomMarkersOutputWrapper(super.writeElement(), markers)
}

private final class AdjustedObjectOutput(
override protected val wrapped: ObjectOutput,
override protected val markers: Set[CustomEventMarker[?]],
) extends ObjectOutputWrapper with AcceptsAdditionalCustomMarkers {

override def writeField(key: String): Output =
new CustomMarkersOutputWrapper(super.writeField(key), markers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SingletonCodec[T <: Singleton](
) extends ErrorReportingCodec[T] with OOOFieldsObjectCodec[T] {
final def nullable = true
final def readObject(input: ObjectInput, outOfOrderFields: FieldValues): T = singletonValue
def size(value: T): Int = 0
def size(value: T, output: Opt[SequentialOutput]): Int = 0
def writeFields(output: ObjectOutput, value: T): Unit = ()
}

Expand Down Expand Up @@ -109,7 +109,7 @@ abstract class ProductCodec[T <: Product](
nullable: Boolean,
fieldNames: Array[String]
) extends ApplyUnapplyCodec[T](typeRepr, nullable, fieldNames) {
def size(value: T): Int = value.productArity
def size(value: T, output: Opt[SequentialOutput]): Int = value.productArity

final def writeFields(output: ObjectOutput, value: T): Unit = {
val size = value.productArity
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.avsystem.commons
package serialization

import com.avsystem.commons.serialization.CodecTestData.HasDefaults

object IgnoreTransientDefaultMarkerTest {
final case class NestedHasDefaults(
@transientDefault flag: Boolean = false,
obj: HasDefaults,
list: Seq[HasDefaults],
@transientDefault defaultObj: HasDefaults = HasDefaults(),
)
object NestedHasDefaults extends HasGenCodec[NestedHasDefaults]

final case class HasOptParam(
@transientDefault flag: Boolean = false,
@optionalParam str: Opt[String] = Opt.Empty,
)
object HasOptParam extends HasGenCodec[HasOptParam]
}

class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest {
import IgnoreTransientDefaultMarkerTest.*

override type Raw = Any

def writeToOutput(write: Output => Unit): Any = {
var result: Any = null
write(CustomMarkersOutputWrapper(new SimpleValueOutput(v => result = v), IgnoreTransientDefaultMarker))
result
}

def createInput(raw: Any): Input =
CustomMarkersInputWrapper(new SimpleValueInput(raw), IgnoreTransientDefaultMarker)

test("write case class with default values") {
testWrite(HasDefaults(str = "lol"), Map("str" -> "lol", "int" -> 42))
testWrite(HasDefaults(43, "lol"), Map("int" -> 43, "str" -> "lol"))
testWrite(HasDefaults(str = null), Map("str" -> null, "int" -> 42))
testWrite(HasDefaults(str = "dafuq"), Map("str" -> "dafuq", "int" -> 42))
}

//noinspection RedundantDefaultArgument
test("read case class with default values") {
testRead(Map("str" -> "lol", "int" -> 42), HasDefaults(str = "lol", int = 42))
testRead(Map("str" -> "lol"), HasDefaults(str = "lol", int = 42))
testRead(Map("int" -> 43, "str" -> "lol"), HasDefaults(int = 43, str = "lol"))
testRead(Map("str" -> null, "int" -> 42), HasDefaults(str = null, int = 42))
testRead(Map("str" -> null), HasDefaults(str = null, int = 42))
testRead(Map(), HasDefaults(str = "dafuq", int = 42))
}

test("write case class with opt values") {
testWrite(HasOptParam(str = "lol".opt), Map("flag" -> false, "str" -> "lol"))
testWrite(HasOptParam(), Map("flag" -> false))
}

//noinspection RedundantDefaultArgument
test("write nested case class with default values") {
testWrite(
value = NestedHasDefaults(
flag = false,
obj = HasDefaults(str = "lol"),
list = Seq(HasDefaults(int = 43)),
defaultObj = HasDefaults(),
),
expectedRepr = Map(
"flag" -> false,
"defaultObj" -> Map[String, Any]("str" -> "kek", "int" -> 42),
"obj" -> Map[String, Any]("str" -> "lol", "int" -> 42),
"list" -> List(Map[String, Any]("str" -> "kek", "int" -> 43)),
),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,77 @@ package serialization

import org.scalatest.funsuite.AnyFunSuite

case class RecordWithDefaults(
final case class RecordWithDefaults(
@transientDefault a: String = "",
b: Int = 42
) {
@generated def c: String = s"$a-$b"
}
object RecordWithDefaults extends HasApplyUnapplyCodec[RecordWithDefaults]

class CustomRecordWithDefaults(val a: String, val b: Int)
final class CustomRecordWithDefaults(val a: String, val b: Int)
object CustomRecordWithDefaults extends HasApplyUnapplyCodec[CustomRecordWithDefaults] {
def apply(@transientDefault a: String = "", b: Int = 42): CustomRecordWithDefaults =
new CustomRecordWithDefaults(a, b)
def unapply(crwd: CustomRecordWithDefaults): Opt[(String, Int)] =
Opt((crwd.a, crwd.b))
}

class CustomWrapper(val a: String)
final class CustomWrapper(val a: String)
object CustomWrapper extends HasApplyUnapplyCodec[CustomWrapper] {
def apply(@transientDefault a: String = ""): CustomWrapper = new CustomWrapper(a)
def unapply(cw: CustomWrapper): Opt[String] = Opt(cw.a)
}

final case class RecordWithOpts(
@optionalParam abc: Opt[String] = Opt.Empty,
@transientDefault flag: Opt[Boolean] = Opt.Empty,
b: Int = 42,
)
object RecordWithOpts extends HasApplyUnapplyCodec[RecordWithOpts]

final case class SingleFieldRecordWithOpts(@optionalParam abc: Opt[String] = Opt.Empty)
object SingleFieldRecordWithOpts extends HasApplyUnapplyCodec[SingleFieldRecordWithOpts]

final case class SingleFieldRecordWithTD(@transientDefault abc: String = "abc")
object SingleFieldRecordWithTD extends HasApplyUnapplyCodec[SingleFieldRecordWithTD]

class ObjectSizeTest extends AnyFunSuite {
test("computing object size") {
assert(RecordWithDefaults.codec.size(RecordWithDefaults()) == 2)
assert(RecordWithDefaults.codec.size(RecordWithDefaults("fuu")) == 3)
assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt)) == 2)
assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt, true.opt)) == 3)
assert(RecordWithOpts.codec.size(RecordWithOpts()) == 1)
assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts()) == 0)
assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts("abc".opt)) == 1)
assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD()) == 0)
assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD("haha")) == 1)
assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults()) == 1)
assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults("fuu")) == 2)
assert(CustomWrapper.codec.size(CustomWrapper()) == 0)
assert(CustomWrapper.codec.size(CustomWrapper("fuu")) == 1)
}

test("computing object size with custom output") {
val defaultIgnoringOutput = new SequentialOutput {
override def customEvent[T](marker: CustomEventMarker[T], event: T): Boolean =
marker match {
case IgnoreTransientDefaultMarker => true
case _ => super.customEvent(marker, event)
}
override def finish(): Unit = ()
}
assert(RecordWithDefaults.codec.size(RecordWithDefaults(), defaultIgnoringOutput.opt) == 3)
assert(RecordWithDefaults.codec.size(RecordWithDefaults("fuu"), defaultIgnoringOutput.opt) == 3)
assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt), defaultIgnoringOutput.opt) == 3)
assert(RecordWithOpts.codec.size(RecordWithOpts("abc".opt, true.opt), defaultIgnoringOutput.opt) == 3)
assert(RecordWithOpts.codec.size(RecordWithOpts(), defaultIgnoringOutput.opt) == 2)
assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts(), defaultIgnoringOutput.opt) == 0) // @optionalParam field should NOT be counted
assert(SingleFieldRecordWithOpts.codec.size(SingleFieldRecordWithOpts("abc".opt), defaultIgnoringOutput.opt) == 1)
assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD(), defaultIgnoringOutput.opt) == 1) // @transientDefault field should be counted
assert(SingleFieldRecordWithTD.codec.size(SingleFieldRecordWithTD("haha"), defaultIgnoringOutput.opt) == 1)
assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults(), defaultIgnoringOutput.opt) == 2)
assert(CustomRecordWithDefaults.codec.size(CustomRecordWithDefaults("fuu"), defaultIgnoringOutput.opt) == 2)
}
}
Loading

0 comments on commit e950fc6

Please sign in to comment.