Skip to content

Commit

Permalink
feat: Add Token Refresh Implementation to Network Layer (#84)
Browse files Browse the repository at this point in the history
* feat: Add Token Refresh Implementation to Network Layer

It applies token authentication to the network requests before a
network call is made. The expiry duration is saved and it is
applied before a network request is queued.

* refactor: Use milliseconds for expiry time
  • Loading branch information
HamzaIsrar12 authored Nov 22, 2023
1 parent a11faa3 commit d1bf1a6
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ import android.util.Log
import com.google.gson.Gson
import kotlinx.coroutines.runBlocking
import okhttp3.*
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.ResponseBody.Companion.toResponseBody
import okhttp3.logging.HttpLoggingInterceptor
import org.json.JSONException
import org.json.JSONObject
import org.openedx.app.system.notifier.AppNotifier
import org.openedx.app.system.notifier.LogoutEvent
import org.openedx.auth.data.api.AuthApi
import org.openedx.auth.data.model.AuthResponse
import org.openedx.auth.domain.model.AuthResponse
import org.openedx.core.ApiConstants
import org.openedx.core.ApiConstants.TOKEN_TYPE_JWT
import org.openedx.core.BuildConfig
import org.openedx.core.BuildConfig.ACCESS_TOKEN_TYPE
import org.openedx.core.data.storage.CorePreferences
import org.openedx.core.utils.TimeUtils
import retrofit2.Retrofit
import retrofit2.converter.gson.GsonConverterFactory
import java.io.IOException
Expand All @@ -24,9 +27,20 @@ import java.util.concurrent.TimeUnit
class OauthRefreshTokenAuthenticator(
private val preferencesManager: CorePreferences,
private val appNotifier: AppNotifier,
) : Authenticator {
) : Authenticator, Interceptor {

private val authApi: AuthApi
private var lastTokenRefreshRequestTime = 0L

override fun intercept(chain: Interceptor.Chain): Response {
if (isTokenExpired()) {
val response = createUnauthorizedResponse(chain)
val request = authenticate(chain.connection()?.route(), response)

return request?.let { chain.proceed(it) } ?: chain.proceed(chain.request())
}
return chain.proceed(chain.request())
}

init {
val okHttpClient = OkHttpClient.Builder().apply {
Expand All @@ -44,6 +58,7 @@ class OauthRefreshTokenAuthenticator(
.create(AuthApi::class.java)
}

@Synchronized
override fun authenticate(route: Route?, response: Response): Request? {
val accessToken = preferencesManager.accessToken
val refreshToken = preferencesManager.refreshToken
Expand Down Expand Up @@ -112,26 +127,42 @@ class OauthRefreshTokenAuthenticator(
return null
}

private fun isTokenExpired(): Boolean {
val time = TimeUtils.getCurrentTime() + REFRESH_TOKEN_EXPIRY_THRESHOLD
return time >= preferencesManager.accessTokenExpiresAt
}

private fun canRequestTokenRefresh(): Boolean {
return TimeUtils.getCurrentTime() - lastTokenRefreshRequestTime >
REFRESH_TOKEN_INTERVAL_MINIMUM
}

@Throws(IOException::class)
private fun refreshAccessToken(refreshToken: String): AuthResponse? {
val response = authApi.refreshAccessToken(
ApiConstants.TOKEN_TYPE_REFRESH,
BuildConfig.CLIENT_ID,
refreshToken,
ACCESS_TOKEN_TYPE
).execute()
val authResponse = response.body()
if (response.isSuccessful && authResponse != null) {
val newAccessToken = authResponse.accessToken ?: ""
val newRefreshToken = authResponse.refreshToken ?: ""

if (newAccessToken.isNotEmpty() && newRefreshToken.isNotEmpty()) {
preferencesManager.accessToken = newAccessToken
preferencesManager.refreshToken = newRefreshToken
var authResponse: AuthResponse? = null
if (canRequestTokenRefresh()) {
val response = authApi.refreshAccessToken(
ApiConstants.TOKEN_TYPE_REFRESH,
BuildConfig.CLIENT_ID,
refreshToken,
ACCESS_TOKEN_TYPE
).execute()
authResponse = response.body()?.mapToDomain()
if (response.isSuccessful && authResponse != null) {
val newAccessToken = authResponse.accessToken ?: ""
val newRefreshToken = authResponse.refreshToken ?: ""
val newExpireTime = authResponse.getTokenExpiryTime()

if (newAccessToken.isNotEmpty() && newRefreshToken.isNotEmpty()) {
preferencesManager.accessToken = newAccessToken
preferencesManager.refreshToken = newRefreshToken
preferencesManager.accessTokenExpiresAt = newExpireTime
lastTokenRefreshRequestTime = TimeUtils.getCurrentTime()
}
} else if (response.code() == 400) {
//another refresh already in progress
Thread.sleep(1500)
}
} else if (response.code() == 400) {
//another refresh already in progress
Thread.sleep(1500)
}

return authResponse
Expand All @@ -144,7 +175,8 @@ class OauthRefreshTokenAuthenticator(
return jsonObj.getString(FIELD_ERROR_CODE)
} else {
return if (TOKEN_TYPE_JWT.equals(ACCESS_TOKEN_TYPE, ignoreCase = true)) {
val errorType = if (jsonObj.has(FIELD_DETAIL)) FIELD_DETAIL else FIELD_DEVELOPER_MESSAGE
val errorType =
if (jsonObj.has(FIELD_DETAIL)) FIELD_DETAIL else FIELD_DEVELOPER_MESSAGE
jsonObj.getString(errorType)
} else {
val errorCode = jsonObj
Expand All @@ -163,6 +195,41 @@ class OauthRefreshTokenAuthenticator(
}
}

/**
* [createUnauthorizedResponse] creates an unauthorized okhttp response with the initial chain
* request for [authenticate] method of [OauthRefreshTokenAuthenticator]. The response is
* specially designed to trigger the 'Token Expired' case of the [authenticate] method so that
* it can handle the refresh logic of the access token accordingly.
*
* @param chain Chain request for authentication
* @return Custom unauthorized response builder with initial request
*/
private fun createUnauthorizedResponse(chain: Interceptor.Chain) = Response.Builder()
.code(401)
.request(chain.request())
.protocol(Protocol.HTTP_1_1)
.message("Unauthorized")
.headers(chain.request().headers)
.body(getResponseBody())
.build()

/**
* [getResponseBody] generates an error response body based on access token type because both
* Bearer and JWT have their own sets of errors.
*
* @return ResponseBody based on access token type
*/
private fun getResponseBody(): ResponseBody {
val tokenType = ACCESS_TOKEN_TYPE
val jsonObject = if (TOKEN_TYPE_JWT.equals(tokenType, ignoreCase = true)) {
JSONObject().put("detail", JWT_TOKEN_EXPIRED)
} else {
JSONObject().put("error_code", TOKEN_EXPIRED_ERROR_MESSAGE)
}

return jsonObject.toString().toResponseBody("application/json".toMediaType())
}

companion object {
private const val HEADER_AUTHORIZATION = "Authorization"

Expand All @@ -177,5 +244,19 @@ class OauthRefreshTokenAuthenticator(
private const val FIELD_ERROR_CODE = "error_code"
private const val FIELD_DETAIL = "detail"
private const val FIELD_DEVELOPER_MESSAGE = "developer_message"

/**
* [REFRESH_TOKEN_EXPIRY_THRESHOLD] behave as a buffer time to be used in the expiry
* verification method of the access token to ensure that the token doesn't expire during
* an active session.
*/
private const val REFRESH_TOKEN_EXPIRY_THRESHOLD = 60 * 1000

/**
* [REFRESH_TOKEN_INTERVAL_MINIMUM] behave as a buffer time for refresh token network
* requests. It prevents multiple calls to refresh network requests in case of an
* unauthorized access token during async requests.
*/
private const val REFRESH_TOKEN_INTERVAL_MINIMUM = 60 * 1000
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,36 @@ package org.openedx.app.data.storage
import android.content.Context
import com.google.gson.Gson
import org.openedx.app.BuildConfig
import org.openedx.core.data.storage.CorePreferences
import org.openedx.profile.data.model.Account
import org.openedx.core.data.model.User
import org.openedx.core.data.storage.CorePreferences
import org.openedx.core.data.storage.InAppReviewPreferences
import org.openedx.core.domain.model.VideoSettings
import org.openedx.profile.data.model.Account
import org.openedx.profile.data.storage.ProfilePreferences
import org.openedx.whatsnew.data.storage.WhatsNewPreferences

class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences, WhatsNewPreferences,
InAppReviewPreferences {
class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences,
WhatsNewPreferences, InAppReviewPreferences {

private val sharedPreferences = context.getSharedPreferences(BuildConfig.APPLICATION_ID, Context.MODE_PRIVATE)
private val sharedPreferences =
context.getSharedPreferences(BuildConfig.APPLICATION_ID, Context.MODE_PRIVATE)

private fun saveString(key: String, value: String) {
sharedPreferences.edit().apply {
putString(key, value)
}.apply()
}

private fun getString(key: String): String = sharedPreferences.getString(key, "") ?: ""

private fun saveLong(key: String, value: Long) {
sharedPreferences.edit().apply {
putLong(key, value)
}.apply()
}

private fun getLong(key: String): Long = sharedPreferences.getLong(key, 0L)

private fun saveBoolean(key: String, value: Boolean) {
sharedPreferences.edit().apply {
putBoolean(key, value)
Expand All @@ -36,6 +46,7 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
remove(ACCESS_TOKEN)
remove(REFRESH_TOKEN)
remove(USER)
remove(EXPIRES_IN)
}.apply()
}

Expand All @@ -51,6 +62,12 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
}
get() = getString(REFRESH_TOKEN)

override var accessTokenExpiresAt: Long
set(value) {
saveLong(EXPIRES_IN, value)
}
get() = getLong(EXPIRES_IN)

override var user: User?
set(value) {
val userJson = Gson().toJson(value)
Expand Down Expand Up @@ -95,7 +112,10 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
}
get() {
val versionNameString = getString(LAST_REVIEW_VERSION)
return Gson().fromJson(versionNameString, InAppReviewPreferences.VersionName::class.java)
return Gson().fromJson(
versionNameString,
InAppReviewPreferences.VersionName::class.java
)
?: InAppReviewPreferences.VersionName.default
}

Expand All @@ -109,11 +129,12 @@ class PreferencesManager(context: Context) : CorePreferences, ProfilePreferences
companion object {
private const val ACCESS_TOKEN = "access_token"
private const val REFRESH_TOKEN = "refresh_token"
private const val EXPIRES_IN = "expires_in"
private const val USER = "user"
private const val ACCOUNT = "account"
private const val VIDEO_SETTINGS = "video_settings"
private const val LAST_WHATS_NEW_VERSION = "last_whats_new_version"
private const val LAST_REVIEW_VERSION = "last_review_version"
private const val APP_WAS_POSITIVE_RATED = "app_was_positive_rated"
}
}
}
3 changes: 2 additions & 1 deletion app/src/main/java/org/openedx/app/di/NetworkingModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ val networkingModule = module {
}
addInterceptor(HandleErrorInterceptor(get()))
addInterceptor(AppUpgradeInterceptor(get()))
addInterceptor(get<OauthRefreshTokenAuthenticator>())
authenticator(get<OauthRefreshTokenAuthenticator>())
}.build()
}
Expand All @@ -53,4 +54,4 @@ val networkingModule = module {

inline fun <reified T> provideApi(retrofit: Retrofit): T {
return retrofit.create(T::class.java)
}
}
15 changes: 13 additions & 2 deletions auth/src/main/java/org/openedx/auth/data/model/AuthResponse.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.openedx.auth.data.model

import com.google.gson.annotations.SerializedName
import org.openedx.auth.domain.model.AuthResponse

data class AuthResponse(
@SerializedName("access_token")
Expand All @@ -15,5 +16,15 @@ data class AuthResponse(
var error: String?,
@SerializedName("refresh_token")
var refreshToken: String?,
)

) {
fun mapToDomain(): AuthResponse {
return AuthResponse(
accessToken = accessToken,
tokenType = tokenType,
expiresIn = expiresIn?.times(1000),
scope = scope,
error = error,
refreshToken = refreshToken,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.openedx.auth.data.repository

import org.openedx.auth.data.api.AuthApi
import org.openedx.auth.data.model.ValidationFields
import org.openedx.auth.domain.model.AuthResponse
import org.openedx.core.ApiConstants
import org.openedx.core.data.storage.CorePreferences
import org.openedx.core.domain.model.RegistrationField
Expand All @@ -16,18 +17,19 @@ class AuthRepository(
username: String,
password: String,
) {
val authResponse = api.getAccessToken(
val authResponse: AuthResponse = api.getAccessToken(
ApiConstants.GRANT_TYPE_PASSWORD,
org.openedx.core.BuildConfig.CLIENT_ID,
username,
password,
org.openedx.core.BuildConfig.ACCESS_TOKEN_TYPE
)
).mapToDomain()
if (authResponse.error != null) {
throw EdxError.UnknownException(authResponse.error!!)
}
preferencesManager.accessToken = authResponse.accessToken ?: ""
preferencesManager.refreshToken = authResponse.refreshToken ?: ""
preferencesManager.accessTokenExpiresAt = authResponse.getTokenExpiryTime()
val user = api.getProfile()
preferencesManager.user = user
}
Expand All @@ -47,4 +49,4 @@ class AuthRepository(
suspend fun passwordReset(email: String): Boolean {
return api.passwordReset(email).success
}
}
}
19 changes: 19 additions & 0 deletions auth/src/main/java/org/openedx/auth/domain/model/AuthResponse.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.openedx.auth.domain.model

import android.os.Parcelable
import kotlinx.parcelize.Parcelize
import org.openedx.core.utils.TimeUtils

@Parcelize
data class AuthResponse(
var accessToken: String?,
var tokenType: String?,
var expiresIn: Long?,
var scope: String?,
var error: String?,
var refreshToken: String?,
) : Parcelable {
fun getTokenExpiryTime(): Long {
return (expiresIn ?: 0L) + TimeUtils.getCurrentTime()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import org.openedx.core.domain.model.VideoSettings
interface CorePreferences {
var accessToken: String
var refreshToken: String
var accessTokenExpiresAt: Long
var user: User?
var videoSettings: VideoSettings

fun clear()
}
}
Loading

0 comments on commit d1bf1a6

Please sign in to comment.