Skip to content

Commit

Permalink
Fix race in client cancellation propagation (#98)
Browse files Browse the repository at this point in the history
* fix race in client cancellation propagation

* remove cancellation message assertion

* abstract outbound channel completion handler logic
  • Loading branch information
marcoferrer authored Dec 26, 2019
1 parent 8c62ab8 commit 3a75d61
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 28 deletions.
2 changes: 1 addition & 1 deletion kroto-plus-coroutines/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ protobuf {
}

jacoco {
toolVersion = "0.8.3"
toolVersion = "0.8.5"
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ProducerScope
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException

/**
* Launch a [Job] within a [ProducerScope] using the supplied channel as the Receiver.
Expand Down Expand Up @@ -54,4 +58,14 @@ public fun <T> CoroutineScope.launchProducerJob(
channel.close(it?.toRpcException())
}
}
}
}

internal suspend fun Channel<*>.awaitCloseOrThrow(){
suspendCancellableCoroutine<Unit> { cont ->
invokeOnClose { error ->
if(error == null)
cont.resume(Unit) else
cont.resumeWithException(error)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@

package com.github.marcoferrer.krotoplus.coroutines.call

import com.github.marcoferrer.krotoplus.coroutines.awaitCloseOrThrow
import io.grpc.stub.CallStreamObserver
import io.grpc.stub.ClientCallStreamObserver
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.ActorScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.actor
import kotlinx.coroutines.launch
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger

Expand All @@ -48,6 +50,26 @@ internal fun <T> CallStreamObserver<*>.applyInboundFlowControl(

internal typealias MessageHandler = suspend ActorScope<*>.() -> Unit

internal inline fun <T> CoroutineScope.attachOutboundChannelCompletionHandler(
streamObserver: CallStreamObserver<T>,
targetChannel: Channel<T>,
crossinline onSuccess: () -> Unit = {},
crossinline onError: (Throwable) -> Unit = {}
){
launch(start = CoroutineStart.UNDISPATCHED) {
val job = coroutineContext[Job]!!
try {
targetChannel.awaitCloseOrThrow()
onSuccess()
} catch (error: Throwable) {
if(!job.isCancelled){
streamObserver.completeSafely(error, convertError = streamObserver !is ClientCallStreamObserver)
}
onError(error)
}
}
}

internal fun <T> CoroutineScope.applyOutboundFlowControl(
streamObserver: CallStreamObserver<T>,
targetChannel: Channel<T>
Expand Down Expand Up @@ -75,8 +97,6 @@ internal fun <T> CoroutineScope.applyOutboundFlowControl(
// the only way we can do this in the current implementation.
streamObserver.completeSafely(e, convertError = streamObserver !is ClientCallStreamObserver)
isCompleted.set(true)
} else {
throw e
}
}
if (targetChannel.isClosedForReceive &&
Expand All @@ -90,25 +110,21 @@ internal fun <T> CoroutineScope.applyOutboundFlowControl(

val messageHandlerActor = actor<MessageHandler>(
capacity = Channel.BUFFERED,
context = Dispatchers.Unconfined + CoroutineExceptionHandler { _, e ->
streamObserver.completeSafely(e)
targetChannel.close(e)
}
context = Dispatchers.Unconfined
) {

for (handler in channel) {
if (isCompleted.get()) break
handler(this)
}
if (!isCompleted.get()) {
streamObserver.completeSafely()
try {
for (handler in channel) {
if (isCompleted.get()) break
handler(this)
}
if (!isCompleted.get()) {
streamObserver.completeSafely()
}
}catch (e: Throwable){
channel.cancel()
}
}

targetChannel.invokeOnClose {
messageHandlerActor.close()
}

streamObserver.setOnReadyHandler {
try {
if (!messageHandlerActor.isClosedForSend) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.github.marcoferrer.krotoplus.coroutines.client
import com.github.marcoferrer.krotoplus.coroutines.call.FlowControlledInboundStreamObserver
import com.github.marcoferrer.krotoplus.coroutines.call.MessageHandler
import com.github.marcoferrer.krotoplus.coroutines.call.applyOutboundFlowControl
import com.github.marcoferrer.krotoplus.coroutines.call.attachOutboundChannelCompletionHandler
import io.grpc.Status
import io.grpc.stub.ClientCallStreamObserver
import io.grpc.stub.ClientResponseObserver
import kotlinx.coroutines.CancellationException
Expand Down Expand Up @@ -96,6 +98,8 @@ internal class ClientBidiCallChannelImpl<ReqT,RespT>(

override val isInboundCompleted = AtomicBoolean()

private var aborted: Boolean = false

override val transientInboundMessageCount: AtomicInteger = AtomicInteger()

override lateinit var callStreamObserver: ClientCallStreamObserver<ReqT>
Expand All @@ -106,18 +110,28 @@ internal class ClientBidiCallChannelImpl<ReqT,RespT>(
callStreamObserver = requestStream.apply { disableAutoInboundFlowControl() }
outboundMessageHandler = applyOutboundFlowControl(requestStream,outboundChannel)

attachOutboundChannelCompletionHandler(
callStreamObserver, outboundChannel,
onSuccess = { outboundMessageHandler.close() },
onError = { error -> inboundChannel.close(error) }
)

inboundChannel.invokeOnClose {
// If the client prematurely closes the response channel
// we need to propagate this as a cancellation to the underlying call
if(!outboundChannel.isClosedForSend && coroutineContext[Job]?.isCancelled == false){
callStreamObserver.cancel("Client has cancelled call", it)
if(!outboundChannel.isClosedForSend && coroutineContext[Job]?.isCancelled == false && !aborted){
outboundChannel.close(Status.CANCELLED
.withDescription("Client has cancelled call")
.withCause(it)
.asRuntimeException())
}
}
}

override fun onNext(value: RespT): Unit = onNextWithBackPressure(value)

override fun onError(t: Throwable) {
aborted = true
outboundChannel.close(t)
outboundChannel.cancel(CancellationException(t.message,t))
inboundChannel.close(t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.github.marcoferrer.krotoplus.coroutines.client

import com.github.marcoferrer.krotoplus.coroutines.call.MessageHandler
import com.github.marcoferrer.krotoplus.coroutines.call.applyOutboundFlowControl
import com.github.marcoferrer.krotoplus.coroutines.call.attachOutboundChannelCompletionHandler
import io.grpc.stub.ClientCallStreamObserver
import io.grpc.stub.ClientResponseObserver
import kotlinx.coroutines.CancellationException
Expand Down Expand Up @@ -71,6 +72,11 @@ internal class ClientStreamingCallChannelImpl<ReqT,RespT>(
callStreamObserver = requestStream
outboundMessageHandler = applyOutboundFlowControl(requestStream, outboundChannel)

attachOutboundChannelCompletionHandler(
callStreamObserver, outboundChannel,
onSuccess = { outboundMessageHandler.close() },
onError = { error -> completableResponse.completeExceptionally(error) }
)
completableResponse.invokeOnCompletion {
// If the client prematurely cancels the response
// we need to propagate this as a cancellation to the underlying call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.github.marcoferrer.krotoplus.coroutines.server

import com.github.marcoferrer.krotoplus.coroutines.call.applyInboundFlowControl
import com.github.marcoferrer.krotoplus.coroutines.call.applyOutboundFlowControl
import com.github.marcoferrer.krotoplus.coroutines.call.attachOutboundChannelCompletionHandler
import com.github.marcoferrer.krotoplus.coroutines.call.bindToClientCancellation
import com.github.marcoferrer.krotoplus.coroutines.call.completeSafely
import com.github.marcoferrer.krotoplus.coroutines.call.newRpcScope
Expand Down Expand Up @@ -67,6 +68,12 @@ public fun <ReqT, RespT> ServiceScope.serverCallServerStreaming(
with(newRpcScope(initialContext, methodDescriptor)) {
bindToClientCancellation(serverCallObserver)
val outboundMessageHandler = applyOutboundFlowControl(serverCallObserver,responseChannel)

attachOutboundChannelCompletionHandler(
serverCallObserver, responseChannel,
onSuccess = { outboundMessageHandler.close() }
)

launch(start = CoroutineStart.ATOMIC) {
try{
block(responseChannel)
Expand Down Expand Up @@ -165,6 +172,11 @@ public fun <ReqT, RespT> ServiceScope.serverCallBidiStreaming(
}
)

attachOutboundChannelCompletionHandler(
serverCallObserver, responseChannel,
onSuccess = { outboundMessageHandler.close() }
)

launch(start = CoroutineStart.ATOMIC) {
serverCallObserver.request(1)
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ class ClientCallBidiStreamingTests {
}
}

verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) }
assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" }
assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" }
}
Expand Down Expand Up @@ -401,7 +400,7 @@ class ClientCallBidiStreamingTests {
}

@Test
fun `Call is cancelled when request channel closed with error`() {
fun `Call is cancelled when request channel closed with error concurrently`() {
val rpcSpy = RpcSpy()
val stub = rpcSpy.stub
val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()"
Expand Down Expand Up @@ -437,7 +436,43 @@ class ClientCallBidiStreamingTests {
result.forEachIndexed { index, message ->
assertEquals("Req:#$index/Resp:#$index",message)
}
verify(exactly = 1) { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) }
verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) }
assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" }
assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" }
}

@Test
fun `Call is cancelled when request channel closed with error sequentially`() {
val rpcSpy = RpcSpy()
val stub = rpcSpy.stub
val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()"
val expectedException = IllegalStateException("test")

setupServerHandlerSuccess()
val (requestChannel, responseChannel) = stub
.clientCallBidiStreaming(methodDescriptor)

val result = mutableListOf<String>()
runBlocking(Dispatchers.Default) {
requestChannel.send(
HelloRequest.newBuilder()
.setName(0.toString())
.build()
)
requestChannel.close(expectedException)

assertFailsWithStatus(Status.CANCELLED,"CANCELLED: $expectedCancelMessage"){
responseChannel.consumeAsFlow()
.collect { result.add(it.message) }
}
}


assertEquals(1, result.size)
result.forEachIndexed { index, message ->
assertEquals("Req:#$index/Resp:#$index",message)
}
verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) }
assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" }
assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" }
}
Expand Down Expand Up @@ -472,7 +507,7 @@ class ClientCallBidiStreamingTests {
responseChannel.cancel()
}

verify(exactly = 1) { rpcSpy.call.cancel("Client has cancelled call",any()) }
verify { rpcSpy.call.cancel("Cancelled by client with StreamObserver.onError()",any()) }
assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" }
assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class ClientCallClientStreamingTests {
}

@Test
fun `Call is cancelled when request channel closed with error`() {
fun `Call is cancelled when request channel closed with error concurrently`() {
val rpcSpy = RpcSpy()
val stub = rpcSpy.stub
val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()"
Expand Down Expand Up @@ -363,7 +363,38 @@ class ClientCallClientStreamingTests {
}
}

verify(exactly = 1) { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) }
verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) }
assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" }
assertExEquals(expectedException, response.getCompletionExceptionOrNull()?.cause)
assert(response.isCancelled) { "Response should not be cancelled" }
}


@Test
fun `Call is cancelled when request channel closed with error sequentially`() {
val rpcSpy = RpcSpy()
val stub = rpcSpy.stub
val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()"
val expectedException = IllegalStateException("test")

setupServerHandlerSuccess()
val (requestChannel, response) = stub
.clientCallClientStreaming(methodDescriptor)

runBlocking(Dispatchers.Default) {
requestChannel.send(
HelloRequest.newBuilder()
.setName(0.toString())
.build()
)
requestChannel.close(expectedException)

assertFailsWithStatus(Status.CANCELLED,"CANCELLED: $expectedCancelMessage"){
response.await()
}
}

verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) }
assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" }
assertExEquals(expectedException, response.getCompletionExceptionOrNull()?.cause)
assert(response.isCancelled) { "Response should not be cancelled" }
Expand Down
Loading

0 comments on commit 3a75d61

Please sign in to comment.