diff --git a/README.md b/README.md index 7b5bd71..d3d2778 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ Current status: 11/79 tests pass. Known issues: -* `google.protobuf.Any` serialization doesn't follow Connect-RPC +* `google.protobuf.Any` serialization [doesn't follow](https://github.com/connectrpc/conformance/issues/948) Connect-RPC spec: [#32](https://github.com/igor-vovk/connect-rpc-scala/issues/32) ## Future improvements diff --git a/conformance/src/main/scala/org/ivovk/connect_rpc_scala/conformance/Main.scala b/conformance/src/main/scala/org/ivovk/connect_rpc_scala/conformance/Main.scala index b5e9e97..8515d03 100644 --- a/conformance/src/main/scala/org/ivovk/connect_rpc_scala/conformance/Main.scala +++ b/conformance/src/main/scala/org/ivovk/connect_rpc_scala/conformance/Main.scala @@ -4,8 +4,9 @@ import cats.effect.{IO, IOApp, Sync} import com.comcast.ip4s.{Port, host, port} import connectrpc.conformance.v1.{ConformanceServiceFs2GrpcTrailers, ServerCompatRequest, ServerCompatResponse} import org.http4s.ember.server.EmberServerBuilder +import org.ivovk.connect_rpc_scala.http.ConnectAnyFormat import org.ivovk.connect_rpc_scala.ConnectRouteBuilder -import scalapb.json4s.TypeRegistry +import scalapb.json4s.{AnyFormat, JsonFormat, TypeRegistry} import java.io.InputStream import java.nio.ByteBuffer @@ -45,7 +46,13 @@ object Main extends IOApp.Simple { .addMessage[connectrpc.conformance.v1.UnaryRequest] .addMessage[connectrpc.conformance.v1.IdempotentUnaryRequest] .addMessage[connectrpc.conformance.v1.ConformancePayload.RequestInfo] - ) + ).withFormatRegistry( + JsonFormat.DefaultRegistry + .registerMessageFormatter[com.google.protobuf.any.Any]( + ConnectAnyFormat.anyWriter, + AnyFormat.anyParser + ) + ) } .build diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala index ea09a6c..33a6933 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala @@ -180,7 +180,7 @@ class ConnectHandler[F[_] : Async]( Response[F](httpStatus).withEntity(connectrpc.Error( code = connectCode, message = messageWithDetails.map(_._1), - details = Seq.empty // details + details = details )) } } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala index 62170e7..e9c66b0 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/Mappings.scala @@ -1,11 +1,13 @@ package org.ivovk.connect_rpc_scala +import com.google.protobuf.struct.{ListValue, NullValue, Struct, Value} import io.grpc.{Metadata, Status} import org.http4s.{Header, Headers} import org.typelevel.ci.CIString import scalapb.GeneratedMessage +import scalapb.descriptors.* -object Mappings extends HeaderMappings, StatusCodeMappings, AnyMappings +object Mappings extends HeaderMappings, StatusCodeMappings, ProtoMappings trait HeaderMappings { @@ -103,7 +105,7 @@ trait StatusCodeMappings { } -trait AnyMappings { +trait ProtoMappings { extension [T <: GeneratedMessage](t: T) { def toProtoAny: com.google.protobuf.any.Any = @@ -111,6 +113,27 @@ trait AnyMappings { typeUrl = "type.googleapis.com/" + t.companion.scalaDescriptor.fullName, value = t.toByteString ) + + def toProtoStruct: Struct = toValue(t.toPMessage).kind match { + case Value.Kind.StructValue(struct) => struct + case _ => throw new IllegalArgumentException("Expected a struct value") + } + } + + def toValue(value: PValue): Value = { + value match { + case PEmpty => Value.of(Value.Kind.NullValue(NullValue.NULL_VALUE)) + case PInt(value) => Value.of(Value.Kind.NumberValue(value.toDouble)) + case PLong(value) => Value.of(Value.Kind.NumberValue(value.toDouble)) + case PString(value) => Value.of(Value.Kind.StringValue(value)) + case PDouble(value) => Value.of(Value.Kind.NumberValue(value)) + case PFloat(value) => Value.of(Value.Kind.NumberValue(value.toDouble)) + case PByteString(value) => Value.of(Value.Kind.StringValue(value.toStringUtf8)) + case PBoolean(value) => Value.of(Value.Kind.BoolValue(value)) + case PEnum(value) => Value.of(Value.Kind.StringValue(value.index.toString)) + case PMessage(value) => Value.of(Value.Kind.StructValue(Struct(value.map((k, v) => k.name -> toValue(v))))) + case PRepeated(value) => Value.of(Value.Kind.ListValue(ListValue(value.map(toValue)))) + } } } \ No newline at end of file diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/ConnectAnyFormat.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/ConnectAnyFormat.scala new file mode 100644 index 0000000..2df2b5b --- /dev/null +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/json/ConnectAnyFormat.scala @@ -0,0 +1,83 @@ +package org.ivovk.connect_rpc_scala.http.json + +import com.google.protobuf.any.Any as PBAny +import org.json4s.JsonAST.{JObject, JString, JValue} +import scalapb.json4s.Printer + +import scala.language.existentials + +object ConnectAnyFormat { + // Messages that have special representation are parsed/serialized from a `value` field of the + // any. + private val SpecialValues: Set[scalapb.GeneratedMessageCompanion[_]] = ( + com.google.protobuf.struct.StructProto.messagesCompanions ++ + com.google.protobuf.wrappers.WrappersProto.messagesCompanions ++ + Seq( + com.google.protobuf.any.Any, + com.google.protobuf.duration.Duration, + com.google.protobuf.timestamp.Timestamp, + com.google.protobuf.field_mask.FieldMask + ) + ).toSet + + val anyWriter: (Printer, PBAny) => JValue = { case (printer, any) => + // Find the companion so it can be used to JSON-serialize the message. Perhaps this can be circumvented by + // including the original GeneratedMessage with the Any (at least in memory). + val cmp = printer.typeRegistry + .findType(any.typeUrl) + .getOrElse( + throw new IllegalStateException( + s"Unknown type ${any.typeUrl} in Any. Add a TypeRegistry that supports this type to the Printer." + ) + ) + + // Unpack the message... + val message = any.unpack(cmp) + + // ... and add the @type marker to the resulting JSON + if (SpecialValues.contains(cmp)) + JObject( + "@type" -> JString(any.typeUrl), + "value" -> printer.toJson(message) + ) + else + printer.toJson(message) match { + case JObject(fields) => + JObject(("@type" -> JString(any.typeUrl)) +: fields) + case value => + // Safety net, this shouldn't happen + throw new IllegalStateException( + s"Message of type ${any.typeUrl} emitted non-object JSON: $value" + ) + } + } + +// val anyParser: (Parser, JValue) => PBAny = { +// case (parser, obj @ JObject(fields)) => +// obj \ "@type" match { +// case JString(typeUrl) => +// val cmp = parser.typeRegistry +// .findType(typeUrl) +// .getOrElse( +// throw new JsonFormatException( +// s"Unknown type ${typeUrl} in Any. Add a TypeRegistry that supports this type to the Parser." +// ) +// ) +// val input = if (SpecialValues.contains(cmp)) obj \ "value" else obj +// val message = parser.fromJson(input, true)(cmp) +// PBAny(typeUrl = typeUrl, value = message.toByteString) +// +// case JNothing => +// throw new JsonFormatException(s"Missing type url when parsing $obj") +// +// case unknown => +// throw new JsonFormatException( +// s"Expected string @type field, got $unknown" +// ) +// } +// +// case (_, unknown) => +// throw new JsonFormatException(s"Expected an object, got $unknown") +// } +} +