Skip to content

Commit

Permalink
Optimized asyncUnaryCall implementation with cancellation support (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
igor-vovk authored Dec 6, 2024
1 parent 74403b0 commit 29dab9a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import io.grpc.stub.MetadataUtils
import org.http4s.dsl.Http4sDsl
import org.http4s.{Header, MediaType, MessageFailure, Method, Response}
import org.ivovk.connect_rpc_scala.Mappings.*
import org.ivovk.connect_rpc_scala.grpc.{MethodName, MethodRegistry}
import org.ivovk.connect_rpc_scala.grpc.{GrpcClientCalls, MethodName, MethodRegistry}
import org.ivovk.connect_rpc_scala.http.Headers.`X-Test-Case-Name`
import org.ivovk.connect_rpc_scala.http.codec.MessageCodec.given
import org.ivovk.connect_rpc_scala.http.codec.{MessageCodec, MessageCodecRegistry}
import org.ivovk.connect_rpc_scala.http.{MediaTypes, RequestEntity}
import org.slf4j.{Logger, LoggerFactory}
import scalapb.grpc.ClientCalls
import scalapb.{GeneratedMessage, GeneratedMessageCompanion, TextFormat}

import java.util.concurrent.atomic.AtomicReference
Expand Down Expand Up @@ -103,32 +102,35 @@ class ConnectHandler[F[_] : Async](
logger.trace(s">>> Method: ${method.descriptor.getFullMethodName}, Entity: $message")
}

Async[F].fromFuture(Async[F].delay {
ClientCalls.asyncUnaryCall[GeneratedMessage, GeneratedMessage](
val callOptions = CallOptions.DEFAULT
.pipe(
req.timeout match {
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
case None => identity
}
)

GrpcClientCalls
.asyncUnaryCall2[F, GeneratedMessage, GeneratedMessage](
ClientInterceptors.intercept(
channel,
MetadataUtils.newAttachHeadersInterceptor(req.headers.toMetadata),
MetadataUtils.newCaptureMetadataInterceptor(responseHeaderMetadata, responseTrailerMetadata),
),
method.descriptor,
CallOptions.DEFAULT.pipe(
req.timeout match {
case Some(timeout) => _.withDeadlineAfter(timeout, MILLISECONDS)
case None => identity
}
),
callOptions,
message
)
}).map { response =>
val headers = responseHeaderMetadata.get.toHeaders() ++
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)
.map { response =>
val headers = responseHeaderMetadata.get.toHeaders() ++
responseTrailerMetadata.get.toHeaders(trailing = !treatTrailersAsHeaders)

if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}
if (logger.isTraceEnabled) {
logger.trace(s"<<< Headers: ${headers.redactSensitive()}")
}

Response(Ok, headers = headers).withEntity(response)
}
Response(Ok, headers = headers).withEntity(response)
}
}
.recover { case e =>
val grpcStatus = e match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package org.ivovk.connect_rpc_scala.grpc

import cats.effect.Async
import com.google.common.util.concurrent.{FutureCallback, Futures, MoreExecutors}
import io.grpc.stub.{ClientCalls, StreamObserver}
import io.grpc.{CallOptions, Channel, MethodDescriptor}

object GrpcClientCalls {

/**
* Asynchronous unary call.
*
* Optimized version of the `scalapb.grpc.ClientCalls.asyncUnaryCall` that skips Scala's Future instantiation
* and supports cancellation.
*/
def asyncUnaryCall[F[_] : Async, Req, Resp](
channel: Channel,
method: MethodDescriptor[Req, Resp],
options: CallOptions,
request: Req,
): F[Resp] = {
Async[F].async[Resp] { cb =>
Async[F].delay {
val future = ClientCalls.futureUnaryCall(channel.newCall(method, options), request)

Futures.addCallback(
future,
new FutureCallback[Resp] {
def onSuccess(result: Resp): Unit = cb(Right(result))

def onFailure(t: Throwable): Unit = cb(Left(t))
},
MoreExecutors.directExecutor(),
)

Some(Async[F].delay(future.cancel(true)))
}
}
}

/**
* Implementation that should be faster than the [[asyncUnaryCall]].
*/
def asyncUnaryCall2[F[_] : Async, Req, Resp](
channel: Channel,
method: MethodDescriptor[Req, Resp],
options: CallOptions,
request: Req,
): F[Resp] = {
Async[F].async[Resp] { cb =>
Async[F].delay {
val call = channel.newCall(method, options)

ClientCalls.asyncUnaryCall(call, request, new CallbackObserver(cb))

Some(Async[F].delay(call.cancel("Cancelled", null)))
}
}
}

/**
* [[CallbackObserver]] either executes [[onNext]] -> [[onCompleted]] during the happy path or just [[onError]] in case of
* an error.
*/
private class CallbackObserver[F[_] : Async, Resp](cb: Either[Throwable, Resp] => Unit) extends StreamObserver[Resp] {
private var value: Option[Either[Throwable, Resp]] = None

override def onNext(value: Resp): Unit = {
if this.value.isDefined then
throw new IllegalStateException("Value already received")

this.value = Some(Right(value))
}

override def onError(t: Throwable): Unit = {
cb(Left(t))
}

override def onCompleted(): Unit = {
this.value match
case Some(v) => cb(v)
case None => cb(Left(new IllegalStateException("No value received or call to onCompleted after onError")))
}

}

}

0 comments on commit 29dab9a

Please sign in to comment.