From ca4c0d6d8f58ffaeeab53052685e7bbc89446cec Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Tue, 7 Nov 2023 10:28:40 +0100 Subject: [PATCH] Support proto3 optional (#818) * Support proto3 optional * Fix protobuf refined --- docs/protobuf.md | 2 - .../magnolify/protobuf/ProtobufType.scala | 67 +++++-------------- .../magnolify/protobuf/unsafe/package.scala | 4 -- protobuf/src/test/protobuf/Proto3.proto | 25 ++++--- .../protobuf/ProtobufTypeSuite.scala | 62 +++-------------- .../magnolify/refined/RefinedSuite.scala | 31 +++++---- 6 files changed, 58 insertions(+), 133 deletions(-) diff --git a/docs/protobuf.md b/docs/protobuf.md index bffae92a1..9e0010c7b 100644 --- a/docs/protobuf.md +++ b/docs/protobuf.md @@ -42,8 +42,6 @@ implicit val efEnum = ProtobufField.enum[Color.Type, ColorProto] Additional `ProtobufField[T]` instances for `Byte`, `Char`, `Short`, and `UnsafeEnum[T]` are available from `import magnolify.protobuf.unsafe._`. These conversions are unsafe due to potential overflow. -By default nullable type `Option[T]` is not supported when `MsgT` is compiled with Protobuf 3 syntax. This is because Protobuf 3 does not offer a way to check if a field was set, and instead returns `0`, `""`, `false`, etc. when it was not. You can enable Protobuf 3 support for `Option[T]` by adding `import magnolify.protobuf.unsafe.Proto3Option._`. However with this, Scala `None`s will become `0/""/false` in Protobuf and come back as `Some(0/""/false)`. - To use a different field case format in target records, add an optional `CaseMapper` argument to `ProtobufType`. The following example maps `firstName` & `lastName` to `first_name` & `last_name`. ```scala diff --git a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala index fd808b189..e83572813 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/ProtobufType.scala @@ -17,55 +17,31 @@ package magnolify.protobuf import java.lang.reflect.Method -import java.{util => ju} - -import com.google.protobuf.Descriptors.FileDescriptor.Syntax +import java.util as ju import com.google.protobuf.Descriptors.{Descriptor, EnumValueDescriptor, FieldDescriptor} import com.google.protobuf.{ByteString, Message, ProtocolMessageEnum} -import magnolia1._ -import magnolify.shared._ +import magnolia1.* +import magnolify.shared.* import magnolify.shims.FactoryCompat import scala.annotation.implicitNotFound import scala.collection.concurrent import scala.reflect.ClassTag -import scala.jdk.CollectionConverters._ -import scala.collection.compat._ +import scala.jdk.CollectionConverters.* +import scala.collection.compat.* + sealed trait ProtobufType[T, MsgT <: Message] extends Converter[T, MsgT, MsgT] { def apply(r: MsgT): T = from(r) def apply(t: T): MsgT = to(t) } -sealed trait ProtobufOption { - def check(f: ProtobufField.Record[_], syntax: Syntax): Unit -} - -object ProtobufOption { - implicit val proto2Option: ProtobufOption = new ProtobufOption { - override def check(f: ProtobufField.Record[_], syntax: Syntax): Unit = - if (f.hasOptional) { - require( - syntax == Syntax.PROTO2, - "Option[T] support is PROTO2 only, " + - "`import magnolify.protobuf.unsafe.Proto3Option._` to enable PROTO3 support" - ) - } - } - - private[protobuf] class Proto3Option extends ProtobufOption { - override def check(f: ProtobufField.Record[_], syntax: Syntax): Unit = () - } -} - object ProtobufType { - implicit def apply[T: ProtobufField, MsgT <: Message: ClassTag](implicit - po: ProtobufOption - ): ProtobufType[T, MsgT] = ProtobufType(CaseMapper.identity) + implicit def apply[T: ProtobufField, MsgT <: Message: ClassTag]: ProtobufType[T, MsgT] = + ProtobufType(CaseMapper.identity) def apply[T, MsgT <: Message](cm: CaseMapper)(implicit f: ProtobufField[T], - ct: ClassTag[MsgT], - po: ProtobufOption + ct: ClassTag[MsgT] ): ProtobufType[T, MsgT] = f match { case r: ProtobufField.Record[_] => new ProtobufType[T, MsgT] { @@ -74,9 +50,7 @@ object ProtobufType { .getMethod("getDescriptor") .invoke(null) .asInstanceOf[Descriptor] - if (r.hasOptional) { - po.check(r, descriptor.getFile.getSyntax) - } + r.checkDefaults(descriptor)(cm) } @@ -101,7 +75,6 @@ sealed trait ProtobufField[T] extends Serializable { type FromT type ToT - val hasOptional: Boolean val default: Option[T] def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = () @@ -133,7 +106,7 @@ object ProtobufField { new ProtobufField[T] { override type FromT = tc.FromT override type ToT = tc.ToT - override val hasOptional: Boolean = tc.hasOptional + override val default: Option[T] = tc.default.map(x => caseClass.construct(_ => x)) override def from(v: FromT)(cm: CaseMapper): T = caseClass.construct(_ => tc.from(v)(cm)) override def to(v: T, b: Message.Builder)(cm: CaseMapper): ToT = @@ -157,14 +130,11 @@ object ProtobufField { } ) - override val hasOptional: Boolean = caseClass.parameters.exists(_.typeclass.hasOptional) - override def checkDefaults(descriptor: Descriptor)(cm: CaseMapper): Unit = { - val syntax = descriptor.getFile.getSyntax val fields = getFields(descriptor)(cm) caseClass.parameters.foreach { p => val field = fields(p.index) - val protoDefault = if (syntax == Syntax.PROTO2 && field.hasDefaultValue) { + val protoDefault = if (field.hasDefaultValue) { Some(p.typeclass.fromAny(field.getDefaultValue)(cm)) } else { p.typeclass.default @@ -183,13 +153,12 @@ object ProtobufField { override def from(v: Message)(cm: CaseMapper): T = { val descriptor = v.getDescriptorForType - val syntax = descriptor.getFile.getSyntax val fields = getFields(descriptor)(cm) caseClass.construct { p => val field = fields(p.index) - // hasField behaves correctly on PROTO2 optional fields - val value = if (syntax == Syntax.PROTO2 && field.isOptional && !v.hasField(field)) { + // check hasPresence to make sure hasField is meaningful + val value = if (field.hasPresence && !v.hasField(field)) { null } else { v.getField(field) @@ -234,7 +203,6 @@ object ProtobufField { class FromWord[T] { def apply[U](f: T => U)(g: U => T)(implicit pf: ProtobufField[T]): ProtobufField[U] = new Aux[U, pf.FromT, pf.ToT] { - override val hasOptional: Boolean = pf.hasOptional override val default: Option[U] = pf.default.map(f) override def from(v: FromT)(cm: CaseMapper): U = f(pf.from(v)(cm)) override def to(v: U, b: Message.Builder)(cm: CaseMapper): ToT = pf.to(g(v), null)(cm) @@ -243,7 +211,6 @@ object ProtobufField { private def aux[T, From, To](_default: T)(f: From => T)(g: T => To): ProtobufField[T] = new Aux[T, From, To] { - override val hasOptional: Boolean = false override val default: Option[T] = Some(_default) override def from(v: FromT)(cm: CaseMapper): T = f(v) override def to(v: T, b: Message.Builder)(cm: CaseMapper): ToT = g(v) @@ -262,9 +229,9 @@ object ProtobufField { implicit val pfString: ProtobufField[String] = id[String]("") implicit val pfByteString: ProtobufField[ByteString] = id[ByteString](ByteString.EMPTY) implicit val pfByteArray: ProtobufField[Array[Byte]] = - aux2[Array[Byte], ByteString](Array.emptyByteArray)(_.toByteArray)(ByteString.copyFrom) + aux2[Array[Byte], ByteString](Array.emptyByteArray)(b => b.toByteArray)(ByteString.copyFrom) - def `enum`[T, E <: Enum[E] with ProtocolMessageEnum](implicit + implicit def `enum`[T, E <: Enum[E] with ProtocolMessageEnum](implicit et: EnumType[T], ct: ClassTag[E] ): ProtobufField[T] = { @@ -282,7 +249,6 @@ object ProtobufField { implicit def pfOption[T](implicit f: ProtobufField[T]): ProtobufField[Option[T]] = new Aux[Option[T], f.FromT, f.ToT] { - override val hasOptional: Boolean = true override val default: Option[Option[T]] = f.default match { case Some(v) => Some(Some(v)) case None => None @@ -306,7 +272,6 @@ object ProtobufField { fc: FactoryCompat[T, C[T]] ): ProtobufField[C[T]] = new Aux[C[T], ju.List[f.FromT], ju.List[f.ToT]] { - override val hasOptional: Boolean = false override val default: Option[C[T]] = Some(fc.newBuilder.result()) override def from(v: ju.List[f.FromT])(cm: CaseMapper): C[T] = { val b = fc.newBuilder diff --git a/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala b/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala index 3454cb711..bbf6f27f7 100644 --- a/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala +++ b/protobuf/src/main/scala/magnolify/protobuf/unsafe/package.scala @@ -22,10 +22,6 @@ package object unsafe { implicit val pfChar: ProtobufField[Char] = ProtobufField.from[Int](_.toChar)(_.toInt) implicit val pfShort: ProtobufField[Short] = ProtobufField.from[Int](_.toShort)(_.toInt) - object Proto3Option { - implicit val proto3Option: ProtobufOption = new ProtobufOption.Proto3Option - } - implicit def pfUnsafeEnum[T: EnumType]: ProtobufField[UnsafeEnum[T]] = ProtobufField .from[String](s => Option(s).filter(_.nonEmpty).map(UnsafeEnum.from[T]).orNull)( diff --git a/protobuf/src/test/protobuf/Proto3.proto b/protobuf/src/test/protobuf/Proto3.proto index 1e33411d0..8c828cf54 100644 --- a/protobuf/src/test/protobuf/Proto3.proto +++ b/protobuf/src/test/protobuf/Proto3.proto @@ -14,12 +14,18 @@ message FloatsP3 { double d = 2; } -message SingularP3 { +message RequiredP3 { bool b = 1; string s = 2; int32 i = 3; } +message NullableP3 { + optional bool b = 1; + optional string s = 2; + optional int32 i = 3; +} + message RepeatedP3 { repeated bool b = 1; repeated string s = 2; @@ -30,8 +36,9 @@ message NestedP3 { bool b = 1; string s = 2; int32 i = 3; - SingularP3 r = 4; - repeated SingularP3 l = 5; + RequiredP3 r = 4; + optional RequiredP3 o = 5; + repeated RequiredP3 l = 6; } message CollectionP3 { @@ -60,9 +67,9 @@ message EnumsP3 { JavaEnums j = 1; ScalaEnums s = 2; // Enumeration ScalaEnums a = 3; // ADT - JavaEnums jo = 4; - ScalaEnums so = 5; // Enumeration - ScalaEnums ao = 6; // ADT + optional JavaEnums jo = 4; + optional ScalaEnums so = 5; // Enumeration + optional ScalaEnums ao = 6; // ADT repeated JavaEnums jr = 7; repeated ScalaEnums sr = 8; // Enumeration repeated ScalaEnums ar = 9; // ADT @@ -72,9 +79,9 @@ message UnsafeEnumsP3 { string j = 1; string s = 2; string a = 3; - string jo = 4; - string so = 5; - string ao = 6; + optional string jo = 4; + optional string so = 5; + optional string ao = 6; repeated string jr = 7; repeated string sr = 8; repeated string ar = 9; diff --git a/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala b/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala index 198f2ed67..c6cb11637 100644 --- a/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala +++ b/protobuf/src/test/scala/magnolify/protobuf/ProtobufTypeSuite.scala @@ -60,43 +60,26 @@ class ProtobufTypeSuite extends BaseProtobufTypeSuite { test[Floats, FloatsP2] test[Floats, FloatsP3] test[Required, RequiredP2] - test[Required, SingularP3] + test[Required, RequiredP3] test[Nullable, NullableP2] + test[Nullable, NullableP3] test[Repeated, RepeatedP2] test[Repeated, RepeatedP3] test[Nested, NestedP2] - test[NestedNoOption, NestedP3] + test[Nested, NestedP3] test[UnsafeByte, IntegersP2] test[UnsafeChar, IntegersP2] test[UnsafeShort, IntegersP2] test[Collections, CollectionP2] - test[MoreCollections, MoreCollectionP2] test[Collections, CollectionP3] + test[MoreCollections, MoreCollectionP2] test[MoreCollections, MoreCollectionP3] - // PROTO3 removes the notion of require vs optional fields. - // By default `Option[T] are not supported`. - test("Fail PROTO3 Option[T]") { - val msg = "requirement failed: Option[T] support is PROTO2 only, " + - "`import magnolify.protobuf.unsafe.Proto3Option._` to enable PROTO3 support" - interceptMessage[IllegalArgumentException](msg)(ProtobufType[Nullable, SingularP3]) - } - - // Adding `import magnolify.protobuf.unsafe.Proto3Option._` enables PROTO3 `Option[T]` support. - // The new singular field returns default value if unset. - // Hence `None` round trips back as `Some(false/0/"")`. - { - import magnolify.protobuf.unsafe.Proto3Option._ - implicit val eq: Eq[Nullable] = Eq.by { x => - Required(x.b.getOrElse(false), x.i.getOrElse(0), x.s.getOrElse("")) - } - test[Nullable, SingularP3] - } - test("AnyVal") { test[ProtoHasValueClass, IntegersP2] + test[ProtoHasValueClass, IntegersP3] } } @@ -121,23 +104,7 @@ class MoreProtobufTypeSuite extends BaseProtobufTypeSuite { { import Proto3Enums._ - import magnolify.protobuf.unsafe.Proto3Option._ - // Enums are encoded as integers and default to zero value - implicit val eq: Eq[Enums] = Eq.by(e => - ( - e.j, - e.s, - e.a, - e.jo.getOrElse(JavaEnums.Color.RED), - e.so.getOrElse(ScalaEnums.Color.Red), - e.ao.getOrElse(ADT.Red), - e.jr, - e.sr, - e.ar - ) - ) test[Enums, EnumsP3] - // Unsafe enums are encoded as string and default "" is treated as None test[UnsafeEnums, UnsafeEnumsP3] } @@ -160,16 +127,12 @@ class MoreProtobufTypeSuite extends BaseProtobufTypeSuite { import Proto3Enums._ test[DefaultIntegers3, IntegersP3] test[DefaultFloats3, FloatsP3] - test[DefaultRequired3, SingularP3] + test[DefaultRequired3, RequiredP3] test[DefaultEnums3, EnumsP3] } { - import magnolify.protobuf.unsafe.Proto3Option._ - implicit val eq: Eq[DefaultNullable3] = Eq.by { x => - Required(x.b.getOrElse(false), x.i.getOrElse(0), x.s.getOrElse("")) - } - test[DefaultNullable3, SingularP3] + test[DefaultNullable3, NullableP3] } { @@ -178,14 +141,13 @@ class MoreProtobufTypeSuite extends BaseProtobufTypeSuite { testFail[F, DefaultMismatch2](ProtobufType[DefaultMismatch2, DefaultRequiredP2])( "Default mismatch magnolify.protobuf.DefaultMismatch2#i: 321 != 123" ) - testFail[F, DefaultMismatch3](ProtobufType[DefaultMismatch3, SingularP3])( + testFail[F, DefaultMismatch3](ProtobufType[DefaultMismatch3, RequiredP3])( "Default mismatch magnolify.protobuf.DefaultMismatch3#i: 321 != 0" ) } } object Proto2Enums { - // FIXME: for some reasons these implicits fail to resolve without explicit types implicit val efJavaEnum2: ProtobufField[JavaEnums.Color] = ProtobufField.enum[JavaEnums.Color, EnumsP2.JavaEnums] implicit val efScalaEnum2: ProtobufField[ScalaEnums.Color.Type] = @@ -195,7 +157,6 @@ object Proto2Enums { } object Proto3Enums { - // FIXME: for some reasons these implicits fail to resolve without explicit types implicit val efJavaEnum3: ProtobufField[JavaEnums.Color] = ProtobufField.enum[JavaEnums.Color, EnumsP3.JavaEnums] implicit val efScalaEnum3: ProtobufField[ScalaEnums.Color.Type] = @@ -211,13 +172,6 @@ case class UnsafeChar(i: Char, l: Long) case class UnsafeShort(i: Short, l: Long) case class BytesA(b: ByteString) case class BytesB(b: Array[Byte]) -case class NestedNoOption( - b: Boolean, - i: Int, - s: String, - r: Required, - l: List[Required] -) case class DefaultsRequired2( i: Int = 123, diff --git a/refined/src/test/scala/magnolify/refined/RefinedSuite.scala b/refined/src/test/scala/magnolify/refined/RefinedSuite.scala index e696ed9ef..a244b242d 100644 --- a/refined/src/test/scala/magnolify/refined/RefinedSuite.scala +++ b/refined/src/test/scala/magnolify/refined/RefinedSuite.scala @@ -156,27 +156,32 @@ class RefinedSuite extends MagnolifySuite { test("protobuf") { import magnolify.protobuf._ import magnolify.protobuf.Proto3._ - import magnolify.protobuf.unsafe.Proto3Option._ import magnolify.refined.protobuf._ - val tpe1 = ensureSerializable(ProtobufType[ProtoRequired, SingularP3]) - val required = ProtoRequired(true, record.pct, refineV.unsafeFrom(record.uuid.value)) + val tpe1 = ensureSerializable(ProtobufType[ProtoRequired, RequiredP3]) + val required = ProtoRequired( + true, + record.pct, + refineV.unsafeFrom(record.uuid.value) + ) assertEquals(tpe1(tpe1(required)), required) - val tpe2 = ensureSerializable(ProtobufType[ProtoNullable, SingularP3]) - val nullable = - ProtoNullable(Some(true), Some(record.pct), Some(refineV.unsafeFrom(record.url.get.value))) + val tpe2 = ensureSerializable(ProtobufType[ProtoNullable, NullableP3]) + val nullable = ProtoNullable( + Some(true), + Some(record.pct), + Some(refineV.unsafeFrom(record.url.get.value)) + ) assertEquals(tpe2(tpe2(nullable)), nullable) val tpe3 = ensureSerializable(ProtobufType[ProtoRepeated, RepeatedP3]) - val repeated = - ProtoRepeated( - List(true), - List(record.pct), - List(refineV.unsafeFrom("US"), refineV.unsafeFrom("UK")) - ) + val repeated = ProtoRepeated( + List(true), + List(record.pct), + List(refineV.unsafeFrom("US"), refineV.unsafeFrom("UK")) + ) assertEquals(tpe3(tpe3(repeated)), repeated) - val bad = SingularP3.newBuilder().setB(true).setI(42).setS("foo").build() + val bad = NullableP3.newBuilder().setB(true).setI(42).setS("foo").build() val msg = """Both predicates of (isValidUrl("foo") || "foo".matches("^$")) failed. """ + """Left: Url predicate failed: URI is not absolute """ + """Right: Predicate failed: "foo".matches("^$")."""