diff --git a/.travis.yml b/.travis.yml index bcfbec8..6261571 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ jobs: install: - ./gradlew assemble script: - - ./gradlew check + - ./gradlew -Dkotlinx.coroutines.debug=on check - cd example-project && ./gradlew test && ./gradlew clean test -PuseKrotoConfigDsl=true after_success: - bash <(curl -s https://codecov.io/bash) diff --git a/build.gradle b/build.gradle index a59bb1c..2937165 100644 --- a/build.gradle +++ b/build.gradle @@ -62,6 +62,7 @@ subprojects{ subproject -> } tasks.withType(Test) { + testLogging { // set options for log level LIFECYCLE events ( diff --git a/gradle.properties b/gradle.properties index 7fc6f1f..974d2f8 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1 +1,3 @@ kotlin.code.style=official + + diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt index da6fc29..09ac0a4 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcUtils.kt @@ -18,11 +18,11 @@ package com.github.marcoferrer.krotoplus.coroutines import com.github.marcoferrer.krotoplus.coroutines.call.newProducerScope import com.github.marcoferrer.krotoplus.coroutines.call.toRpcException +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.Job -import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ProducerScope import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.launch @@ -60,7 +60,7 @@ public fun CoroutineScope.launchProducerJob( } } -internal suspend fun Channel<*>.awaitCloseOrThrow(){ +internal suspend fun SendChannel<*>.awaitCloseOrThrow(){ suspendCancellableCoroutine { cont -> invokeOnClose { error -> if(error == null) diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallDecorators.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallDecorators.kt new file mode 100644 index 0000000..7f9b437 --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallDecorators.kt @@ -0,0 +1,31 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.client + +import io.grpc.ClientCall +import io.grpc.ForwardingClientCall + +internal inline fun > C.beforeCancellation( + crossinline block: C.(message: String?, cause: Throwable?) -> Unit +): ClientCall { + return object : ForwardingClientCall.SimpleForwardingClientCall(this) { + override fun cancel(message: String?, cause: Throwable?){ + block(message, cause) + super.cancel(message, cause) + } + } +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt index 5ce3100..cb9b192 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt @@ -22,17 +22,33 @@ import com.github.marcoferrer.krotoplus.coroutines.call.newRpcScope import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext import io.grpc.CallOptions import io.grpc.MethodDescriptor +import io.grpc.Status import io.grpc.stub.AbstractStub +import io.grpc.stub.ClientCallStreamObserver import io.grpc.stub.ClientCalls.asyncBidiStreamingCall import io.grpc.stub.ClientCalls.asyncClientStreamingCall import io.grpc.stub.ClientCalls.asyncServerStreamingCall import io.grpc.stub.ClientCalls.asyncUnaryCall import io.grpc.stub.ClientResponseObserver import kotlinx.coroutines.CancellableContinuation +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ProducerScope import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.flow.buffer +import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.produceIn import kotlinx.coroutines.suspendCancellableCoroutine +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException +internal const val MESSAGE_CLIENT_CANCELLED_CALL = "Client has cancelled the call" /** * Executes a unary rpc call using the [io.grpc.Channel] and [io.grpc.CallOptions] attached to the @@ -101,22 +117,70 @@ public fun > T.clientCallServerStreaming( public fun clientCallServerStreaming( request: ReqT, method: MethodDescriptor, - channel: io.grpc.Channel, + grpcChannel: io.grpc.Channel, callOptions: CallOptions = CallOptions.DEFAULT ): ReceiveChannel { - val initialContext = callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT) - with(newRpcScope(initialContext, method)) { - val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)) - val responseObserverChannel = ClientResponseStreamChannel(coroutineContext) - asyncServerStreamingCall( - call, - request, - responseObserverChannel - ) - bindScopeCancellationToCall(call) - return responseObserverChannel + val observerAdapter = ResponseObserverChannelAdapter() + val rpcScope = newRpcScope(callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT), method) + val responseFlow = callbackFlow flow@ { + observerAdapter.scope = this + + val call = grpcChannel + .newCall(method, callOptions.withCoroutineContext(coroutineContext)) + .beforeCancellation { message, cause -> + observerAdapter.beforeCallCancellation(message, cause) + } + + val job = coroutineContext[Job]!! + + // Start the RPC Call + asyncServerStreamingCall(call, request, observerAdapter) + + // If our parent job is cancelled before we can + // start the call then we need to propagate the + // cancellation to the underlying call + job.invokeOnCompletion { error -> + // Our job can be cancelled after completion due to the inner machinery + // of kotlinx.coroutines.flow.Channels.kt.emitAll(). Its final operation + // after receiving a close is a call to channel.cancelConsumed(cause). + // Even if it doesnt encounter an exception it will cancel with null. + // We will only invoke cancel on the call + if(job.isCancelled && observerAdapter.isActive){ + call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, error) + } + } + + suspendCancellableCoroutine { cont -> + // Here we need to handle not only parent job cancellation + // but calls to `channel.cancel(...)` as well. + cont.invokeOnCancellation { error -> + if (observerAdapter.isActive) { + call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, error) + } + } + invokeOnClose { error -> + if (error == null) + cont.resume(Unit) else + cont.resumeWithException(error) + } + } } + + // Use buffer UNLIMITED so that we dont drop any inbound messages + return flow { emitAll(responseFlow.buffer(Channel.UNLIMITED)) } + .onEach { + if(observerAdapter.isActive){ + observerAdapter.callStreamObserver.request(1) + } + } + // We use buffer RENDEZVOUS on the outer flow so that our + // `onEach` operator is only invoked each time a message is + // collected instead of each time a message is received from + // from the underlying call. + .buffer(Channel.RENDEZVOUS) + .produceIn(rpcScope) + } public fun > T.clientCallBidiStreaming( diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseStreamChannel.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseStreamChannel.kt deleted file mode 100644 index 0596162..0000000 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseStreamChannel.kt +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2019 Kroto+ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.github.marcoferrer.krotoplus.coroutines.client - -import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledInboundStreamObserver -import com.github.marcoferrer.krotoplus.coroutines.call.applyInboundFlowControl -import io.grpc.stub.ClientCallStreamObserver -import io.grpc.stub.ClientResponseObserver -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.channels.ReceiveChannel -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger -import kotlin.coroutines.CoroutineContext - - -internal class ClientResponseStreamChannel( - override val coroutineContext: CoroutineContext, - override val inboundChannel: Channel = Channel() -) : ClientResponseObserver, - FlowControlledInboundStreamObserver, - ReceiveChannel by inboundChannel, - CoroutineScope { - - override val isInboundCompleted: AtomicBoolean = AtomicBoolean() - - override val transientInboundMessageCount: AtomicInteger = AtomicInteger() - - override lateinit var callStreamObserver: ClientCallStreamObserver - - private var aborted: Boolean = false - - override fun beforeStart(requestStream: ClientCallStreamObserver) { - callStreamObserver = requestStream.apply { - applyInboundFlowControl(inboundChannel,transientInboundMessageCount) - } - - inboundChannel.invokeOnClose { - // If the client prematurely cancels the responseChannel - // we need to propagate this as a cancellation to the underlying call - if(!isInboundCompleted.get() && !aborted){ - callStreamObserver.cancel("Client has cancelled call", it) - } - } - } - - override fun onNext(value: RespT): Unit = onNextWithBackPressure(value) - - override fun onError(t: Throwable) { - aborted = true - inboundChannel.close(t) - } -} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt index 7098384..0e9449f 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt @@ -74,8 +74,7 @@ internal class ClientStreamingCallChannelImpl( attachOutboundChannelCompletionHandler( callStreamObserver, outboundChannel, - onSuccess = { outboundMessageHandler.close() }, - onError = { error -> completableResponse.completeExceptionally(error) } + onSuccess = { outboundMessageHandler.close() } ) completableResponse.invokeOnCompletion { // If the client prematurely cancels the response diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ResponseObserverChannelAdapter.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ResponseObserverChannelAdapter.kt new file mode 100644 index 0000000..bf42ec6 --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ResponseObserverChannelAdapter.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.client + +import io.grpc.Status +import io.grpc.stub.ClientCallStreamObserver +import io.grpc.stub.ClientResponseObserver +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.ProducerScope +import java.util.concurrent.atomic.AtomicBoolean + +internal class ResponseObserverChannelAdapter: ClientResponseObserver { + + private val isAborted = AtomicBoolean() + private val isCompleted = AtomicBoolean() + + lateinit var scope: ProducerScope + + lateinit var callStreamObserver: ClientCallStreamObserver + private set + + val isActive: Boolean + get() = !(isAborted.get() || isCompleted.get()) + + override fun beforeStart(requestStream: ClientCallStreamObserver) { + require(::scope.isInitialized){ "Producer scope was not initialized" } + callStreamObserver = requestStream.apply { disableAutoInboundFlowControl() } + } + + fun beforeCallCancellation(message: String?, cause: Throwable?){ + if(!isAborted.getAndSet(true)) { + val cancellationStatus = Status.CANCELLED + .withDescription(message) + .withCause(cause) + .asRuntimeException() + + scope.close(CancellationException(message, cancellationStatus)) + } + } + + override fun onNext(value: RespT) { + scope.offer(value) + } + + override fun onError(t: Throwable) { + isAborted.set(true) + scope.close(t) + scope.cancel(CancellationException(t.message,t)) + } + + override fun onCompleted() { + isCompleted.set(true) + scope.close() + } +} + diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt new file mode 100644 index 0000000..9b34404 --- /dev/null +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt @@ -0,0 +1,140 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines + +import com.github.marcoferrer.krotoplus.coroutines.utils.ClientCallSpyInterceptor +import com.github.marcoferrer.krotoplus.coroutines.utils.RpcStateInterceptor +import io.grpc.Channel +import io.grpc.ClientCall +import io.grpc.MethodDescriptor +import io.grpc.examples.helloworld.GreeterCoroutineGrpc +import io.grpc.examples.helloworld.GreeterGrpc +import io.grpc.examples.helloworld.HelloReply +import io.grpc.examples.helloworld.HelloRequest +import io.grpc.testing.GrpcServerRule +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import org.junit.Rule +import kotlin.test.BeforeTest +import kotlin.test.fail + +// Create StreamRecorder + +abstract class RpcCallTest( + val methodDescriptor: MethodDescriptor +) { + + @[Rule JvmField] + var grpcServerRule = GrpcServerRule().directExecutor() + + @[Rule JvmField] + var nonDirectGrpcServerRule = GrpcServerRule() + + // @[Rule JvmField] + // public val timeout = CoroutinesTimeout.seconds(COROUTINE_TEST_TIMEOUT) + + var callState = RpcStateInterceptor() + + val expectedRequest = HelloRequest.newBuilder().setName("success").build()!! + + @BeforeTest + fun setupCall() { + callState = RpcStateInterceptor() + } + + inner class RpcSpy(channel: Channel) { + + constructor(useDirectExecutor: Boolean = true) : this( + if(useDirectExecutor) grpcServerRule.channel else nonDirectGrpcServerRule.channel + ) + + private val _call = CompletableDeferred>() + + val stub = GreeterGrpc.newStub(channel) + .withInterceptors(ClientCallSpyInterceptor(_call), callState)!! + + val coStub = GreeterCoroutineGrpc.newStub(channel) + .withInterceptors(ClientCallSpyInterceptor(_call), callState)!! + + val call: ClientCall by lazy { + @Suppress("UNCHECKED_CAST") + runBlocking { _call.await() as ClientCall } + } + + } + + suspend fun RpcStateInterceptor.awaitCancellation(timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT){ + client.cancelled.assert(timeout){ "Client call should be cancelled" } + server.cancelled.assert(timeout){ "Server call should be cancelled" } + } + + suspend fun RpcStateInterceptor.awaitClose(timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT){ + client.closed.assert(timeout){ "Client call should be closed" } + server.closed.assert(timeout){ "Server call should be closed" } + } + + fun RpcStateInterceptor.blockUntilCancellation(timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT) = + runBlocking { awaitCancellation(timeout) } + + fun RpcStateInterceptor.blockUntilClosed(timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT) = + runBlocking { awaitClose(timeout) } + + suspend fun withTimeoutOrDumpState( + timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT, + message: String, + block: suspend () -> T + ) : T = try { + withTimeout(timeout) { block() } + } catch (e: TimeoutCancellationException) { + fail(""" + |$message + |Timeout after ${timeout}ms + |$callState + """.trimMargin()) + } + + suspend fun CompletableDeferred.assert( + timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT, + message: () -> String + ) = withTimeoutOrDumpState(timeout, message()){ + await() + } + + fun CompletableDeferred.assertBlocking( + timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT, + message: () -> String + ) = runBlocking { assert(timeout, message) } + + + inline fun runTest ( + timeout: Long = DEFAULT_STATE_ASSERT_TIMEOUT, + crossinline block: suspend CoroutineScope.() -> T + ): T = runBlocking(Dispatchers.Default) { + withTimeoutOrDumpState(timeout, "Rpc did not complete in time"){ + block() + } + } + + + companion object{ + const val DEFAULT_STATE_ASSERT_TIMEOUT = 10_000L + } +} diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt index e697050..132ea02 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt @@ -17,91 +17,79 @@ package com.github.marcoferrer.krotoplus.coroutines.client -import com.github.marcoferrer.krotoplus.coroutines.CALL_OPTION_COROUTINE_CONTEXT +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus +import com.github.marcoferrer.krotoplus.coroutines.utils.invoke +import com.github.marcoferrer.krotoplus.coroutines.utils.newCancellingInterceptor import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext import io.grpc.CallOptions import io.grpc.Channel import io.grpc.ClientCall import io.grpc.ClientInterceptor +import io.grpc.ClientInterceptors +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener +import io.grpc.Metadata import io.grpc.MethodDescriptor +import io.grpc.ServerInterceptors import io.grpc.Status import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.stub.StreamObserver -import io.grpc.testing.GrpcServerRule -import io.mockk.Runs -import io.mockk.every -import io.mockk.just import io.mockk.spyk import io.mockk.verify import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.toList +import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking -import org.junit.Rule import org.junit.Test -import kotlin.test.BeforeTest +import java.util.concurrent.Phaser import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.fail -class ClientCallServerStreamingTests { +class ClientCallServerStreamingTests : + RpcCallTest(GreeterGrpc.getSayHelloServerStreamingMethod()) { - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() - - @[Rule JvmField] - var nonDirectGrpcServerRule = GrpcServerRule() - - // @[Rule JvmField] - // public val timeout = CoroutinesTimeout.seconds(COROUTINE_TEST_TIMEOUT) - - private val methodDescriptor = GreeterGrpc.getSayHelloServerStreamingMethod() - private val service = spyk(object : GreeterGrpc.GreeterImplBase() {}) - private val expectedRequest = HelloRequest.newBuilder().setName("success").build() - - private val cancellingInterceptor = object : ClientInterceptor { + private val excessiveInboundMessageInterceptor = object : ClientInterceptor { override fun interceptCall( method: MethodDescriptor, callOptions: CallOptions, next: Channel - ): ClientCall { - val call = next.newCall(method, callOptions) - callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT)[Job]?.cancel() - return call - } - } - - inner class RpcSpy(val channel: Channel = grpcServerRule.channel){ - val stub: GreeterGrpc.GreeterStub - lateinit var call: ClientCall - - init { - val channelSpy = spyk(channel) - stub = GreeterGrpc.newStub(channelSpy) - - every { channelSpy.newCall(methodDescriptor, any()) } answers { - spyk(channel.newCall(methodDescriptor, secondArg())).also { - this@RpcSpy.call = it + ): ClientCall = + object : SimpleForwardingClientCall(next.newCall(method, callOptions)) { + override fun start(responseListener: Listener?, headers: Metadata?) { + super.start(object : SimpleForwardingClientCallListener(responseListener) { + override fun onMessage(message: RespT) { + repeat(3) { super.onMessage(message) } + } + }, headers) } } - } } - private fun setupServerHandlerNoop(){ - every { service.sayHelloServerStreaming(expectedRequest, any()) } just Runs - } + private fun setupServerHandlerNoop() = setupUpServerHandler { _, _ -> } - @BeforeTest - fun setupService() { - grpcServerRule.serviceRegistry.addService(service) + private fun setupUpServerHandler( + block: (request: HelloRequest, responseObserver: StreamObserver) -> Unit + ) { + val serviceImpl = object : GreeterGrpc.GreeterImplBase() { + override fun sayHelloServerStreaming(request: HelloRequest, responseObserver: StreamObserver) = + block(request, responseObserver) + } + + val service = ServerInterceptors.intercept(serviceImpl, callState) nonDirectGrpcServerRule.serviceRegistry.addService(service) + grpcServerRule.serviceRegistry.addService(service) } @Test @@ -109,13 +97,12 @@ class ClientCallServerStreamingTests { val rpcSpy = RpcSpy() val stub = rpcSpy.stub - every { service.sayHelloServerStreaming(expectedRequest, any()) } answers { - val actualRequest = firstArg() - with(secondArg>()) { + setupUpServerHandler { request, responseObserver -> + with(responseObserver) { repeat(3) { onNext( HelloReply.newBuilder() - .setMessage("Request#$it:${actualRequest.name}") + .setMessage("Request#$it:${request.name}") .build() ) } @@ -125,13 +112,16 @@ class ClientCallServerStreamingTests { val responseChannel = stub .clientCallServerStreaming(expectedRequest, methodDescriptor) - runBlocking { + + runTest { repeat(3) { assertEquals("Request#$it:${expectedRequest.name}", responseChannel.receive().message) } + delay(300) } + assert(responseChannel.isClosedForReceive) { "Response channel is closed after successful call" } - verify(exactly = 0) { rpcSpy.call.cancel(any(),any()) } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } } @Test @@ -139,13 +129,12 @@ class ClientCallServerStreamingTests { val rpcSpy = RpcSpy() val stub = rpcSpy.stub - every { service.sayHelloServerStreaming(expectedRequest, any()) } answers { - val actualRequest = firstArg() - with(secondArg>()) { + setupUpServerHandler { request, responseObserver -> + with(responseObserver) { repeat(10) { onNext( HelloReply.newBuilder() - .setMessage("Request#$it:${actualRequest.name}") + .setMessage("Request#$it:${request.name}") .build() ) } @@ -156,33 +145,37 @@ class ClientCallServerStreamingTests { val responseChannel = stub .clientCallServerStreaming(expectedRequest, methodDescriptor) val results = mutableListOf() - runBlocking { + runTest { repeat(3) { results.add(responseChannel.receive().message) } responseChannel.cancel() } + callState { + client.cancelled.assertBlocking{ "Client should be cancelled" } + } + assert(responseChannel.isClosedForReceive) { "Response channel is closed after successful call" } - verify(exactly = 1) { rpcSpy.call.cancel("Client has cancelled call",any()) } + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, any()) } } @Test fun `Call fails on server error`() { - val rpcSpy = RpcSpy(nonDirectGrpcServerRule.channel) + val rpcSpy = RpcSpy(useDirectExecutor = false) val stub = rpcSpy.stub - every { service.sayHelloServerStreaming(expectedRequest, any()) } answers { - val actualRequest = firstArg() - with(secondArg>()) { + val phaser = Phaser(2) + setupUpServerHandler { request, responseObserver -> + with(responseObserver) { repeat(2) { onNext( HelloReply.newBuilder() - .setMessage("Request#$it:${actualRequest.name}") + .setMessage("Request#$it:${request.name}") .build() ) } - + phaser.arriveAndAwaitAdvance() onError(Status.INVALID_ARGUMENT.asRuntimeException()) } } @@ -190,10 +183,11 @@ class ClientCallServerStreamingTests { val responseChannel = stub .clientCallServerStreaming(expectedRequest, methodDescriptor) - runBlocking { + runBlocking(Dispatchers.Default) { repeat(2) { assertEquals("Request#$it:${expectedRequest.name}", responseChannel.receive().message) } + phaser.arrive() assertFailsWithStatus(Status.INVALID_ARGUMENT) { responseChannel.receive() @@ -201,7 +195,7 @@ class ClientCallServerStreamingTests { } assert(responseChannel.isClosedForReceive) { "Response channel is closed after server error" } - verify(exactly = 0) { rpcSpy.call.cancel(any(),any()) } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } } @Test @@ -213,25 +207,40 @@ class ClientCallServerStreamingTests { val externalJob = Job() lateinit var responseChannel: ReceiveChannel - assertFailsWithStatus(Status.CANCELLED, "CANCELLED: Job was cancelled") { - runBlocking { - launch(Dispatchers.Default) { - launch(start = CoroutineStart.UNDISPATCHED) { - responseChannel = stub - .withCoroutineContext(externalJob) - .clientCallServerStreaming(expectedRequest, methodDescriptor) + lateinit var parentJob: Job - responseChannel.receive() - } - launch { - externalJob.cancel() - } + runTest { + launch(Dispatchers.Default) { + parentJob = launch(start = CoroutineStart.UNDISPATCHED) { + responseChannel = stub + .withCoroutineContext(externalJob) + .clientCallServerStreaming(expectedRequest, methodDescriptor) + responseChannel.receive() + fail("Should not reach here") + } + launch { + callState.client.started.assert { "Client should be started" } + externalJob.cancel() } + } + } + callState { + client { + closed.assertBlocking { "Client should be closed" } + } + server { + cancelled.assertBlocking { "Server should be cancelled" } + } + } + assert(parentJob.isCancelled) { "External job cancellation should propagate from receive channel" } + assertFailsWith(CancellationException::class) { + runBlocking { + responseChannel.receive() } } - verify { rpcSpy.call.cancel("Job was cancelled", any()) } + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, any()) } assert(responseChannel.isClosedForReceive) { "Response channel is closed after cancellation" } } @@ -240,22 +249,52 @@ class ClientCallServerStreamingTests { fun `Call is canceled when interceptor cancels scope normally`() { val rpcSpy = RpcSpy() - val stub = rpcSpy.stub.withInterceptors(cancellingInterceptor) + val stub = rpcSpy.stub.withInterceptors(newCancellingInterceptor(useNormalCancellation = true)) setupServerHandlerNoop() lateinit var responseChannel: ReceiveChannel - assertFailsWithStatus(Status.CANCELLED, "CANCELLED: Job was cancelled") { - runBlocking { - launch { - responseChannel = stub.clientCallServerStreaming(expectedRequest, methodDescriptor) + runTest { + launch { + responseChannel = stub.clientCallServerStreaming(expectedRequest, methodDescriptor) - responseChannel.receive() - } + responseChannel.receive() + } + } + + assertFailsWith(CancellationException::class) { + runBlocking { responseChannel.receive() } + } + + callState.client.cancelled.assertBlocking { "Client should be cancelled" } + + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, any()) } + assert(responseChannel.isClosedForReceive) { "Response channel is closed after cancellation" } + } + + @Test + fun `Call is canceled when interceptor cancels scope exceptionally`() { + + val rpcSpy = RpcSpy() + val stub = rpcSpy.stub.withInterceptors(newCancellingInterceptor(useNormalCancellation = false)) + + setupServerHandlerNoop() + + lateinit var responseChannel: ReceiveChannel + runTest { + launch { + responseChannel = stub.clientCallServerStreaming(expectedRequest, methodDescriptor) + + responseChannel.receive() } } - verify { rpcSpy.call.cancel("Job was cancelled", any()) } + assertFailsWith(CancellationException::class) { + runBlocking { responseChannel.receive() } + } + + callState.client.cancelled.assertBlocking { "Client should be cancelled" } + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, any()) } assert(responseChannel.isClosedForReceive) { "Response channel is closed after cancellation" } } @@ -268,7 +307,7 @@ class ClientCallServerStreamingTests { setupServerHandlerNoop() lateinit var responseChannel: ReceiveChannel - runBlocking { + runTest { launch(Dispatchers.Default) { launch(start = CoroutineStart.UNDISPATCHED) { responseChannel = stub @@ -277,11 +316,12 @@ class ClientCallServerStreamingTests { responseChannel.receive() } + delay(100) cancel() } } - verify { rpcSpy.call.cancel(any(), any()) } + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, any()) } assert(responseChannel.isClosedForReceive) { "Response channel is closed after cancellation" } } @@ -304,6 +344,7 @@ class ClientCallServerStreamingTests { responseChannel.receive() } + delay(100) launch { error("cancel") } @@ -311,8 +352,99 @@ class ClientCallServerStreamingTests { } } - verify { rpcSpy.call.cancel("Parent job is Cancelling", any()) } + assertFailsWith(CancellationException::class) { + runBlocking { + responseChannel.receive() + } + } + + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, any()) } assert(responseChannel.isClosedForReceive) { "Response channel is closed after cancellation" } } + @Test + fun `Call only requests messages after one is consumed`() { + val rpcSpy = RpcSpy(useDirectExecutor = false) + val stub = rpcSpy.stub + + setupUpServerHandler { request, responseObserver -> + with(responseObserver) { + repeat(20) { + onNext( + HelloReply.newBuilder() + .setMessage("Request#$it:${request.name}") + .build() + ) + } + onCompleted() + } + } + + val responseChannel = stub + .clientCallServerStreaming(expectedRequest, methodDescriptor) + + val result = runTest { + delay(100) + repeat(3) { + verify(exactly = it + 2) { rpcSpy.call.request(1) } + assertEquals("Request#$it:${expectedRequest.name}", responseChannel.receive().message) + delay(10) + } + + // Consume remaining messages + responseChannel.toList() + } + + assertEquals(17, result.size) + assert(responseChannel.isClosedForReceive) { "Response channel is closed after server error" } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } + } + + @Test + fun `Excessive messages are buffered without requesting new ones`() { + + val rpcSpy = RpcSpy( + ClientInterceptors.intercept(nonDirectGrpcServerRule.channel, excessiveInboundMessageInterceptor) + ) + val stub = rpcSpy.stub + + + setupUpServerHandler { request, responseObserver -> + with(responseObserver) { + repeat(20) { + onNext( + HelloReply.newBuilder() + .setMessage("Request#$it:${request.name}") + .build() + ) + } + onCompleted() + } + } + + val responseChannel = stub + .clientCallServerStreaming(expectedRequest, methodDescriptor) + + val consumedMessages = mutableListOf() + val result = runTest { + delay(300) + repeat(4) { + verify(exactly = it + 2) { rpcSpy.call.request(1) } + consumedMessages += responseChannel.receive().message + delay(10) + } + + // Consume remaining messages + responseChannel.toList() + } + + assertEquals("Request#0:${expectedRequest.name}", consumedMessages[0]) + assertEquals("Request#0:${expectedRequest.name}", consumedMessages[1]) + assertEquals("Request#0:${expectedRequest.name}", consumedMessages[2]) + assertEquals("Request#1:${expectedRequest.name}", consumedMessages[3]) + assertEquals(56, result.size) + assert(responseChannel.isClosedForReceive) { "Response channel is closed after server error" } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } + } + } \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt index c544106..f36e756 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt @@ -16,21 +16,25 @@ package com.github.marcoferrer.krotoplus.coroutines.integration +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.client.clientCallClientStreaming import com.github.marcoferrer.krotoplus.coroutines.utils.assertExEquals import com.github.marcoferrer.krotoplus.coroutines.utils.assertFails -import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus2 +import com.github.marcoferrer.krotoplus.coroutines.utils.invoke import com.github.marcoferrer.krotoplus.coroutines.utils.matchThrowable import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext import io.grpc.CallOptions +import io.grpc.Channel import io.grpc.ClientCall +import io.grpc.ClientInterceptor +import io.grpc.MethodDescriptor +import io.grpc.ServerInterceptors import io.grpc.Status import io.grpc.examples.helloworld.GreeterCoroutineGrpc import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest -import io.grpc.testing.GrpcServerRule -import io.mockk.every import io.mockk.spyk import io.mockk.verify import kotlinx.coroutines.CancellationException @@ -46,7 +50,6 @@ import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking -import org.junit.Rule import org.junit.Test import java.util.concurrent.atomic.AtomicInteger import kotlin.coroutines.coroutineContext @@ -54,27 +57,20 @@ import kotlin.test.assertEquals import kotlin.test.assertFalse -class ClientStreamingBackPressureTests { +class ClientStreamingBackPressureTests : + RpcCallTest(GreeterCoroutineGrpc.sayHelloClientStreamingMethod) { - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() - - private val methodDescriptor = GreeterGrpc.getSayHelloClientStreamingMethod() - - inner class RpcSpy{ - val stub: GreeterGrpc.GreeterStub - lateinit var call: ClientCall - - init { - val channelSpy = spyk(grpcServerRule.channel) - stub = GreeterGrpc.newStub(channelSpy) - - every { channelSpy.newCall(methodDescriptor, any()) } answers { - spyk(grpcServerRule.channel.newCall(methodDescriptor, secondArg())).also { - this@RpcSpy.call = it - } - } + private fun setupUpServerHandler( + block: suspend (requestChannel: ReceiveChannel) -> HelloReply + ) { + val serviceImpl = object : GreeterCoroutineGrpc.GreeterImplBase() { + override suspend fun sayHelloClientStreaming(requestChannel: ReceiveChannel): HelloReply = + block(requestChannel) } + + val service = ServerInterceptors.intercept(serviceImpl, callState) + grpcServerRule.serviceRegistry.addService(service) + nonDirectGrpcServerRule.serviceRegistry.addService(service) } // TODO(marco) @@ -86,13 +82,12 @@ class ClientStreamingBackPressureTests { @Test fun `Client send suspends until server invokes receive`() { val deferredServerChannel = CompletableDeferred>() - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ - override suspend fun sayHelloClientStreaming(requestChannel: ReceiveChannel): HelloReply { - deferredServerChannel.complete(spyk(requestChannel)) - delay(Long.MAX_VALUE) - return HelloReply.getDefaultInstance() - } - }) + + setupUpServerHandler { requestChannel -> + deferredServerChannel.complete(spyk(requestChannel)) + delay(Long.MAX_VALUE) + HelloReply.getDefaultInstance() + } val rpcSpy = RpcSpy() val stub = rpcSpy.stub @@ -101,7 +96,7 @@ class ClientStreamingBackPressureTests { assertFails { runBlocking { - val (clientRequestChannel, response) = stub + val (clientRequestChannel, _) = stub .withCoroutineContext(coroutineContext + Dispatchers.Default) .clientCallClientStreaming(methodDescriptor) @@ -134,17 +129,15 @@ class ClientStreamingBackPressureTests { fun `Call completed successfully`() { val deferredServerChannel = CompletableDeferred>() val serverJob = CompletableDeferred() - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ - override suspend fun sayHelloClientStreaming(requestChannel: ReceiveChannel): HelloReply { - val job = coroutineContext[Job]!! - job.invokeOnCompletion { serverJob.complete(job) } - deferredServerChannel.complete(spyk(requestChannel)) - val reqValues = requestChannel.consumeAsFlow().map { it.name }.toList() - return HelloReply.newBuilder() - .setMessage(reqValues.joinToString()) - .build() - } - }) + setupUpServerHandler { requestChannel -> + val job = coroutineContext[Job]!! + job.invokeOnCompletion { serverJob.complete(job) } + deferredServerChannel.complete(spyk(requestChannel)) + val reqValues = requestChannel.consumeAsFlow().map { it.name }.toList() + HelloReply.newBuilder() + .setMessage(reqValues.joinToString()) + .build() + } val rpcSpy = RpcSpy() val stub = rpcSpy.stub @@ -175,17 +168,15 @@ class ClientStreamingBackPressureTests { fun `Call is cancelled when client closes request channel`() { val deferredServerChannel = CompletableDeferred>() val serverJob = CompletableDeferred() - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ - override suspend fun sayHelloClientStreaming(requestChannel: ReceiveChannel): HelloReply { - val job = coroutineContext[Job]!! - job.invokeOnCompletion { serverJob.complete(job) } - deferredServerChannel.complete(spyk(requestChannel)) - delay(Long.MAX_VALUE) - return HelloReply.getDefaultInstance() - } - }) + setupUpServerHandler { requestChannel -> + val job = coroutineContext[Job]!! + job.invokeOnCompletion { serverJob.complete(job) } + deferredServerChannel.complete(spyk(requestChannel)) + delay(Long.MAX_VALUE) + HelloReply.getDefaultInstance() + } - val rpcSpy = RpcSpy() + val rpcSpy = RpcSpy(nonDirectGrpcServerRule.channel) val stub = rpcSpy.stub val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()" val expectedException = IllegalStateException("test") @@ -193,7 +184,7 @@ class ClientStreamingBackPressureTests { val (requestChannel, response) = stub .clientCallClientStreaming(methodDescriptor) - runBlocking(Dispatchers.Default) { + runTest { requestChannel.send( HelloRequest.newBuilder() .setName(0.toString()) @@ -201,11 +192,16 @@ class ClientStreamingBackPressureTests { ) requestChannel.close(expectedException) - assertFailsWithStatus(Status.CANCELLED){ - println(response.await()) + assertFailsWithStatus2(Status.CANCELLED){ + response.await() } } + callState { + blockUntilCancellation() + client.closed.assertBlocking { "Client must be closed" } + } + verify(exactly = 1) { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assertExEquals(expectedException, response.getCompletionExceptionOrNull()?.cause) diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt new file mode 100644 index 0000000..0033f05 --- /dev/null +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt @@ -0,0 +1,112 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.integration + +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus2 +import io.grpc.ServerInterceptors +import io.grpc.Status +import io.grpc.examples.helloworld.GreeterCoroutineGrpc +import io.grpc.examples.helloworld.HelloReply +import io.grpc.examples.helloworld.HelloRequest +import io.mockk.verify +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.toList +import java.util.concurrent.Phaser +import kotlin.test.Test +import kotlin.test.assertEquals + + +class ServerStreamingTests : + RpcCallTest(GreeterCoroutineGrpc.sayHelloServerStreamingMethod){ + + private fun setupUpServerHandler( + block: suspend (request: HelloRequest, responseChannel: SendChannel) -> Unit + ) { + val serviceImpl = object : GreeterCoroutineGrpc.GreeterImplBase() { + override suspend fun sayHelloServerStreaming( + request: HelloRequest, responseChannel: SendChannel) = block(request, responseChannel) + } + + val service = ServerInterceptors.intercept(serviceImpl, callState) + grpcServerRule.serviceRegistry.addService(service) + nonDirectGrpcServerRule.serviceRegistry.addService(service) + } + + @Test + fun `Call is successful`(){ + val rpcSpy = RpcSpy() + setupUpServerHandler { request, responseChannel -> + request.name.map { char -> + responseChannel.send(HelloReply.newBuilder() + .setMessage("response:$char") + .build() + ) + } + } + val responseChannel = rpcSpy.coStub.sayHelloServerStreaming(expectedRequest) + val result = runTest { responseChannel.toList() } + + callState.blockUntilClosed() + + assert(responseChannel.isClosedForReceive){ "Response channel should be closed" } + assertEquals(expectedRequest.name.length, result.size) + result.forEachIndexed { index, response -> + assertEquals("response:${expectedRequest.name[index]}", response.message) + } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } + } + + + @Test + fun `Server responds with error`(){ + val rpcSpy = RpcSpy() + val phaser = Phaser(2) + setupUpServerHandler { request, responseChannel -> + repeat(3){ + responseChannel.send(HelloReply.newBuilder() + .setMessage("response:$it") + .build() + ) + } + phaser.arriveAndAwaitAdvance() + responseChannel.close(Status.INVALID_ARGUMENT.asRuntimeException()) + } + + val responseChannel = rpcSpy.coStub.sayHelloServerStreaming(expectedRequest) + val result = mutableListOf() + runTest { + repeat(3){ + result += responseChannel.receive() + } + phaser.arrive() + assertFailsWithStatus2(Status.INVALID_ARGUMENT) { + responseChannel.receive() + } + } + + callState.blockUntilClosed() + + assert(responseChannel.isClosedForReceive){ "Response channel should be closed" } + assertEquals(3, result.size) + result.forEachIndexed { index, response -> + assertEquals("response:$index", response.message) + } + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } + } + +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt index d7cbe64..7694db0 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt @@ -21,6 +21,23 @@ import io.grpc.StatusRuntimeException import kotlin.test.assertEquals import kotlin.test.fail +inline fun assertFailsWithStatus2( + status: Status, + message: String? = null, + block: () -> Unit +){ + try{ + block() + fail("Block did not fail") + }catch (e: Throwable){ + println("assertFailsWithStatus(${e.javaClass}, message: ${e.message})") + assertEquals(StatusRuntimeException::class.java.canonicalName, e.javaClass.canonicalName) + require(e is StatusRuntimeException) + message?.let { assertEquals(it,e.message) } + assertEquals(status.code, e.status.code) + } +} + inline fun assertFailsWithStatus( status: Status, message: String? = null, @@ -30,6 +47,11 @@ inline fun assertFailsWithStatus( block() fail("Block did not fail") }catch (e: StatusRuntimeException){ +// TODO: Fix this in separate PR +// }catch (e: Throwable){ +// assertEquals(StatusRuntimeException::class.java.canonicalName, e.javaClass.canonicalName) +// require(e is StatusRuntimeException) +// println("assertFailsWithStatus(${e.javaClass}, message: ${e.message})") message?.let { assertEquals(it,e.message) } assertEquals(status.code, e.status.code) } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt index 984fbcf..ca67b98 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt @@ -16,12 +16,25 @@ package com.github.marcoferrer.krotoplus.coroutines.utils +import com.github.marcoferrer.krotoplus.coroutines.CALL_OPTION_COROUTINE_CONTEXT import io.grpc.CallOptions import io.grpc.Channel import io.grpc.ClientCall import io.grpc.ClientInterceptor -import io.grpc.ForwardingClientCall +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener +import io.grpc.Metadata import io.grpc.MethodDescriptor +import io.grpc.ServerCall +import io.grpc.ServerCallHandler +import io.grpc.ServerInterceptor +import io.grpc.Status +import io.mockk.spyk +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Job object CancellingClientInterceptor : ClientInterceptor { override fun interceptCall( @@ -30,7 +43,7 @@ object CancellingClientInterceptor : ClientInterceptor { next: Channel ): ClientCall { val _call = next.newCall(method,callOptions) - return object : ForwardingClientCall.SimpleForwardingClientCall(_call){ + return object : SimpleForwardingClientCall(_call){ override fun halfClose() { super.halfClose() // Cancel call after we've verified @@ -40,4 +53,196 @@ object CancellingClientInterceptor : ClientInterceptor { } } -} \ No newline at end of file +} + + +class ClientState( + val intercepted: CompletableDeferred = CompletableDeferred(), + val started: CompletableDeferred = CompletableDeferred(), + val halfClosed: CompletableDeferred = CompletableDeferred(), + val closed: CompletableDeferred = CompletableDeferred(), + val cancelled: CompletableDeferred = CompletableDeferred() +) : Invokable { + + override fun toString(): String { + return "\tClientState(\n" + + "\t\tintercepted=${intercepted.stateToString()}, \n" + + "\t\tstarted=${started.stateToString()},\n" + + "\t\thalfClosed=${halfClosed.stateToString()},\n" + + "\t\tclosed=${closed.stateToString()},\n" + + "\t\tcancelled=${cancelled.stateToString()}\n" + + "\t)" + } +} + +class ServerState( + val intercepted: CompletableDeferred = CompletableDeferred(), + val wasReady: CompletableDeferred = CompletableDeferred(), + val halfClosed: CompletableDeferred = CompletableDeferred(), + val closed: CompletableDeferred = CompletableDeferred(), + val cancelled: CompletableDeferred = CompletableDeferred(), + val completed: CompletableDeferred = CompletableDeferred() +) : Invokable { + override fun toString(): String { + return "\tServerState(\n" + + "\t\tintercepted=${intercepted.stateToString()},\n" + + "\t\twasReady=${wasReady.stateToString()},\n" + + "\t\thalfClosed=${halfClosed.stateToString()},\n" + + "\t\tclosed=${closed.stateToString()},\n" + + "\t\tcancelled=${cancelled.stateToString()},\n" + + "\t\tcompleted=${completed.stateToString()}\n" + + "\t)" + } +} + +class ClientCallSpyInterceptor( + val call: CompletableDeferred> +) : ClientInterceptor { + + override fun interceptCall( + method: MethodDescriptor, + callOptions: CallOptions, + next: Channel + ): ClientCall { + val spy = spyk(next.newCall(method,callOptions)) + call.complete(spy) + return spy + } +} + +class RpcStateInterceptor( + val client: ClientState = ClientState(), + val server: ServerState = ServerState() +) : Invokable, + ClientInterceptor by ClientStateInterceptor(client), + ServerInterceptor by ServerStateInterceptor(server) { + + override fun toString(): String { + return "RpcStateInterceptor(\n" + + "$client,\n" + + "$server\n" + + ")" + } +} + + +interface Invokable + +inline operator fun > T.invoke(block: T.()->Unit) = block() + + +fun CompletableDeferred.stateToString(): String = + "\tisCompleted:$isCompleted,\tisActive:$isActive,\tisCancelled:$isCancelled" + +class ClientStateInterceptor(val state: ClientState) : ClientInterceptor { + + override fun interceptCall( + method: MethodDescriptor, + callOptions: CallOptions, next: Channel + ): ClientCall { + return object : SimpleForwardingClientCall(next.newCall(method,callOptions)) { + + init { + state.intercepted.complete() + } + + override fun start(responseListener: Listener, headers: Metadata) { + println("Client: Call start()") + super.start(object : SimpleForwardingClientCallListener(responseListener){ + + override fun onClose(status: Status?, trailers: Metadata?) { + println("Client: Call Listener onClose(${status?.toDebugString()})") + super.onClose(status, trailers) + state.closed.complete() + } + + }, headers) + state.started.complete(Unit) + } + + override fun halfClose() { + println("Client: Call halfClose()") + super.halfClose() + state.halfClosed.complete() + } + + override fun cancel(message: String?, cause: Throwable?) { + println("Client: Call cancel(message=$message, cause=${cause?.toDebugString()})") + super.cancel(message, cause) + state.cancelled.complete() + } + } + } +} + +class ServerStateInterceptor(val state: ServerState) : ServerInterceptor { + + override fun interceptCall( + call: ServerCall, headers: Metadata, + next: ServerCallHandler + ): ServerCall.Listener { + + val interceptedCall = object : SimpleForwardingServerCall(call){ + + override fun close(status: Status?, trailers: Metadata?) { + println("Server: Call Close, ${status?.toDebugString()}") + super.close(status, trailers) + state.closed.complete() + } + } + + return object: SimpleForwardingServerCallListener(next.startCall(interceptedCall, headers)){ + init { + state.intercepted.complete() + } + + override fun onReady() { + println("Server: Call Listener onReady()") + super.onReady() + state.wasReady.complete() + } + + override fun onHalfClose() { + println("Server: Call Listener onHalfClose()") + super.onHalfClose() + state.halfClosed.complete() + } + + override fun onComplete() { + println("Server: Call Listener onComplete()") + super.onComplete() + state.completed.complete() + } + + override fun onCancel() { + println("Server: Call Listener onCancel()") + super.onCancel() + state.cancelled.complete() + } + } + } +} + +private fun Throwable.toDebugString(): String = + "(${this.javaClass.canonicalName}, ${this.message})" + +private fun Status.toDebugString(): String = + "Status{code=$code, description=$description, cause=${cause?.toDebugString()}}" + + +private fun CompletableDeferred.complete() = complete(Unit) + +fun newCancellingInterceptor(useNormalCancellation: Boolean) = object : ClientInterceptor { + override fun interceptCall( + method: MethodDescriptor, + callOptions: CallOptions, + next: Channel + ): ClientCall { + val job = callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT)[Job]!! + if (useNormalCancellation) + job.cancel() else + job.cancel(CancellationException("interceptor-cancel")) + return next.newCall(method, callOptions) + } +} +