From 29dab9a5e77ee1c707767b0e0941c406f8984036 Mon Sep 17 00:00:00 2001 From: Ihor Vovk Date: Fri, 6 Dec 2024 21:04:25 +0100 Subject: [PATCH] Optimized asyncUnaryCall implementation with cancellation support (#45) --- .../connect_rpc_scala/ConnectHandler.scala | 38 ++++---- .../grpc/GrpcClientCalls.scala | 87 +++++++++++++++++++ 2 files changed, 107 insertions(+), 18 deletions(-) create mode 100644 core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala index 0d6b9ee..2c79181 100644 --- a/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/ConnectHandler.scala @@ -9,13 +9,12 @@ import io.grpc.stub.MetadataUtils import org.http4s.dsl.Http4sDsl import org.http4s.{Header, MediaType, MessageFailure, Method, Response} import org.ivovk.connect_rpc_scala.Mappings.* -import org.ivovk.connect_rpc_scala.grpc.{MethodName, MethodRegistry} +import org.ivovk.connect_rpc_scala.grpc.{GrpcClientCalls, MethodName, MethodRegistry} import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name` import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry} import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity} import org.slf4j.{Logger, LoggerFactory} -import scalapb.grpc.ClientCalls import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TextFormat} import java.util.concurrent.atomic.AtomicReference @@ -103,32 +102,35 @@ class ConnectHandler[F[_] : Async]( logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message") } - Async[F].fromFuture(Async[F].delay { - ClientCalls.asyncUnaryCall[GeneratedMessage, GeneratedMessage]( + val callOptions = CallOptions.DEFAULT + .pipe( + req.timeout match { + case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS) + case None => identity + } + ) + + GrpcClientCalls + .asyncUnaryCall2[F, GeneratedMessage, GeneratedMessage]( ClientInterceptors.intercept( channel, MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata), MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata), ), method.descriptor, - CallOptions.DEFAULT.pipe( - req.timeout match { - case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS) - case None => identity - } - ), + callOptions, message ) - }).map { response => - val headers = responseHeaderMetadata.get.toHeaders() ++ - responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders) + .map { response => + val headers = responseHeaderMetadata.get.toHeaders() ++ + responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders) - if (logger.isTraceEnabled) { - logger.trace(s"<<< Headers: ${headers.redactSensitive()}") - } + if (logger.isTraceEnabled) { + logger.trace(s"<<< Headers: ${headers.redactSensitive()}") + } - Response(Ok, headers = headers).withEntity(response) - } + Response(Ok, headers = headers).withEntity(response) + } } .recover { case e => val grpcStatus = e match { diff --git a/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala b/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala new file mode 100644 index 0000000..eb27679 --- /dev/null +++ b/core/src/main/scala/org/ivovk/connect_rpc_scala/grpc/GrpcClientCalls.scala @@ -0,0 +1,87 @@ +package org.ivovk.connect_rpc_scala.grpc + +import cats.effect.Async +import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors} +import io.grpc.stub.{ClientCalls, StreamObserver} +import io.grpc.{CallOptions, Channel, MethodDescriptor} + +object GrpcClientCalls { + + /** + * Asynchronous unary call. + * + * Optimized version of the `scalapb.grpc.ClientCalls.asyncUnaryCall` that skips Scala's Future instantiation + * and supports cancellation. + */ + def asyncUnaryCall[F[_] : Async, Req, Resp]( + channel: Channel, + method: MethodDescriptor[Req, Resp], + options: CallOptions, + request: Req, + ): F[Resp] = { + Async[F].async[Resp] { cb => + Async[F].delay { + val future = ClientCalls.futureUnaryCall(channel.newCall(method, options), request) + + Futures.addCallback( + future, + new FutureCallback[Resp] { + def onSuccess(result: Resp): Unit = cb(Right(result)) + + def onFailure(t: Throwable): Unit = cb(Left(t)) + }, + MoreExecutors.directExecutor(), + ) + + Some(Async[F].delay(future.cancel(true))) + } + } + } + + /** + * Implementation that should be faster than the [[asyncUnaryCall]]. + */ + def asyncUnaryCall2[F[_] : Async, Req, Resp]( + channel: Channel, + method: MethodDescriptor[Req, Resp], + options: CallOptions, + request: Req, + ): F[Resp] = { + Async[F].async[Resp] { cb => + Async[F].delay { + val call = channel.newCall(method, options) + + ClientCalls.asyncUnaryCall(call, request, new CallbackObserver(cb)) + + Some(Async[F].delay(call.cancel("Cancelled", null))) + } + } + } + + /** + * [[CallbackObserver]] either executes [[onNext]] -> [[onCompleted]] during the happy path or just [[onError]] in case of + * an error. + */ + private class CallbackObserver[F[_] : Async, Resp](cb: Either[Throwable, Resp] => Unit) extends StreamObserver[Resp] { + private var value: Option[Either[Throwable, Resp]] = None + + override def onNext(value: Resp): Unit = { + if this.value.isDefined then + throw new IllegalStateException("Value already received") + + this.value = Some(Right(value)) + } + + override def onError(t: Throwable): Unit = { + cb(Left(t)) + } + + override def onCompleted(): Unit = { + this.value match + case Some(v) => cb(v) + case None => cb(Left(new IllegalStateException("No value received or call to onCompleted after onError"))) + } + + } + +}