Skip to content

Commit

Permalink
Further optimizations in GRPC method execution (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-vovk authored Dec 7, 2024
1 parent 6db2aed commit 909a0d7
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import cats.effect.Async
import cats.implicits.*
import io.grpc.*
import io.grpc.MethodDescriptor.MethodType
import io.grpc.stub.MetadataUtils
import org.http4s.dsl.Http4sDsl
import org.http4s.{Header, MediaType, MessageFailure, Method, Response}
import org.ivovk.connect_rpc_scala.Mappings.*
Expand All @@ -15,9 +14,8 @@ import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given
import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry}
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
import org.slf4j.{Logger, LoggerFactory}
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}
import scalapb.GeneratedMessage

import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.duration.*
import scala.jdk.CollectionConverters.*
import scala.util.chaining.*
Expand Down Expand Up @@ -82,7 +80,7 @@ class ConnectHandler[F[_] : Async](
method: MethodRegistry.Entry,
req: RequestEntity[F],
channel: Channel
)(using codec: MessageCodec[F]): F[Response[F]] = {
)(using MessageCodec[F]): F[Response[F]] = {
if (logger.isTraceEnabled) {
// Used in conformance tests
req.headers.get[`X-Test-Case-Name`] match {
Expand All @@ -92,15 +90,10 @@ class ConnectHandler[F[_] : Async](
}
}

given GeneratedMessageCompanion[GeneratedMessage] = method.requestMessageCompanion

req.as[GeneratedMessage]
req.as[GeneratedMessage](method.requestMessageCompanion)
.flatMap { message =>
val responseHeaderMetadata = new AtomicReference[Metadata]()
val responseTrailerMetadata = new AtomicReference[Metadata]()

if (logger.isTraceEnabled) {
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message")
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}")
}

val callOptions = CallOptions.DEFAULT
Expand All @@ -111,37 +104,33 @@ class ConnectHandler[F[_] : Async](
}
)

GrpcClientCalls
.asyncUnaryCall2[F, GeneratedMessage, GeneratedMessage](
ClientInterceptors.intercept(
channel,
MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata),
MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata),
),
method.descriptor,
callOptions,
message
)
.map { response =>
val headers = responseHeaderMetadata.get.toHeaders() ++
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)
GrpcClientCalls.asyncUnaryCall(
channel,
method.descriptor,
callOptions,
req.headers.toMetadata,
message
)
}
.map { response =>
val headers = response.headers.toHeaders() ++
response.trailers.toHeaders(trailing = !treatTrailersAsHeaders)

if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}
if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}

Response(Ok, headers = headers).withEntity(response)
}
Response(Ok, headers = headers).withEntity(response.value)
}
.recover { case e =>
val grpcStatus = e match {
case e: StatusRuntimeException =>
case e: StatusException =>
e.getStatus.getDescription match {
case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED
case _ => e.getStatus
}
case e: StatusException => e.getStatus
case e: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
case e: StatusRuntimeException => e.getStatus
case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT
case _ => io.grpc.Status.INTERNAL
}

Expand All @@ -154,14 +143,16 @@ class ConnectHandler[F[_] : Async](
val httpStatus = grpcStatus.toHttpStatus
val connectCode = grpcStatus.toConnectCode

// Should be called before converting metadata to headers
val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey))
.fold(Seq.empty)(_.asScala.toSeq)

val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders)

if (logger.isTraceEnabled) {
logger.warn(s"<<< Error processing request", e)
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode, Message: ${message.orNull}")
logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode")
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
logger.trace(s"<<< Error processing request", e)
}

Response[F](httpStatus, headers = headers).withEntity(connectrpc.Error(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,85 +1,61 @@
package org.ivovk.connect_rpc_scala.grpc

import cats.effect.Async
import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors}
import io.grpc.stub.{ClientCalls, StreamObserver}
import io.grpc.{CallOptions, Channel, MethodDescriptor}
import io.grpc.*

object GrpcClientCalls {

case class Response[T](value: T, headers: Metadata, trailers: Metadata)

/**
* Asynchronous unary call.
*
* Optimized version of the `scalapb.grpc.ClientCalls.asyncUnaryCall` that skips Scala's Future instantiation
* and supports cancellation.
*/
def asyncUnaryCall[F[_] : Async, Req, Resp](
channel: Channel,
method: MethodDescriptor[Req, Resp],
options: CallOptions,
headers: Metadata,
request: Req,
): F[Resp] = {
Async[F].async[Resp] { cb =>
Async[F].delay {
val future = ClientCalls.futureUnaryCall(channel.newCall(method, options), request)

Futures.addCallback(
future,
new FutureCallback[Resp] {
def onSuccess(result: Resp): Unit = cb(Right(result))

def onFailure(t: Throwable): Unit = cb(Left(t))
},
MoreExecutors.directExecutor(),
)

Some(Async[F].delay(future.cancel(true)))
}
}
}

/**
* Implementation that should be faster than the [[asyncUnaryCall]].
*/
def asyncUnaryCall2[F[_] : Async, Req, Resp](
channel: Channel,
method: MethodDescriptor[Req, Resp],
options: CallOptions,
request: Req,
): F[Resp] = {
Async[F].async[Resp] { cb =>
): F[Response[Resp]] = {
Async[F].async[Response[Resp]] { cb =>
Async[F].delay {
val call = channel.newCall(method, options)

ClientCalls.asyncUnaryCall(call, request, new CallbackObserver(cb))
call.start(CallbackListener[Resp](cb), headers)
call.sendMessage(request)
call.halfClose()
call.request(2)

Some(Async[F].delay(call.cancel("Cancelled", null)))
}
}
}

/**
* [[StreamObserverToCallListenerAdapter]] either executes [[onNext]] -> [[onCompleted]] during the happy path
* or just [[onError]] in case of an error.
*/
private class CallbackObserver[Resp](cb: Either[Throwable, Resp] => Unit) extends StreamObserver[Resp] {
private var value: Option[Either[Throwable, Resp]] = None
private class CallbackListener[Resp](cb: Either[Throwable, Response[Resp]] => Unit) extends ClientCall.Listener[Resp] {
private var headers: Option[Metadata] = None
private var message: Option[Resp] = None

override def onNext(value: Resp): Unit = {
if this.value.isDefined then
throw new IllegalStateException("Value already received")

this.value = Some(Right(value))
override def onHeaders(headers: Metadata): Unit = {
this.headers = Some(headers)
}

override def onError(t: Throwable): Unit = {
cb(Left(t))
override def onMessage(message: Resp): Unit = {
if this.message.isDefined then
throw new IllegalStateException("More than one message received")

this.message = Some(message)
}

override def onCompleted(): Unit = {
this.value match
case Some(v) => cb(v)
case None => cb(Left(new IllegalStateException("No value received or call to onCompleted after onError")))
override def onClose(status: Status, trailers: Metadata): Unit = {
if status.isOk then
message match
case Some(value) => cb(Right(Response(
value = value,
headers = headers.getOrElse(new Metadata()),
trailers = trailers
)))
case None => cb(Left(new IllegalStateException("No value received")))
else
cb(Left(status.asException(trailers)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ case class RequestEntity[F[_]](
def timeout: Option[Long] =
headers.get[`Connect-Timeout-Ms`].map(_.value)

def as[A <: Message](using M: MonadThrow[F], codec: MessageCodec[F], cmp: Companion[A]): F[A] =
M.rethrow(codec.decode(this).value)
def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] =
M.rethrow(codec.decode(this)(using cmp).value)

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import io.grpc.{StatusException, StatusRuntimeException}
import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders
import scalapb.GeneratedMessage

object all extends RuntimeExceptionSyntax, ProtoMappingsSyntax
object all extends ExceptionSyntax, ProtoMappingsSyntax

trait RuntimeExceptionSyntax {
trait ExceptionSyntax {

extension (e: StatusRuntimeException) {
def withDetails[T <: GeneratedMessage](t: T): StatusRuntimeException = {
Expand Down

0 comments on commit 909a0d7

Please sign in to comment.