Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CachedTableManager use refresh token to refresh urls #387

Merged
merged 3 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.delta.sharing.TableRefreshResult

import io.delta.sharing.client.util.JsonUtils

Expand All @@ -47,10 +48,10 @@ trait DeltaSharingProfileProvider {

def getCustomTablePath(tablePath: String): String = tablePath

// Map[String, String] is the id to url map.
// Long is the minimum url expiration time for all the urls.
def getCustomRefresher(refresher: () => (Map[String, String], Option[Long])): () =>
(Map[String, String], Option[Long]) = {
// `refresher` takes an optional refreshToken, and returns
// (idToUrlMap, minUrlExpirationTimestamp, refreshToken)
def getCustomRefresher(
refresher: Option[String] => TableRefreshResult): Option[String] => TableRefreshResult = {
refresher
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,31 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}

import io.delta.sharing.client.DeltaSharingProfileProvider

case class TableRefreshResult(
idToUrl: Map[String, String],
expirationTimestamp: Option[Long],
refreshToken: Option[String]
)

/**
* @param expiration the expiration time of the pre signed urls
* @param idToUrl the file id to pre sign url map
* @param refs the references that we track. When all of references in the table are gone, we will
* remove the cached table from our cache.
* @param lastAccess When the table was accessed last time. We will remove old tables that are not
* accessed after `expireAfterAccessMs` milliseconds.
* @param refresher the function to generate a new file id to pre sign url map, as long as the new
* expiration timestamp of the urls.
* @param refresher the function to generate a new file id to pre sign url map, with the new
* expiration timestamp of the urls and the new refresh token.
* @param refreshToken the optional refresh token that can be used by the refresher to retrieve
* the same set of files with refreshed urls.
*/
class CachedTable(
val expiration: Long,
val idToUrl: Map[String, String],
val refs: Seq[WeakReference[AnyRef]],
@volatile var lastAccess: Long,
val refresher: () => (Map[String, String], Option[Long]))
val refresher: Option[String] => TableRefreshResult,
val refreshToken: Option[String])

class CachedTableManager(
val preSignedUrlExpirationMs: Long,
Expand Down Expand Up @@ -96,17 +105,18 @@ class CachedTableManager(
logInfo(s"Updating pre signed urls for $tablePath (expiration time: " +
s"${new java.util.Date(cachedTable.expiration)})")
try {
val (idToUrl, expOpt) = cachedTable.refresher()
val refreshRes = cachedTable.refresher(cachedTable.refreshToken)
val newTable = new CachedTable(
if (isValidUrlExpirationTime(expOpt)) {
expOpt.get
if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) {
refreshRes.expirationTimestamp.get
} else {
preSignedUrlExpirationMs + System.currentTimeMillis()
},
idToUrl,
refreshRes.idToUrl,
cachedTable.refs,
cachedTable.lastAccess,
cachedTable.refresher
cachedTable.refresher,
refreshRes.refreshToken
)
// Failing to replace the table is fine because if it did happen, we would retry after
// `refreshCheckIntervalMs` milliseconds.
Expand Down Expand Up @@ -149,43 +159,51 @@ class CachedTableManager(
* @param tablePath the table path. This is usually the profile file path.
* @param idToUrl the pre signed url map. This will be refreshed when the pre signed urls is going
* to expire.
* @param refs A list of weak references which can be used to determine whether the cache is
* still needed. When all the weak references return null, we will remove the pre
* signed url cache of this table form the cache.
* @param refs A list of weak references which can be used to determine whether the cache is
* still needed. When all the weak references return null, we will remove the pre
* signed url cache of this table form the cache.
* @param profileProvider a profile Provider that can provide customized refresher function.
* @param refresher A function to re-generate pre signed urls for the table.
* @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration
* timestamp of the idToUrl.
* @param refresher A function to re-generate pre signed urls for the table.
* @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration
* timestamp of the idToUrl.
* @param refreshToken an optional refresh token that can be used by the refresher to retrieve
* the same set of files with refreshed urls.
*/
def register(
tablePath: String,
idToUrl: Map[String, String],
refs: Seq[WeakReference[AnyRef]],
profileProvider: DeltaSharingProfileProvider,
refresher: () => (Map[String, String], Option[Long]),
expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs
): Unit = {
refresher: Option[String] => TableRefreshResult,
expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs,
refreshToken: Option[String]
): Unit = {
val customTablePath = profileProvider.getCustomTablePath(tablePath)
val customRefresher = profileProvider.getCustomRefresher(refresher)

val (resolvedIdToUrl, resolvedExpiration) =
val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) =
if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) {
val (refreshedIdToUrl, expOpt) = customRefresher()
if (isValidUrlExpirationTime(expOpt)) {
(refreshedIdToUrl, expOpt.get)
val refreshRes = customRefresher(refreshToken)
if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) {
(refreshRes.idToUrl, refreshRes.expirationTimestamp.get, refreshRes.refreshToken)
} else {
(refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs)
(
refreshRes.idToUrl,
System.currentTimeMillis() + preSignedUrlExpirationMs,
refreshRes.refreshToken
)
}
} else {
(idToUrl, expirationTimestamp)
(idToUrl, expirationTimestamp, refreshToken)
}

val cachedTable = new CachedTable(
resolvedExpiration,
idToUrl = resolvedIdToUrl,
refs,
System.currentTimeMillis(),
customRefresher
customRefresher,
resolvedRefreshToken
)
var oldTable = cache.putIfAbsent(customTablePath, cachedTable)
if (oldTable == null) {
Expand All @@ -203,7 +221,8 @@ class CachedTableManager(
// Try to avoid storing duplicate references
refs.filterNot(ref => oldTable.refs.exists(_.get eq ref.get)) ++ oldTable.refs,
lastAccess = System.currentTimeMillis(),
customRefresher
customRefresher,
cachedTable.refreshToken
)
if (cache.replace(customTablePath, oldTable, mergedTable)) {
// Put the merged one to the cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
(Map("id1" -> "url1", "id2" -> "url2"), None)
})
_ => {
TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None)
},
refreshToken = None)
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"),
"id1")._1 == "url1")
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"),
Expand All @@ -60,9 +61,10 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
(Map("id1" -> "url3", "id2" -> "url4"), None)
})
_ => {
TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None)
},
refreshToken = None)
// We should get the new urls eventually
eventually(timeout(10.seconds)) {
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path2"),
Expand All @@ -76,9 +78,10 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(new AnyRef)),
provider,
() => {
(Map("id1" -> "url3", "id2" -> "url4"), None)
})
_ => {
TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None)
},
refreshToken = None)
// We should remove the cached table eventually
eventually(timeout(10.seconds)) {
System.gc()
Expand All @@ -93,10 +96,11 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
(Map("id1" -> "url3", "id2" -> "url4"), None)
_ => {
TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None)
},
-1
-1,
refreshToken = None
)
// We should get new urls immediately because it's refreshed upon register
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"),
Expand Down Expand Up @@ -124,14 +128,16 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
_ => {
refreshTime += 1
(
TableRefreshResult(
Map("id1" -> ("url" + refreshTime.toString), "id2" -> "url4"),
Some(System.currentTimeMillis() + 1900)
Some(System.currentTimeMillis() + 1900),
None
)
},
System.currentTimeMillis() + 1900
System.currentTimeMillis() + 1900,
None
)
// We should refresh at least 5 times within 10 seconds based on
// (System.currentTimeMillis() + 1900).
Expand All @@ -148,14 +154,16 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
_ => {
refreshTime2 += 1
(
TableRefreshResult(
Map("id1" -> ("url" + refreshTime2.toString), "id2" -> "url4"),
Some(System.currentTimeMillis() + 4900)
Some(System.currentTimeMillis() + 4900),
None
)
},
System.currentTimeMillis() + 4900
System.currentTimeMillis() + 4900,
None
)
// We should refresh 2 times within 10 seconds based on (System.currentTimeMillis() + 4900).
eventually(timeout(10.seconds)) {
Expand All @@ -171,14 +179,16 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
_ => {
refreshTime3 += 1
(
TableRefreshResult(
Map("id1" -> ("url" + refreshTime3.toString), "id2" -> "url4"),
Some(System.currentTimeMillis() - 4900)
Some(System.currentTimeMillis() - 4900),
None
)
},
System.currentTimeMillis() + 6000
System.currentTimeMillis() + 6000,
None
)
// We should refresh 1 times within 10 seconds based on (preSignedUrlExpirationMs = 6000).
try {
Expand All @@ -197,6 +207,52 @@ class CachedTableManagerSuite extends SparkFunSuite {
}
}

test("refresh using refresh token") {
val manager = new CachedTableManager(
preSignedUrlExpirationMs = 10,
refreshCheckIntervalMs = 10,
refreshThresholdMs = 10,
expireAfterAccessMs = 60000
)
try {
val ref = new AnyRef
val provider = new TestDeltaSharingProfileProvider
manager.register(
"test-table-path",
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
refreshToken => {
if (refreshToken.contains("refresh-token-1")) {
TableRefreshResult(
Map("id1" -> "url3", "id2" -> "url4"),
None,
Some("refresh-token-2")
)
} else if (refreshToken.contains("refresh-token-2")) {
TableRefreshResult(
Map("id1" -> "url5", "id2" -> "url6"),
None,
Some("refresh-token-2")
)
} else {
fail("Expecting to refresh with a refresh token")
}
},
refreshToken = Some("refresh-token-1")
)
// We should get url5 and url6 eventually.
eventually(timeout(10.seconds)) {
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"),
"id1")._1 == "url5")
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"),
"id2")._1 == "url6")
}
} finally {
manager.stop()
}
}

test("expireAfterAccessMs") {
val manager = new CachedTableManager(
preSignedUrlExpirationMs = 10,
Expand All @@ -213,9 +269,10 @@ class CachedTableManagerSuite extends SparkFunSuite {
Map("id1" -> "url1", "id2" -> "url2"),
Seq(new WeakReference(ref)),
provider,
() => {
(Map("id1" -> "url1", "id2" -> "url2"), None)
})
_ => {
TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None)
},
refreshToken = None)
Thread.sleep(1000)
// We should remove the cached table when it's not accessed
intercept[IllegalStateException](manager.getPreSignedUrl(
Expand Down
Loading
Loading