Skip to content

Commit

Permalink
refresh using refreshToken
Browse files Browse the repository at this point in the history
  • Loading branch information
charlenelyu-db committed Sep 1, 2023
1 parent 22a619e commit 7b8c8f5
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ trait DeltaSharingProfileProvider {

def getCustomTablePath(tablePath: String): String = tablePath

// Map[String, String] is the id to url map.
// Long is the minimum url expiration time for all the urls.
def getCustomRefresher(refresher: () => (Map[String, String], Option[Long])): () =>
(Map[String, String], Option[Long]) = {
// `refresher` takes an optional refreshToken, and returns
// (idToUrlMap, minUrlExpirationTimestamp, refreshToken)
def getCustomRefresher(
refresher: Option[String] => (Map[String, String], Option[Long], Option[String]))
: Option[String] => (Map[String, String], Option[Long], Option[String]) = {
refresher
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ import io.delta.sharing.client.DeltaSharingProfileProvider
* remove the cached table from our cache.
* @param lastAccess When the table was accessed last time. We will remove old tables that are not
* accessed after `expireAfterAccessMs` milliseconds.
* @param refresher the function to generate a new file id to pre sign url map, as long as the new
* expiration timestamp of the urls.
* @param refresher the function to generate a new file id to pre sign url map, with the new
* expiration timestamp of the urls and the new refresh token.
* @param refreshToken the optional refresh token that can be used by the refresher to retrieve
* the same set of files with refreshed urls.
*/
class CachedTable(
val expiration: Long,
val idToUrl: Map[String, String],
val refs: Seq[WeakReference[AnyRef]],
@volatile var lastAccess: Long,
val refresher: () => (Map[String, String], Option[Long]))
val refresher: Option[String] => (Map[String, String], Option[Long], Option[String]),
val refreshToken: Option[String])

class CachedTableManager(
val preSignedUrlExpirationMs: Long,
Expand Down Expand Up @@ -96,7 +99,7 @@ class CachedTableManager(
logInfo(s"Updating pre signed urls for $tablePath (expiration time: " +
s"${new java.util.Date(cachedTable.expiration)})")
try {
val (idToUrl, expOpt) = cachedTable.refresher()
val (idToUrl, expOpt, refreshToken) = cachedTable.refresher(cachedTable.refreshToken)
val newTable = new CachedTable(
if (isValidUrlExpirationTime(expOpt)) {
expOpt.get
Expand All @@ -106,7 +109,8 @@ class CachedTableManager(
idToUrl,
cachedTable.refs,
cachedTable.lastAccess,
cachedTable.refresher
cachedTable.refresher,
refreshToken
)
// Failing to replace the table is fine because if it did happen, we would retry after
// `refreshCheckIntervalMs` milliseconds.
Expand Down Expand Up @@ -149,43 +153,47 @@ class CachedTableManager(
* @param tablePath the table path. This is usually the profile file path.
* @param idToUrl the pre signed url map. This will be refreshed when the pre signed urls is going
* to expire.
* @param refs A list of weak references which can be used to determine whether the cache is
* still needed. When all the weak references return null, we will remove the pre
* signed url cache of this table form the cache.
* @param refs A list of weak references which can be used to determine whether the cache is
* still needed. When all the weak references return null, we will remove the pre
* signed url cache of this table form the cache.
* @param profileProvider a profile Provider that can provide customized refresher function.
* @param refresher A function to re-generate pre signed urls for the table.
* @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration
* timestamp of the idToUrl.
* @param refresher A function to re-generate pre signed urls for the table.
* @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration
* timestamp of the idToUrl.
* @param refreshToken an optional refresh token that can be used by the refresher to retrieve
* the same set of files with refreshed urls.
*/
def register(
tablePath: String,
idToUrl: Map[String, String],
refs: Seq[WeakReference[AnyRef]],
profileProvider: DeltaSharingProfileProvider,
refresher: () => (Map[String, String], Option[Long]),
expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs
): Unit = {
refresher: Option[String] => (Map[String, String], Option[Long], Option[String]),
expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs,
refreshToken: Option[String]
): Unit = {
val customTablePath = profileProvider.getCustomTablePath(tablePath)
val customRefresher = profileProvider.getCustomRefresher(refresher)

val (resolvedIdToUrl, resolvedExpiration) =
val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) =
if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) {
val (refreshedIdToUrl, expOpt) = customRefresher()
val (refreshedIdToUrl, expOpt, newRefreshToken) = customRefresher(refreshToken)
if (isValidUrlExpirationTime(expOpt)) {
(refreshedIdToUrl, expOpt.get)
(refreshedIdToUrl, expOpt.get, newRefreshToken)
} else {
(refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs)
(refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs, newRefreshToken)
}
} else {
(idToUrl, expirationTimestamp)
(idToUrl, expirationTimestamp, refreshToken)
}

val cachedTable = new CachedTable(
resolvedExpiration,
idToUrl = resolvedIdToUrl,
refs,
System.currentTimeMillis(),
customRefresher
customRefresher,
resolvedRefreshToken
)
var oldTable = cache.putIfAbsent(customTablePath, cachedTable)
if (oldTable == null) {
Expand All @@ -203,7 +211,8 @@ class CachedTableManager(
// Try to avoid storing duplicate references
refs.filterNot(ref => oldTable.refs.exists(_.get eq ref.get)) ++ oldTable.refs,
lastAccess = System.currentTimeMillis(),
customRefresher
customRefresher,
cachedTable.refreshToken
)
if (cache.replace(customTablePath, oldTable, mergedTable)) {
// Put the merged one to the cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ case class DeltaSharingSource(
// The latest function used to fetch presigned urls for the delta sharing table, record it in
// a variable to be used by the CachedTableManager to refresh the presigned urls if the query
// runs for a long time.
private var latestRefreshFunc = () => { (Map.empty[String, String], None: Option[Long]) }
private var latestRefreshFunc = (_: Option[String]) => {
(Map.empty[String, String], Option.empty[Long], Option.empty[String])
}

// Check the latest table version from the delta sharing server through the client.getTableVersion
// RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive
Expand Down Expand Up @@ -413,7 +415,7 @@ case class DeltaSharingSource(
jsonPredicateHints = None,
refreshToken = None
)
latestRefreshFunc = () => {
latestRefreshFunc = _ => {
val queryTimestamp = System.currentTimeMillis()
val files = deltaLog.client.getFiles(
table = deltaLog.table,
Expand All @@ -439,7 +441,7 @@ case class DeltaSharingSource(

refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration)

(idToUrl, minUrlExpiration)
(idToUrl, minUrlExpiration, None)
}

val numFiles = tableFiles.files.size
Expand Down Expand Up @@ -473,7 +475,7 @@ case class DeltaSharingSource(
val tableFiles = deltaLog.client.getFiles(
deltaLog.table, fromVersion, Some(endingVersionForQuery)
)
latestRefreshFunc = () => {
latestRefreshFunc = _ => {
val queryTimestamp = System.currentTimeMillis()
val addFiles = deltaLog.client.getFiles(
deltaLog.table, fromVersion, Some(endingVersionForQuery)
Expand All @@ -493,7 +495,7 @@ case class DeltaSharingSource(

refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration)

(idToUrl, minUrlExpiration)
(idToUrl, minUrlExpiration, None)
}
val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version)
for (v <- fromVersion to endingVersionForQuery) {
Expand Down Expand Up @@ -545,7 +547,7 @@ case class DeltaSharingSource(
),
true
)
latestRefreshFunc = () => {
latestRefreshFunc = _ => {
val queryTimestamp = System.currentTimeMillis()
val d = deltaLog.client.getCDFFiles(
deltaLog.table,
Expand All @@ -562,10 +564,7 @@ case class DeltaSharingSource(
d.addFiles, d.cdfFiles, d.removeFiles)
refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration)

(
idToUrl,
minUrlExpiration
)
(idToUrl, minUrlExpiration, None)
}

(Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m =>
Expand Down Expand Up @@ -774,7 +773,8 @@ case class DeltaSharingSource(
urlExpirationTimestamp.get
} else {
lastQueryTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs
}
},
None
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ case class RemoteDeltaCDFRelation(
deltaTabelFiles.removeFiles,
DeltaTableUtils.addCdcSchema(deltaTabelFiles.metadata.schemaString),
false,
() => {
_ => {
val d = client.getCDFFiles(table, cdfOptions, false)
(
DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles),
DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles)
DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles),
None
)
},
System.currentTimeMillis(),
Expand All @@ -82,7 +83,7 @@ object DeltaSharingCDFReader {
removeFiles: Seq[RemoveFile],
schema: StructType,
isStreaming: Boolean,
refresher: () => (Map[String, String], Option[Long]),
refresher: Option[String] => (Map[String, String], Option[Long], Option[String]),
lastQueryTableTimestamp: Long,
expirationTimestamp: Option[Long]
): DataFrame = {
Expand Down Expand Up @@ -111,7 +112,8 @@ object DeltaSharingCDFReader {
expirationTimestamp.get
} else {
lastQueryTableTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs
}
},
None
)

dfs.reduce((df1, df2) => df1.unionAll(df2))
Expand Down
15 changes: 8 additions & 7 deletions spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,12 @@ class RemoteSnapshot(
idToUrl,
Seq(new WeakReference(fileIndex)),
fileIndex.params.profileProvider,
() => {
// TODO: use tableFiles.refreshToken
val files = client.getFiles(
table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, None).files
refreshToken => {
val tableFiles = client.getFiles(
table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, refreshToken
)
var minUrlExpiration: Option[Long] = None
val idToUrl = files.map { add =>
val idToUrl = tableFiles.files.map { add =>
if (add.expirationTimestamp != null) {
minUrlExpiration = if (minUrlExpiration.isDefined
&& minUrlExpiration.get < add.expirationTimestamp) {
Expand All @@ -252,13 +252,14 @@ class RemoteSnapshot(
}
add.id -> add.url
}.toMap
(idToUrl, minUrlExpiration)
(idToUrl, minUrlExpiration, tableFiles.refreshToken)
},
if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) {
minUrlExpirationTimestamp.get
} else {
System.currentTimeMillis() + CachedTableManager.INSTANCE.preSignedUrlExpirationMs
}
},
tableFiles.refreshToken
)
checkProtocolNotChange(tableFiles.protocol)
checkSchemaNotChange(tableFiles.metadata)
Expand Down

0 comments on commit 7b8c8f5

Please sign in to comment.