Skip to content

Commit

Permalink
Initial preparations to support GRPC Transcoding (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-vovk authored Dec 8, 2024
1 parent 8960a07 commit 68513b3
Show file tree
Hide file tree
Showing 17 changed files with 213 additions and 116 deletions.
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ lazy val core = project
"io.grpc" % "grpc-protobuf" % Versions.grpc,
"io.grpc" % "grpc-inprocess" % Versions.grpc,

"com.thesamet.scalapb.common-protos" %% "proto-google-common-protos-scalapb_0.11" % "2.9.6-0" % "protobuf",
"com.thesamet.scalapb.common-protos" %% "proto-google-common-protos-scalapb_0.11" % "2.9.6-0",

"com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",

"org.http4s" %% "http4s-dsl" % Versions.http4s,
Expand Down
1 change: 1 addition & 0 deletions conformance/src/main/protobuf/connectrpc/package.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ option (scalapb.options) = {
enum_value_naming: CAMEL_CASE
enum_strip_prefix: true
preserve_unknown_fields: false
scala3_sources: true
};
8 changes: 0 additions & 8 deletions conformance/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@

<configuration>

<!-- <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">-->
<!-- <encoder>-->
<!-- <pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>-->
<!-- </encoder>-->
<!-- </appender>-->

<!-- <appender name="NOOP" class="ch.qos.logback.core.helpers.NOPAppender"/>-->

<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>${LOGS_PATH}/out.log</file>
<append>true</append>
Expand Down
1 change: 1 addition & 0 deletions core/src/main/protobuf/connectrpc/package.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ option (scalapb.options) = {
enum_value_naming: CAMEL_CASE
enum_strip_prefix: true
preserve_unknown_fields: false
scala3_sources: true
};
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package org.ivovk.connect_rpc_scala

import cats.data.EitherT
import cats.effect.Async
import cats.implicits.*
import io.grpc.*
import io.grpc.MethodDescriptor.MethodType
import org.http4s.dsl.Http4sDsl
import org.http4s.{Header, MediaType, MessageFailure, Method, Response}
import org.http4s.{Header, MessageFailure, Response}
import org.ivovk.connect_rpc_scala.Mappings.*
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodName, MethodRegistry}
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry}
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
import org.ivovk.connect_rpc_scala.http.RequestEntity
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given
import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry}
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
import org.slf4j.{Logger, LoggerFactory}
import scalapb.GeneratedMessage

Expand All @@ -21,8 +20,6 @@ import scala.jdk.CollectionConverters.*
import scala.util.chaining.*

class ConnectHandler[F[_] : Async](
codecRegistry: MessageCodecRegistry[F],
methodRegistry: MethodRegistry,
channel: Channel,
httpDsl: Http4sDsl[F],
treatTrailersAsHeaders: Boolean,
Expand All @@ -33,53 +30,22 @@ class ConnectHandler[F[_] : Async](
private val logger: Logger = LoggerFactory.getLogger(getClass)

def handle(
httpMethod: Method,
contentType: Option[MediaType],
entity: RequestEntity[F],
grpcMethodName: MethodName,
): F[Response[F]] = {
val eitherT = for
given MessageCodec[F] <- EitherT.fromOptionM(
contentType.flatMap(codecRegistry.byContentType).pure[F],
UnsupportedMediaType(s"Unsupported content-type ${contentType.show}. " +
s"Supported content types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}")
)

method <- EitherT.fromOptionM(
methodRegistry.get(grpcMethodName).pure[F],
NotFound(connectrpc.Error(
code = io.grpc.Status.NOT_FOUND.toConnectCode,
message = s"Method not found: ${grpcMethodName.fullyQualifiedName}".some
))
)

_ <- EitherT.cond[F](
// Support GET-requests for all methods until https://github.com/scalapb/ScalaPB/pull/1774 is merged
httpMethod == Method.POST || (httpMethod == Method.GET && method.descriptor.isSafe) || true,
(),
Forbidden(connectrpc.Error(
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
message = s"Only POST-requests are allowed for method: ${grpcMethodName.fullyQualifiedName}".some
req: RequestEntity[F],
method: MethodRegistry.Entry,
)(using MessageCodec[F]): F[Response[F]] = {
method.descriptor.getType match
case MethodType.UNARY =>
handleUnary(req, method)
case unsupported =>
NotImplemented(connectrpc.Error(
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
message = s"Unsupported method type: $unsupported".some
))
).leftSemiflatMap(identity)

response <- method.descriptor.getType match
case MethodType.UNARY =>
EitherT.right(handleUnary(method, entity, channel))
case unsupported =>
EitherT.left(NotImplemented(connectrpc.Error(
code = io.grpc.Status.UNIMPLEMENTED.toConnectCode,
message = s"Unsupported method type: $unsupported".some
)))
yield response

eitherT.merge
}

private def handleUnary(
method: MethodRegistry.Entry,
req: RequestEntity[F],
channel: Channel
method: MethodRegistry.Entry,
)(using MessageCodec[F]): F[Response[F]] = {
if (logger.isTraceEnabled) {
// Used in conformance tests
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package org.ivovk.connect_rpc_scala

import cats.Endo
import cats.data.OptionT
import cats.effect.{Async, Resource}
import cats.implicits.*
import io.grpc.{ManagedChannelBuilder, ServerBuilder, ServerServiceDefinition}
import org.http4s.dsl.Http4sDsl
import org.http4s.{HttpApp, HttpRoutes, Method, Uri}
import org.http4s.{HttpApp, HttpRoutes, MediaType, Method, Response, Uri}
import org.ivovk.connect_rpc_scala.grpc.*
import org.ivovk.connect_rpc_scala.http.*
import org.ivovk.connect_rpc_scala.http.QueryParams.*
import org.ivovk.connect_rpc_scala.http.codec.{JsonMessageCodec, JsonMessageCodecBuilder, MessageCodecRegistry, ProtoMessageCodec}
import org.ivovk.connect_rpc_scala.http.codec.*

import java.util.concurrent.Executor
import scala.concurrent.ExecutionContext
Expand Down Expand Up @@ -124,29 +125,56 @@ final class ConnectRouteBuilder[F[_] : Async] private(
)
yield
val handler = new ConnectHandler(
codecRegistry,
methodRegistry,
channel,
httpDsl,
treatTrailersAsHeaders,
)

HttpRoutes.of[F] {
case req@Method.GET -> `pathPrefix` / serviceName / methodName :? EncodingQP(contentType) +& MessageQP(message) =>
val grpcMethod = MethodName(serviceName, methodName)
val entity = RequestEntity[F](message, req.headers)

handler.handle(Method.GET, contentType.some, entity, grpcMethod)
case req@Method.POST -> `pathPrefix` / serviceName / methodName =>
val grpcMethod = MethodName(serviceName, methodName)
val contentType = req.contentType.map(_.mediaType)
val entity = RequestEntity[F](req)

handler.handle(Method.POST, contentType, entity, grpcMethod)
HttpRoutes[F] {
case req@Method.GET -> `pathPrefix` / service / method :? EncodingQP(mediaType) +& MessageQP(message) =>
OptionT.fromOption[F](methodRegistry.get(service, method))
// Temporary support GET-requests for all methods,
// until https://github.com/scalapb/ScalaPB/pull/1774 is merged
.filter(_.descriptor.isSafe || true)
.semiflatMap { methodEntry =>
withCodec(httpDsl, codecRegistry, mediaType.some) { codec =>
val entity = RequestEntity[F](message, req.headers)

handler.handle(entity, methodEntry)(using codec)
}
}
case req@Method.POST -> `pathPrefix` / service / method =>
OptionT.fromOption[F](methodRegistry.get(service, method))
.semiflatMap { methodEntry =>
withCodec(httpDsl, codecRegistry, req.contentType.map(_.mediaType)) { codec =>
val entity = RequestEntity[F](req.body, req.headers)

handler.handle(entity, methodEntry)(using codec)
}
}
case _ =>
OptionT.none
}
}

def build: Resource[F, HttpApp[F]] =
buildRoutes.map(_.orNotFound)

private def withCodec(
dsl: Http4sDsl[F],
registry: MessageCodecRegistry[F],
mediaType: Option[MediaType]
)(r: MessageCodec[F] => F[Response[F]]): F[Response[F]] = {
import dsl.*

mediaType.flatMap(registry.byMediaType) match {
case Some(codec) => r(codec)
case None =>
val message = s"Unsupported media-type ${mediaType.show}. " +
s"Supported media types: ${MediaTypes.allSupported.map(_.show).mkString(", ")}"

UnsupportedMediaType(message)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ type Service = String
type Method = String

object MethodName {
def apply(descriptor: MethodDescriptor[_, _]): MethodName =
def from(descriptor: MethodDescriptor[_, _]): MethodName =
MethodName(descriptor.getServiceName, descriptor.getBareMethodName)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package org.ivovk.connect_rpc_scala.grpc

import com.google.api.AnnotationsProto
import com.google.api.http.HttpRule
import io.grpc.{MethodDescriptor, ServerMethodDefinition, ServerServiceDefinition}
import scalapb.grpc.ConcreteProtoMethodDescriptorSupplier
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}

import scala.jdk.CollectionConverters.*

object MethodRegistry {

case class Entry(
name: MethodName,
requestMessageCompanion: GeneratedMessageCompanion[GeneratedMessage],
httpRule: Option[HttpRule],
descriptor: MethodDescriptor[GeneratedMessage, GeneratedMessage],
)

Expand All @@ -30,23 +35,37 @@ object MethodRegistry {
val requestCompanion = companionField.get(requestMarshaller)
.asInstanceOf[GeneratedMessageCompanion[GeneratedMessage]]

val methodEntry = Entry(
val httpRule = extractHttpRule(methodDescriptor)

Entry(
name = MethodName.from(methodDescriptor),
requestMessageCompanion = requestCompanion,
httpRule = httpRule,
descriptor = methodDescriptor,
)

MethodName(methodDescriptor) -> methodEntry
}
.groupMapReduce((mn, _) => mn.service)((mn, m) => Map(mn.method -> m))(_ ++ _)
.groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _)

new MethodRegistry(entries)
}

private def extractHttpRule(methodDescriptor: MethodDescriptor[_, _]): Option[HttpRule] = {
methodDescriptor.getSchemaDescriptor match
case sd: ConcreteProtoMethodDescriptorSupplier =>
val fields = sd.getMethodDescriptor.getOptions.getUnknownFields
val fieldNumber = AnnotationsProto.http.getNumber

if fields.hasField(fieldNumber) then
Some(HttpRule.parseFrom(fields.getField(fieldNumber).getLengthDelimitedList.get(0).toByteArray))
else None
case _ => None
}

}

class MethodRegistry private(entries: Map[Service, Map[Method, MethodRegistry.Entry]]) {

def get(methodName: MethodName): Option[MethodRegistry.Entry] =
entries.getOrElse(methodName.service, Map.empty).get(methodName.method)
def get(service: Service, method: Method): Option[MethodRegistry.Entry] =
entries.getOrElse(service, Map.empty).get(method)

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ package org.ivovk.connect_rpc_scala.http
import cats.MonadThrow
import fs2.Stream
import org.http4s.headers.{`Content-Encoding`, `Content-Type`}
import org.http4s.{Charset, ContentCoding, Headers, Media}
import org.http4s.{Charset, ContentCoding, Headers}
import org.ivovk.connect_rpc_scala.http.Headers.`Connect-Timeout-Ms`
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}

object RequestEntity {
def apply[F[_]](m: Media[F]): RequestEntity[F] =
RequestEntity(m.body, m.headers)
}

/**
* Encoded message and headers with the knowledge how this message can be decoded.
Expand All @@ -23,7 +19,7 @@ case class RequestEntity[F[_]](
headers: Headers,
) {

def contentType: Option[`Content-Type`] =
private def contentType: Option[`Content-Type`] =
headers.get[`Content-Type`]

def charset: Charset =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
package org.ivovk.connect_rpc_scala.http.codec

import cats.Applicative
import org.http4s.headers.`Content-Type`
import org.http4s.{DecodeResult, Entity, EntityDecoder, EntityEncoder, MediaRange, MediaType}
import org.http4s.{DecodeResult, Entity, EntityEncoder, MediaType}
import org.ivovk.connect_rpc_scala.http.RequestEntity
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}

object MessageCodec {
given [F[_] : Applicative, A <: Message](using codec: MessageCodec[F], cmp: Companion[A]): EntityDecoder[F, A] =
EntityDecoder.decodeBy(MediaRange.`*/*`)(m => codec.decode(RequestEntity(m)))

given [F[_], A <: Message](using codec: MessageCodec[F]): EntityEncoder[F, A] =
EntityEncoder.encodeBy(`Content-Type`(codec.mediaType))(codec.encode)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ object MessageCodecRegistry {

class MessageCodecRegistry[F[_]] private(encoders: Map[MediaType, MessageCodec[F]]) {

def byContentType(mediaType: MediaType): Option[MessageCodec[F]] = encoders.get(mediaType)
def byMediaType(mediaType: MediaType): Option[MessageCodec[F]] = encoders.get(mediaType)

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
syntax = "proto3";

package org.ivovk.connect_rpc_scala.test;
package test;

service TestService {
rpc Add(AddRequest) returns (AddResponse) {}
Expand Down
24 changes: 24 additions & 0 deletions core/src/test/protobuf/test/MethodRegistryTest.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
syntax = "proto3";

package test;

import "google/api/annotations.proto";

service MethodRegistryTestService {
rpc SimpleMethod(SimpleMethodRequest) returns (SimpleMethodResponse) {}

rpc HttpAnnotationMethod(HttpAnnotationMethodRequest) returns (HttpAnnotationMethodResponse) {
option (google.api.http) = {
post: "/v1/test/http_annotation_method"
body: "*"
};
}
}

message SimpleMethodRequest {}

message SimpleMethodResponse {}

message HttpAnnotationMethodRequest {}

message HttpAnnotationMethodResponse {}
15 changes: 15 additions & 0 deletions core/src/test/protobuf/test/package.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

package test;

import "scalapb/scalapb.proto";

option (scalapb.options) = {
scope: PACKAGE
flat_package: false
lenses: false
enum_value_naming: CAMEL_CASE
enum_strip_prefix: true
preserve_unknown_fields: false
scala3_sources: true
};
Loading

0 comments on commit 68513b3

Please sign in to comment.