Skip to content

Commit

Permalink
Run suspending calls within CoroutineDispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
jingibus committed Feb 11, 2022
1 parent 4860743 commit 4bb9250
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 106 deletions.
10 changes: 3 additions & 7 deletions retrofit-mock/src/main/java/retrofit2/mock/BehaviorDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,9 @@ public <R> T returning(Call<R> call) {

Call<Object> adaptedCall = (Call<Object>) adapted;
Continuation<Object> continuation = (Continuation<Object>) args[args.length - 1];
try {
return adapterInfo.wantsResponse
? KotlinExtensions.awaitResponse(adaptedCall, continuation)
: KotlinExtensions.await(adaptedCall, continuation);
} catch (Exception e) {
return KotlinExtensions.suspendAndThrow(e, continuation);
}
return adapterInfo.wantsResponse
? KotlinExtensions.awaitResponse(adaptedCall, continuation)
: KotlinExtensions.await(adaptedCall, continuation);
});
}

Expand Down
42 changes: 17 additions & 25 deletions retrofit/src/main/java/retrofit2/HttpServiceMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.lang.reflect.Type;
import javax.annotation.Nullable;
import kotlin.coroutines.Continuation;
import kotlinx.coroutines.CoroutineDispatcher;
import okhttp3.ResponseBody;

/** Adapts an invocation of an interface method into an HTTP call. */
Expand Down Expand Up @@ -94,7 +95,8 @@ static <ResponseT, ReturnT> HttpServiceMethod<ResponseT, ReturnT> parseAnnotatio
requestFactory,
callFactory,
responseConverter,
(CallAdapter<ResponseT, Call<ResponseT>>) callAdapter);
(CallAdapter<ResponseT, Call<ResponseT>>) callAdapter,
retrofit.coroutineDispatcher);
} else {
//noinspection unchecked Kotlin compiler guarantees ReturnT to be Object.
return (HttpServiceMethod<ResponseT, ReturnT>)
Expand All @@ -103,7 +105,8 @@ static <ResponseT, ReturnT> HttpServiceMethod<ResponseT, ReturnT> parseAnnotatio
callFactory,
responseConverter,
(CallAdapter<ResponseT, Call<ResponseT>>) callAdapter,
continuationBodyNullable);
continuationBodyNullable,
retrofit.coroutineDispatcher);
}
}

Expand Down Expand Up @@ -168,14 +171,17 @@ protected ReturnT adapt(Call<ResponseT> call, Object[] args) {

static final class SuspendForResponse<ResponseT> extends HttpServiceMethod<ResponseT, Object> {
private final CallAdapter<ResponseT, Call<ResponseT>> callAdapter;
private final @Nullable CoroutineDispatcher coroutineDispatcher;

SuspendForResponse(
RequestFactory requestFactory,
okhttp3.Call.Factory callFactory,
Converter<ResponseBody, ResponseT> responseConverter,
CallAdapter<ResponseT, Call<ResponseT>> callAdapter) {
CallAdapter<ResponseT, Call<ResponseT>> callAdapter,
@Nullable CoroutineDispatcher coroutineDispatcher) {
super(requestFactory, callFactory, responseConverter);
this.callAdapter = callAdapter;
this.coroutineDispatcher = coroutineDispatcher;
}

@Override
Expand All @@ -186,28 +192,26 @@ protected Object adapt(Call<ResponseT> call, Object[] args) {
Continuation<Response<ResponseT>> continuation =
(Continuation<Response<ResponseT>>) args[args.length - 1];

// See SuspendForBody for explanation about this try/catch.
try {
return KotlinExtensions.awaitResponse(call, continuation);
} catch (Exception e) {
return KotlinExtensions.suspendAndThrow(e, continuation);
}
return KotlinExtensions.awaitResponse(call, this.coroutineDispatcher, continuation);
}
}

static final class SuspendForBody<ResponseT> extends HttpServiceMethod<ResponseT, Object> {
private final CallAdapter<ResponseT, Call<ResponseT>> callAdapter;
private final boolean isNullable;
private final @Nullable CoroutineDispatcher coroutineDispatcher;

SuspendForBody(
RequestFactory requestFactory,
okhttp3.Call.Factory callFactory,
Converter<ResponseBody, ResponseT> responseConverter,
CallAdapter<ResponseT, Call<ResponseT>> callAdapter,
boolean isNullable) {
boolean isNullable,
@Nullable CoroutineDispatcher coroutineDispatcher) {
super(requestFactory, callFactory, responseConverter);
this.callAdapter = callAdapter;
this.isNullable = isNullable;
this.coroutineDispatcher = coroutineDispatcher;
}

@Override
Expand All @@ -217,21 +221,9 @@ protected Object adapt(Call<ResponseT> call, Object[] args) {
//noinspection unchecked Checked by reflection inside RequestFactory.
Continuation<ResponseT> continuation = (Continuation<ResponseT>) args[args.length - 1];

// Calls to OkHttp Call.enqueue() like those inside await and awaitNullable can sometimes
// invoke the supplied callback with an exception before the invoking stack frame can return.
// Coroutines will intercept the subsequent invocation of the Continuation and throw the
// exception synchronously. A Java Proxy cannot throw checked exceptions without them being
// declared on the interface method. To avoid the synchronous checked exception being wrapped
// in an UndeclaredThrowableException, it is intercepted and supplied to a helper which will
// force suspension to occur so that it can be instead delivered to the continuation to
// bypass this restriction.
try {
return isNullable
? KotlinExtensions.awaitNullable(call, continuation)
: KotlinExtensions.await(call, continuation);
} catch (Exception e) {
return KotlinExtensions.suspendAndThrow(e, continuation);
}
return isNullable
? KotlinExtensions.awaitNullable(call, coroutineDispatcher, continuation)
: KotlinExtensions.await(call, coroutineDispatcher, continuation);
}
}
}
126 changes: 59 additions & 67 deletions retrofit/src/main/java/retrofit2/KotlinExtensions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,95 +25,87 @@ import kotlin.coroutines.intrinsics.intercepted
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.withContext

inline fun <reified T: Any> Retrofit.create(): T = create(T::class.java)

suspend fun <T : Any> Call<T>.await(): T {
return suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
cancel()
}
enqueue(object : Callback<T> {
override fun onResponse(call: Call<T>, response: Response<T>) {
if (response.isSuccessful) {
val body = response.body()
if (body == null) {
val invocation = call.request().tag(Invocation::class.java)!!
val method = invocation.method()
val e = KotlinNullPointerException("Response from " +
@JvmOverloads
suspend fun <T : Any> Call<T>.await(coroutineDispatcher: CoroutineDispatcher? = null): T {
return withContext(coroutineDispatcher ?: Dispatchers.Default) {
suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
cancel()
}
enqueue(object : Callback<T> {
override fun onResponse(call: Call<T>, response: Response<T>) {
if (response.isSuccessful) {
val body = response.body()
if (body == null) {
val invocation = call.request().tag(Invocation::class.java)!!
val method = invocation.method()
val e = KotlinNullPointerException("Response from " +
method.declaringClass.name +
'.' +
method.name +
" was null but response body type was declared as non-null")
continuation.resumeWithException(e)
continuation.resumeWithException(e)
} else {
continuation.resume(body)
}
} else {
continuation.resume(body)
continuation.resumeWithException(HttpException(response))
}
} else {
continuation.resumeWithException(HttpException(response))
}
}

override fun onFailure(call: Call<T>, t: Throwable) {
continuation.resumeWithException(t)
}
})
override fun onFailure(call: Call<T>, t: Throwable) {
continuation.resumeWithException(t)
}
})
}
}
}

@JvmName("awaitNullable")
suspend fun <T : Any> Call<T?>.await(): T? {
return suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
cancel()
}
enqueue(object : Callback<T?> {
override fun onResponse(call: Call<T?>, response: Response<T?>) {
if (response.isSuccessful) {
continuation.resume(response.body())
} else {
continuation.resumeWithException(HttpException(response))
}
suspend fun <T : Any> Call<T?>.await(coroutineDispatcher: CoroutineDispatcher?): T? {
return withContext(coroutineDispatcher ?: Dispatchers.Default) {
suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
cancel()
}
enqueue(object : Callback<T?> {
override fun onResponse(call: Call<T?>, response: Response<T?>) {
if (response.isSuccessful) {
continuation.resume(response.body())
} else {
continuation.resumeWithException(HttpException(response))
}
}

override fun onFailure(call: Call<T?>, t: Throwable) {
continuation.resumeWithException(t)
}
})
override fun onFailure(call: Call<T?>, t: Throwable) {
continuation.resumeWithException(t)
}
})
}
}
}

suspend fun <T> Call<T>.awaitResponse(): Response<T> {
return suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
cancel()
}
enqueue(object : Callback<T> {
override fun onResponse(call: Call<T>, response: Response<T>) {
continuation.resume(response)
}

override fun onFailure(call: Call<T>, t: Throwable) {
continuation.resumeWithException(t)
@JvmOverloads
suspend fun <T> Call<T>.awaitResponse(coroutineDispatcher: CoroutineDispatcher? = null): Response<T> {
return withContext(coroutineDispatcher ?: Dispatchers.Default) {
suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
cancel()
}
})
}
}
enqueue(object : Callback<T> {
override fun onResponse(call: Call<T>, response: Response<T>) {
continuation.resume(response)
}

/**
* Force the calling coroutine to suspend before throwing [this].
*
* This is needed when a checked exception is synchronously caught in a [java.lang.reflect.Proxy]
* invocation to avoid being wrapped in [java.lang.reflect.UndeclaredThrowableException].
*
* The implementation is derived from:
* https://github.com/Kotlin/kotlinx.coroutines/pull/1667#issuecomment-556106349
*/
internal suspend fun Exception.suspendAndThrow(): Nothing {
suspendCoroutineUninterceptedOrReturn<Nothing> { continuation ->
Dispatchers.Default.dispatch(continuation.context) {
continuation.intercepted().resumeWithException(this@suspendAndThrow)
override fun onFailure(call: Call<T>, t: Throwable) {
continuation.resumeWithException(t)
}
})
}
COROUTINE_SUSPENDED
}
}
23 changes: 21 additions & 2 deletions retrofit/src/main/java/retrofit2/Retrofit.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import kotlinx.coroutines.CoroutineDispatcher;
import kotlinx.coroutines.Dispatchers;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.RequestBody;
Expand Down Expand Up @@ -74,6 +76,7 @@ public final class Retrofit {
final int defaultCallAdapterFactoriesSize;
final @Nullable Executor callbackExecutor;
final boolean validateEagerly;
final @Nullable CoroutineDispatcher coroutineDispatcher;

Retrofit(
okhttp3.Call.Factory callFactory,
Expand All @@ -83,7 +86,8 @@ public final class Retrofit {
List<CallAdapter.Factory> callAdapterFactories,
int defaultCallAdapterFactoriesSize,
@Nullable Executor callbackExecutor,
boolean validateEagerly) {
boolean validateEagerly,
@Nullable CoroutineDispatcher coroutineDispatcher) {
this.callFactory = callFactory;
this.baseUrl = baseUrl;
this.converterFactories = converterFactories; // Copy+unmodifiable at call site.
Expand All @@ -92,6 +96,7 @@ public final class Retrofit {
this.defaultCallAdapterFactoriesSize = defaultCallAdapterFactoriesSize;
this.callbackExecutor = callbackExecutor;
this.validateEagerly = validateEagerly;
this.coroutineDispatcher = coroutineDispatcher;
}

/**
Expand Down Expand Up @@ -437,6 +442,7 @@ public static final class Builder {
private final List<CallAdapter.Factory> callAdapterFactories = new ArrayList<>();
private @Nullable Executor callbackExecutor;
private boolean validateEagerly;
private @Nullable CoroutineDispatcher coroutineDispatcher;

public Builder() {}

Expand Down Expand Up @@ -610,6 +616,18 @@ public Builder validateEagerly(boolean validateEagerly) {
return this;
}

/**
* Suspending method call implementations will be run on this {@link CoroutineDispatcher}, if
* provided. If no dispatcher is provided, calls are run on {@link Dispatchers#getDefault()}.
*
* <p>Network requests are still run on the {@link OkHttpClient}'s {@link okhttp3.Dispatcher},
* but {@link okhttp3.Call.Factory} invocations will be run on this dispatcher.
*/
public Builder coroutineDispatcher(@Nullable CoroutineDispatcher coroutineDispatcher) {
this.coroutineDispatcher = coroutineDispatcher;
return this;
}

/**
* Create the {@link Retrofit} instance using the configured values.
*
Expand Down Expand Up @@ -660,7 +678,8 @@ public Retrofit build() {
unmodifiableList(callAdapterFactories),
defaultCallAdapterFactories.size(),
callbackExecutor,
validateEagerly);
validateEagerly,
coroutineDispatcher);
}
}
}
Loading

0 comments on commit 4bb9250

Please sign in to comment.