From 909a0d7fdb5a5431de9d49c4a9515e83d0a1ea4a Mon Sep 17 00:00:00 2001 From: Ihor Vovk Date: Sat, 7 Dec 2024 12:22:28 +0100 Subject: [PATCH] Further optimizations in GRPC method execution (#48) --- .../connect_rpc_scala/ConnectHandler.scala | 61 ++++++------- .../grpc/GrpcClientCalls.scala | 86 +++++++------------ .../http/RequestEntity.scala | 4 +- .../ivovk/connect_rpc_scala/syntax/all.scala | 4 +- 4 files changed, 61 insertions(+), 94 deletions(-) diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala index 1bf4e41..cba0eb7 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala @@ -5,7 +5,6 @@ import cats.effect.Async import cats.implicits.* import io.grpc.* import io.grpc.MethodDescriptor.MethodType -import io.grpc.stub.MetadataUtils import org.http4s.dsl.Http4sDsl import org.http4s.{Header, MediaType, MessageFailure, Method, Response} import org.ivovk.connect_rpc_scala.Mappings.* @@ -15,9 +14,8 @@ import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry} import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity} import org.slf4j.{Logger, LoggerFactory} -import scalapb.{GeneratedMessage, GeneratedMessageCompanion} +import scalapb.GeneratedMessage -import java.util.concurrent.atomic.AtomicReference import scala.concurrent.duration.* import scala.jdk.CollectionConverters.* import scala.util.chaining.* @@ -82,7 +80,7 @@ class ConnectHandler[F[_] : Async]( method: MethodRegistry.Entry, req: RequestEntity[F], channel: Channel - )(using codec: MessageCodec[F]): F[Response[F]] = { + )(using MessageCodec[F]): F[Response[F]] = { if (logger.isTraceEnabled) { // Used in conformance tests req.headers.get[`X-Test-Case-Name`] match { @@ -92,15 +90,10 @@ class ConnectHandler[F[_] : Async]( } } - given GeneratedMessageCompanion[GeneratedMessage] = method.requestMessageCompanion - - req.as[GeneratedMessage] + req.as[GeneratedMessage](method.requestMessageCompanion) .flatMap { message => - val responseHeaderMetadata = new AtomicReference[Metadata]() - val responseTrailerMetadata = new AtomicReference[Metadata]() - if (logger.isTraceEnabled) { - logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message") + logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}") } val callOptions = CallOptions.DEFAULT @@ -111,37 +104,33 @@ class ConnectHandler[F[_] : Async]( } ) - GrpcClientCalls - .asyncUnaryCall2[F, GeneratedMessage, GeneratedMessage]( - ClientInterceptors.intercept( - channel, - MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata), - MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata), - ), - method.descriptor, - callOptions, - message - ) - .map { response => - val headers = responseHeaderMetadata.get.toHeaders() ++ - responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders) + GrpcClientCalls.asyncUnaryCall( + channel, + method.descriptor, + callOptions, + req.headers.toMetadata, + message + ) + } + .map { response => + val headers = response.headers.toHeaders() ++ + response.trailers.toHeaders(trailing = !treatTrailersAsHeaders) - if (logger.isTraceEnabled) { - logger.trace(s"<<< Headers: ${headers.redactSensitive()}") - } + if (logger.isTraceEnabled) { + logger.trace(s"<<< Headers: ${headers.redactSensitive()}") + } - Response(Ok, headers = headers).withEntity(response) - } + Response(Ok, headers = headers).withEntity(response.value) } .recover { case e => val grpcStatus = e match { - case e: StatusRuntimeException => + case e: StatusException => e.getStatus.getDescription match { case "an implementation is missing" => io.grpc.Status.UNIMPLEMENTED case _ => e.getStatus } - case e: StatusException => e.getStatus - case e: MessageFailure => io.grpc.Status.INVALID_ARGUMENT + case e: StatusRuntimeException => e.getStatus + case _: MessageFailure => io.grpc.Status.INVALID_ARGUMENT case _ => io.grpc.Status.INTERNAL } @@ -154,14 +143,16 @@ class ConnectHandler[F[_] : Async]( val httpStatus = grpcStatus.toHttpStatus val connectCode = grpcStatus.toConnectCode + // Should be called before converting metadata to headers val details = Option(metadata.removeAll(GrpcHeaders.ErrorDetailsKey)) .fold(Seq.empty)(_.asScala.toSeq) val headers = metadata.toHeaders(trailing = !treatTrailersAsHeaders) if (logger.isTraceEnabled) { - logger.warn(s"<<< Error processing request", e) - logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode, Message: ${message.orNull}") + logger.trace(s"<<< Http Status: $httpStatus, Connect Error Code: $connectCode") + logger.trace(s"<<< Headers: ${headers.redactSensitive()}") + logger.trace(s"<<< Error processing request", e) } Response[F](httpStatus, headers = headers).withEntity(connectrpc.Error( diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala index 01eb0f9..f589455 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala @@ -1,85 +1,61 @@ package org.ivovk.connect_rpc_scala.grpc import cats.effect.Async -import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors} -import io.grpc.stub.{ClientCalls, StreamObserver} -import io.grpc.{CallOptions, Channel, MethodDescriptor} +import io.grpc.* object GrpcClientCalls { + case class Response[T](value: T, headers: Metadata, trailers: Metadata) + /** * Asynchronous unary call. - * - * Optimized version of the `scalapb.grpc.ClientCalls.asyncUnaryCall` that skips Scala's Future instantiation - * and supports cancellation. */ def asyncUnaryCall[F[_] : Async, Req, Resp]( channel: Channel, method: MethodDescriptor[Req, Resp], options: CallOptions, + headers: Metadata, request: Req, - ): F[Resp] = { - Async[F].async[Resp] { cb => - Async[F].delay { - val future = ClientCalls.futureUnaryCall(channel.newCall(method, options), request) - - Futures.addCallback( - future, - new FutureCallback[Resp] { - def onSuccess(result: Resp): Unit = cb(Right(result)) - - def onFailure(t: Throwable): Unit = cb(Left(t)) - }, - MoreExecutors.directExecutor(), - ) - - Some(Async[F].delay(future.cancel(true))) - } - } - } - - /** - * Implementation that should be faster than the [[asyncUnaryCall]]. - */ - def asyncUnaryCall2[F[_] : Async, Req, Resp]( - channel: Channel, - method: MethodDescriptor[Req, Resp], - options: CallOptions, - request: Req, - ): F[Resp] = { - Async[F].async[Resp] { cb => + ): F[Response[Resp]] = { + Async[F].async[Response[Resp]] { cb => Async[F].delay { val call = channel.newCall(method, options) - - ClientCalls.asyncUnaryCall(call, request, new CallbackObserver(cb)) + call.start(CallbackListener[Resp](cb), headers) + call.sendMessage(request) + call.halfClose() + call.request(2) Some(Async[F].delay(call.cancel("Cancelled", null))) } } } - /** - * [[StreamObserverToCallListenerAdapter]] either executes [[onNext]] -> [[onCompleted]] during the happy path - * or just [[onError]] in case of an error. - */ - private class CallbackObserver[Resp](cb: Either[Throwable, Resp] => Unit) extends StreamObserver[Resp] { - private var value: Option[Either[Throwable, Resp]] = None + private class CallbackListener[Resp](cb: Either[Throwable, Response[Resp]] => Unit) extends ClientCall.Listener[Resp] { + private var headers: Option[Metadata] = None + private var message: Option[Resp] = None - override def onNext(value: Resp): Unit = { - if this.value.isDefined then - throw new IllegalStateException("Value already received") - - this.value = Some(Right(value)) + override def onHeaders(headers: Metadata): Unit = { + this.headers = Some(headers) } - override def onError(t: Throwable): Unit = { - cb(Left(t)) + override def onMessage(message: Resp): Unit = { + if this.message.isDefined then + throw new IllegalStateException("More than one message received") + + this.message = Some(message) } - override def onCompleted(): Unit = { - this.value match - case Some(v) => cb(v) - case None => cb(Left(new IllegalStateException("No value received or call to onCompleted after onError"))) + override def onClose(status: Status, trailers: Metadata): Unit = { + if status.isOk then + message match + case Some(value) => cb(Right(Response( + value = value, + headers = headers.getOrElse(new Metadata()), + trailers = trailers + ))) + case None => cb(Left(new IllegalStateException("No value received"))) + else + cb(Left(status.asException(trailers))) } } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala index 319c751..3a60b04 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/http/RequestEntity.scala @@ -35,7 +35,7 @@ case class RequestEntity[F[_]]( def timeout: Option[Long] = headers.get[`Connect-Timeout-Ms`].map(_.value) - def as[A <: Message](using M: MonadThrow[F], codec: MessageCodec[F], cmp: Companion[A]): F[A] = - M.rethrow(codec.decode(this).value) + def as[A <: Message](cmp: Companion[A])(using M: MonadThrow[F], codec: MessageCodec[F]): F[A] = + M.rethrow(codec.decode(this)(using cmp).value) } diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala index d83eee5..d2d40e9 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/syntax/all.scala @@ -4,9 +4,9 @@ import io.grpc.{StatusException, StatusRuntimeException} import org.ivovk.connect_rpc_scala.grpc.GrpcHeaders import scalapb.GeneratedMessage -object all extends RuntimeExceptionSyntax, ProtoMappingsSyntax +object all extends ExceptionSyntax, ProtoMappingsSyntax -trait RuntimeExceptionSyntax { +trait ExceptionSyntax { extension (e: StatusRuntimeException) { def withDetails[T <: GeneratedMessage](t: T): StatusRuntimeException = {