Skip to content

Commit

Permalink
Initial work on the GRPC Transcoding (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-vovk authored Dec 12, 2024
1 parent 3a0e68f commit 73a632c
Show file tree
Hide file tree
Showing 12 changed files with 474 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import org.ivovk.connect_rpc_scala.grpc.*
import org.ivovk.connect_rpc_scala.http.*
import org.ivovk.connect_rpc_scala.http.QueryParams.*
import org.ivovk.connect_rpc_scala.http.codec.*
import scalapb.GeneratedMessage

import java.util.concurrent.Executor
import scala.concurrent.ExecutionContext
Expand Down Expand Up @@ -108,8 +109,9 @@ final class ConnectRouteBuilder[F[_] : Async] private(
val httpDsl = Http4sDsl[F]
import httpDsl.*

val jsonCodec = customJsonCodec.getOrElse(JsonMessageCodecBuilder[F]().build)
val codecRegistry = MessageCodecRegistry[F](
customJsonCodec.getOrElse(JsonMessageCodecBuilder[F]().build),
jsonCodec,
ProtoMessageCodec[F](),
)

Expand All @@ -124,13 +126,13 @@ final class ConnectRouteBuilder[F[_] : Async] private(
waitForShutdown,
)
yield
val handler = new ConnectHandler(
val connectHandler = new ConnectHandler(
channel,
httpDsl,
treatTrailersAsHeaders,
)

HttpRoutes[F] {
val connectRoutes = HttpRoutes[F] {
case req@Method.GET -> `pathPrefix` / service / method :? EncodingQP(mediaType) +& MessageQP(message) =>
OptionT.fromOption[F](methodRegistry.get(service, method))
// Temporary support GET-requests for all methods,
Expand All @@ -140,7 +142,7 @@ final class ConnectRouteBuilder[F[_] : Async] private(
withCodec(httpDsl, codecRegistry, mediaType.some) { codec =>
val entity = RequestEntity[F](message, req.headers)

handler.handle(entity, methodEntry)(using codec)
connectHandler.handle(entity, methodEntry)(using codec)
}
}
case req@Method.POST -> `pathPrefix` / service / method =>
Expand All @@ -149,12 +151,41 @@ final class ConnectRouteBuilder[F[_] : Async] private(
withCodec(httpDsl, codecRegistry, req.contentType.map(_.mediaType)) { codec =>
val entity = RequestEntity[F](req.body, req.headers)

handler.handle(entity, methodEntry)(using codec)
connectHandler.handle(entity, methodEntry)(using codec)
}
}
case _ =>
OptionT.none
}

val transcodingUrlMatcher = TranscodingUrlMatcher.create[F](
methodRegistry.all,
pathPrefix,
)
val transcodingHandler = new TranscodingHandler(
channel,
httpDsl,
treatTrailersAsHeaders,
)

val transcodingRoutes = HttpRoutes[F] { req =>
OptionT.fromOption[F](transcodingUrlMatcher.matchesRequest(req))
.semiflatMap { case MatchedRequest(method, json) =>
given MessageCodec[F] = jsonCodec
given EncodeOptions = EncodeOptions(None)

RequestEntity[F](req.body, req.headers)
.as[GeneratedMessage](method.requestMessageCompanion)
.flatMap { entity =>
val entity2 = jsonCodec.parser.fromJson[GeneratedMessage](json)(method.requestMessageCompanion)
val finalEntity = method.requestMessageCompanion.parseFrom(entity.toByteArray ++ entity2.toByteArray)

transcodingHandler.handleUnary(finalEntity, req.headers, method)
}
}
}

connectRoutes <+> transcodingRoutes
}

def build: Resource[F, HttpApp[F]] =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package org.ivovk.connect_rpc_scala

import cats.effect.Async
import cats.implicits.*
import io.grpc.*
import org.http4s.dsl.Http4sDsl
import org.http4s.{Header, Headers, MessageFailure, Response}
import org.ivovk.connect_rpc_scala.Mappings.*
import org.ivovk.connect_rpc_scala.grpc.{ClientCalls, GrpcHeaders, MethodRegistry}
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
import org.ivovk.connect_rpc_scala.http.RequestEntity
import org.ivovk.connect_rpc_scala.http.RequestEntity.*
import org.ivovk.connect_rpc_scala.http.codec.{EncodeOptions, MessageCodec}
import org.slf4j.{Logger, LoggerFactory}
import scalapb.GeneratedMessage

import scala.concurrent.duration.*
import scala.jdk.CollectionConverters.*
import scala.util.chaining.*

object TranscodingHandler {

extension [F[_]](response: Response[F]) {
def withMessage(entity: GeneratedMessage)(using codec: MessageCodec[F], options: EncodeOptions): Response[F] =
codec.encode(entity, options).applyTo(response)
}

}

class TranscodingHandler[F[_] : Async](
channel: Channel,
httpDsl: Http4sDsl[F],
treatTrailersAsHeaders: Boolean,
) {

import TranscodingHandler.*
import httpDsl.*

private val logger: Logger = LoggerFactory.getLogger(getClass)

def handleUnary(
message: GeneratedMessage,
headers: Headers,
method: MethodRegistry.Entry,
)(using MessageCodec[F], EncodeOptions): F[Response[F]] = {
if (logger.isTraceEnabled) {
// Used in conformance tests
headers.get[`X-Test-Case-Name`] match {
case Some(header) =>
logger.trace(s">>> Test Case: ${header.value}")
case None => // ignore
}
}

if (logger.isTraceEnabled) {
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}")
}

val callOptions = CallOptions.DEFAULT
.pipe(
headers.timeout match {
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
case None => identity
}
)

ClientCalls
.asyncUnaryCall(
channel,
method.descriptor,
callOptions,
headers.toMetadata,
message
)
.map { response =>
val headers = response.headers.toHeaders() ++
response.trailers.toHeaders(trailing = !treatTrailersAsHeaders)

if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}

Response(Ok, headers = headers).withMessage(response.value)
}
.handleError { e =>
val grpcStatus = e match {
case e: StatusException =>
e.getStatus.getDescription match {
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
case _ => e.getStatus
}
case e: StatusRuntimeException => e.getStatus
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
case _ => io.grpc.Status.INTERNAL
}

val (message, metadata) = e match {
case e: StatusRuntimeException => (Option(e.getStatus.getDescription), e.getTrailers)
case e: StatusException => (Option(e.getStatus.getDescription), e.getTrailers)
case e => (Option(e.getMessage), new Metadata())
}

val httpStatus = grpcStatus.toHttpStatus
val connectCode = grpcStatus.toConnectCode

// Should be called before converting metadata to headers
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
.fold(Seq.empty)(_.asScala.toSeq)

val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)

if (logger.isTraceEnabled) {
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
logger.trace(s"<<< Error processing request", e)
}

Response[F](httpStatus, headers = headers).withMessage(connectrpc.Error(
code = connectCode,
message = message,
details = details
))
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.ivovk.connect_rpc_scala

import cats.implicits.*
import com.google.api.HttpRule
import org.http4s.{Method, Request, Uri}
import org.ivovk.connect_rpc_scala
import org.ivovk.connect_rpc_scala.grpc.MethodRegistry
import org.json4s.JsonAST.{JField, JObject}
import org.json4s.{JString, JValue}

import scala.util.boundary
import scala.util.boundary.break

case class MatchedRequest(method: MethodRegistry.Entry, json: JValue)

object TranscodingUrlMatcher {
case class Entry(
method: MethodRegistry.Entry,
httpMethodMatcher: Method => Boolean,
pattern: Uri.Path,
)

def create[F[_]](
methods: Seq[MethodRegistry.Entry],
pathPrefix: Uri.Path,
): TranscodingUrlMatcher[F] = {
val entries = methods.flatMap { method =>
method.httpRule match {
case Some(httpRule) =>
val (httpMethod, pattern) = extractMethodAndPattern(httpRule)

val httpMethodMatcher: Method => Boolean = m => httpMethod.forall(_ == m)

Entry(
method,
httpMethodMatcher,
pathPrefix.dropEndsWithSlash.concat(pattern.toRelative)
).some
case None => none
}
}

new TranscodingUrlMatcher(
entries,
)
}

private def extractMethodAndPattern(rule: HttpRule): (Option[Method], Uri.Path) = {
val (method, str) = rule.getPatternCase match
case HttpRule.PatternCase.GET => (Method.GET.some, rule.getGet)
case HttpRule.PatternCase.PUT => (Method.PUT.some, rule.getPut)
case HttpRule.PatternCase.POST => (Method.POST.some, rule.getPost)
case HttpRule.PatternCase.DELETE => (Method.DELETE.some, rule.getDelete)
case HttpRule.PatternCase.PATCH => (Method.PATCH.some, rule.getPatch)
case HttpRule.PatternCase.CUSTOM => (none, rule.getCustom.getPath)
case other => throw new RuntimeException(s"Unsupported pattern case $other (Rule: $rule)")

val path = Uri.Path.unsafeFromString(str).dropEndsWithSlash

(method, path)
}
}

class TranscodingUrlMatcher[F[_]](
entries: Seq[TranscodingUrlMatcher.Entry],
) {

import org.ivovk.connect_rpc_scala.http.json.JsonProcessing.*

def matchesRequest(req: Request[F]): Option[MatchedRequest] = boundary {
entries.foreach { entry =>
if (entry.httpMethodMatcher(req.method)) {
matchExtract(entry.pattern, req.uri.path) match {
case Some(pathParams) =>
val queryParams = req.uri.query.toList.map((k, v) => k -> JString(v.getOrElse("")))

val merged = mergeFields(groupFields(pathParams), groupFields(queryParams))

break(Some(MatchedRequest(entry.method, JObject(merged))))
case None => // continue
}
}
}

none
}

/**
* Matches path segments with pattern segments and extracts variables from the path.
* Returns None if the path does not match the pattern.
*/
private def matchExtract(pattern: Uri.Path, path: Uri.Path): Option[List[JField]] = boundary {
if path.segments.length != pattern.segments.length then boundary.break(none)

path.segments.indices
.foldLeft(List.empty[JField]) { (state, idx) =>
val pathSegment = path.segments(idx)
val patternSegment = pattern.segments(idx)

if isVariable(patternSegment) then
val varName = patternSegment.encoded.substring(1, patternSegment.encoded.length - 1)

(varName -> JString(pathSegment.encoded)) :: state
else if pathSegment != patternSegment then
boundary.break(none)
else state
}
.some
}

private def isVariable(segment: Uri.Path.Segment): Boolean = {
val enc = segment.encoded
val length = enc.length

length > 2 && enc(0) == '{' && enc(length - 1) == '}'
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.ivovk.connect_rpc_scala.grpc

import com.google.api.AnnotationsProto
import com.google.api.http.HttpRule
import com.google.api.{AnnotationsProto, HttpRule}
import io.grpc.{MethodDescriptor, ServerMethodDefinition, ServerServiceDefinition}
import scalapb.grpc.ConcreteProtoMethodDescriptorSupplier
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
Expand Down Expand Up @@ -44,7 +43,6 @@ object MethodRegistry {
descriptor = methodDescriptor,
)
}
.groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _)

new MethodRegistry(entries)
}
Expand All @@ -63,9 +61,16 @@ object MethodRegistry {

}

class MethodRegistry private(entries: Map[Service, Map[Method, MethodRegistry.Entry]]) {
class MethodRegistry private(entries: Seq[MethodRegistry.Entry]) {

private val serviceMethodEntries: Map[Service, Map[Method, MethodRegistry.Entry]] = entries
.groupMapReduce(_.name.service)(e => Map(e.name.method -> e))(_ ++ _)

def all: Seq[MethodRegistry.Entry] = entries

def get(name: MethodName): Option[MethodRegistry.Entry] = get(name.service, name.method)

def get(service: Service, method: Method): Option[MethodRegistry.Entry] =
entries.getOrElse(service, Map.empty).get(method)
serviceMethodEntries.getOrElse(service, Map.empty).get(method)

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ import org.ivovk.connect_rpc_scala.http.codec.MessageCodec
import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Companion}


object RequestEntity {
extension (h: Headers) {
def timeout: Option[Long] =
h.get[`Connect-Timeout-Ms`].map(_.value)
}
}

/**
* Encoded message and headers with the knowledge how this message can be decoded.
* Similar to [[org.http4s.Media]], but extends the message with `String` type representing message that is
Expand All @@ -18,6 +25,7 @@ case class RequestEntity[F[_]](
message: String | Stream[F, Byte],
headers: Headers,
) {
import RequestEntity.*

private def contentType: Option[`Content-Type`] =
headers.get[`Content-Type`]
Expand All @@ -28,8 +36,7 @@ case class RequestEntity[F[_]](
def encoding: Option[ContentCoding] =
headers.get[`Content-Encoding`].map(_.contentCoding)

def timeout: Option[Long] =
headers.get[`Connect-Timeout-Ms`].map(_.value)
def timeout: Option[Long] = headers.timeout

def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
M.rethrow(codec.decode(this)(using cmp).value)
Expand Down
Loading

0 comments on commit 73a632c

Please sign in to comment.