Skip to content

Commit

Permalink
Optimize GRPC method lookup; optimizations in protobuf decoding (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-vovk authored Dec 1, 2024
1 parent 1816316 commit dfb80df
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import cats.Endo
import cats.data.EitherT
import cats.effect.Async
import cats.implicits.*
import io.grpc.{CallOptions, Channel, ClientInterceptors, Metadata, StatusException, StatusRuntimeException}
import io.grpc.*
import io.grpc.MethodDescriptor.MethodType
import io.grpc.stub.MetadataUtils
import org.http4s.dsl.Http4sDsl
import org.http4s.{MediaType, Method, Response}
import org.ivovk.connect_rpc_scala.Mappings.*
import org.ivovk.connect_rpc_scala.grpc.{MethodName, MethodRegistry}
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
import org.ivovk.connect_rpc_scala.http.MessageCodec.given
import org.ivovk.connect_rpc_scala.http.{MediaTypes, MessageCodec, MessageCodecRegistry, RequestEntity}
Expand All @@ -35,7 +36,7 @@ class ConnectHandler[F[_]: Async](
httpMethod: Method,
contentType: Option[MediaType],
entity: RequestEntity[F],
grpcMethodName: String,
grpcMethodName: MethodName,
): F[Response[F]] = {
val eitherT = for
given MessageCodec[F] <- EitherT.fromOptionM(
Expand All @@ -48,7 +49,7 @@ class ConnectHandler[F[_]: Async](
methodRegistry.get(grpcMethodName).pure[F],
NotFound(connectrpc.Error(
code = io.grpc.Status.NOT_FOUND.toConnectCode,
message = s"Method not found: $grpcMethodName".some
message = s"Method not found: ${grpcMethodName.fullyQualifiedName}".some
))
)

Expand All @@ -58,7 +59,7 @@ class ConnectHandler[F[_]: Async](
(),
Forbidden(connectrpc.Error(
code = io.grpc.Status.PERMISSION_DENIED.toConnectCode,
message = s"Only POST-requests are allowed for method: $grpcMethodName".some
message = s"Only POST-requests are allowed for method: ${grpcMethodName.fullyQualifiedName}".some
))
).leftSemiflatMap(identity)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import cats.implicits.*
import io.grpc.{ManagedChannelBuilder, ServerBuilder, ServerServiceDefinition}
import org.http4s.dsl.Http4sDsl
import org.http4s.{HttpApp, HttpRoutes, Method}
import org.ivovk.connect_rpc_scala.grpc.*
import org.ivovk.connect_rpc_scala.http.*
import org.ivovk.connect_rpc_scala.http.QueryParams.*
import scalapb.json4s.{JsonFormat, Printer}
Expand Down Expand Up @@ -89,12 +90,12 @@ case class ConnectRouteBuilder[F[_] : Async] private(

HttpRoutes.of[F] {
case req@Method.GET -> Root / serviceName / methodName :? EncodingQP(contentType) +& MessageQP(message) =>
val grpcMethod = grpcMethodName(serviceName, methodName)
val grpcMethod = MethodName(serviceName, methodName)
val entity = RequestEntity[F](message, req.headers)

handler.handle(Method.GET, contentType.some, entity, grpcMethod)
case req@Method.POST -> Root / serviceName / methodName =>
val grpcMethod = grpcMethodName(serviceName, methodName)
val grpcMethod = MethodName(serviceName, methodName)
val contentType = req.contentType.map(_.mediaType)
val entity = RequestEntity[F](req)

Expand All @@ -105,7 +106,4 @@ case class ConnectRouteBuilder[F[_] : Async] private(
def build: Resource[F, HttpApp[F]] =
buildRoutes.map(_.orNotFound)

private inline def grpcMethodName(service: String, method: String): String =
service + "/" + method

}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.ivovk.connect_rpc_scala
package org.ivovk.connect_rpc_scala.grpc

import cats.Endo
import cats.effect.{Resource, Sync}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.ivovk.connect_rpc_scala.grpc

import io.grpc.MethodDescriptor

type Service = String
type Method = String

object MethodName {
def apply(descriptor: MethodDescriptor[_, _]): MethodName =
MethodName(descriptor.getServiceName, descriptor.getBareMethodName)
}

case class MethodName(service: Service, method: Method) {
def fullyQualifiedName: String = s"$service/$method"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.ivovk.connect_rpc_scala
package org.ivovk.connect_rpc_scala.grpc

import io.grpc.{MethodDescriptor, ServerMethodDefinition, ServerServiceDefinition}
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
Expand Down Expand Up @@ -35,18 +35,18 @@ object MethodRegistry {
descriptor = methodDescriptor,
)

methodDescriptor.getFullMethodName -> methodEntry
MethodName(methodDescriptor) -> methodEntry
}
.toMap
.groupMapReduce((mn, _) => mn.service)((mn, m) => Map(mn.method -> m))(_ ++ _)

new MethodRegistry(entries)
}

}

class MethodRegistry private(entries: Map[String, MethodRegistry.Entry]) {
class MethodRegistry private(entries: Map[Service, Map[Method, MethodRegistry.Entry]]) {

def get(fullMethodName: String): Option[MethodRegistry.Entry] =
entries.get(fullMethodName)
def get(methodName: MethodName): Option[MethodRegistry.Entry] =
entries.getOrElse(methodName.service, Map.empty).get(methodName.method)

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import scalapb.{GeneratedMessage as Message, GeneratedMessageCompanion as Compan

import java.net.URLDecoder
import java.util.Base64
import scala.util.chaining.*

object MessageCodec {
given [F[_] : Applicative, A <: Message](using codec: MessageCodec[F], cmp: Companion[A]): EntityDecoder[F, A] =
Expand Down Expand Up @@ -93,22 +94,22 @@ class ProtoMessageCodec[F[_] : Async](compressor: Compressor[F]) extends Message
override def decode[A <: Message](entity: RequestEntity[F])(using cmp: Companion[A]): DecodeResult[F, A] = {
val msg = entity.message match {
case str: String =>
Async[F].delay(base64dec.decode(str.getBytes(entity.charset.nioCharset)))
.flatMap(arr => Async[F].delay(cmp.parseFrom(arr)))
Async[F].delay(cmp.parseFrom(base64dec.decode(str.getBytes(entity.charset.nioCharset))))
case stream: Stream[F, Byte] =>
toInputStreamResource(compressor.decompressed(entity.encoding, stream))
.use(is => Async[F].delay(cmp.parseFrom(is)))
}

msg
.map { message =>
if (logger.isTraceEnabled) {
logger.trace(s">>> Headers: ${entity.headers.redactSensitive()}")
logger.trace(s">>> Proto: ${message.toProtoString}")
}

message
}
.pipe(
if logger.isTraceEnabled then
_.map { msg =>
logger.trace(s">>> Headers: ${entity.headers.redactSensitive()}")
logger.trace(s">>> Proto: ${msg.toProtoString}")
msg
}
else identity
)
.attemptT
.leftMap(e => InvalidMessageBodyFailure(e.getMessage, e.some))
}
Expand All @@ -129,7 +130,7 @@ class ProtoMessageCodec[F[_] : Async](compressor: Compressor[F]) extends Message

}

class Compressor[F[_]: Sync] {
class Compressor[F[_] : Sync] {

given Compression[F] = Compression.forSync[F]

Expand Down

0 comments on commit dfb80df

Please sign in to comment.