From f97d5db4053cef5719377f01dd86e8e090d2e70a Mon Sep 17 00:00:00 2001 From: Charlene Lyu Date: Fri, 1 Sep 2023 01:06:56 -0700 Subject: [PATCH] refresh using refreshToken --- .../client/DeltaSharingProfileProvider.scala | 9 ++-- .../delta/sharing/PreSignedUrlCache.scala | 51 +++++++++++-------- .../sharing/spark/DeltaSharingSource.scala | 22 ++++---- .../spark/RemoteDeltaCDFRelation.scala | 10 ++-- .../delta/sharing/spark/RemoteDeltaLog.scala | 15 +++--- 5 files changed, 60 insertions(+), 47 deletions(-) diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala index a93324448..2a1f6acf2 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala @@ -47,10 +47,11 @@ trait DeltaSharingProfileProvider { def getCustomTablePath(tablePath: String): String = tablePath - // Map[String, String] is the id to url map. - // Long is the minimum url expiration time for all the urls. - def getCustomRefresher(refresher: () => (Map[String, String], Option[Long])): () => - (Map[String, String], Option[Long]) = { + // `refresher` takes an optional refreshToken, and returns + // (idToUrlMap, minUrlExpirationTimestamp, refreshToken) + def getCustomRefresher( + refresher: Option[String] => (Map[String, String], Option[Long], Option[String])) + : Option[String] => (Map[String, String], Option[Long], Option[String]) = { refresher } } diff --git a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala index 38cd05bd2..3e6c10fe5 100644 --- a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala +++ b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala @@ -35,15 +35,18 @@ import io.delta.sharing.client.DeltaSharingProfileProvider * remove the cached table from our cache. * @param lastAccess When the table was accessed last time. We will remove old tables that are not * accessed after `expireAfterAccessMs` milliseconds. - * @param refresher the function to generate a new file id to pre sign url map, as long as the new - * expiration timestamp of the urls. + * @param refresher the function to generate a new file id to pre sign url map, with the new + * expiration timestamp of the urls and the new refresh token. + * @param refreshToken the optional refresh token that can be used by the refresher to retrieve + * the same set of files with refreshed urls. */ class CachedTable( val expiration: Long, val idToUrl: Map[String, String], val refs: Seq[WeakReference[AnyRef]], @volatile var lastAccess: Long, - val refresher: () => (Map[String, String], Option[Long])) + val refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + val refreshToken: Option[String]) class CachedTableManager( val preSignedUrlExpirationMs: Long, @@ -96,7 +99,7 @@ class CachedTableManager( logInfo(s"Updating pre signed urls for $tablePath (expiration time: " + s"${new java.util.Date(cachedTable.expiration)})") try { - val (idToUrl, expOpt) = cachedTable.refresher() + val (idToUrl, expOpt, refreshToken) = cachedTable.refresher(cachedTable.refreshToken) val newTable = new CachedTable( if (isValidUrlExpirationTime(expOpt)) { expOpt.get @@ -106,7 +109,8 @@ class CachedTableManager( idToUrl, cachedTable.refs, cachedTable.lastAccess, - cachedTable.refresher + cachedTable.refresher, + refreshToken ) // Failing to replace the table is fine because if it did happen, we would retry after // `refreshCheckIntervalMs` milliseconds. @@ -149,35 +153,38 @@ class CachedTableManager( * @param tablePath the table path. This is usually the profile file path. * @param idToUrl the pre signed url map. This will be refreshed when the pre signed urls is going * to expire. - * @param refs A list of weak references which can be used to determine whether the cache is - * still needed. When all the weak references return null, we will remove the pre - * signed url cache of this table form the cache. + * @param refs A list of weak references which can be used to determine whether the cache is + * still needed. When all the weak references return null, we will remove the pre + * signed url cache of this table form the cache. * @param profileProvider a profile Provider that can provide customized refresher function. - * @param refresher A function to re-generate pre signed urls for the table. - * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration - * timestamp of the idToUrl. + * @param refresher A function to re-generate pre signed urls for the table. + * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration + * timestamp of the idToUrl. + * @param refreshToken an optional refresh token that can be used by the refresher to retrieve + * the same set of files with refreshed urls. */ def register( tablePath: String, idToUrl: Map[String, String], refs: Seq[WeakReference[AnyRef]], profileProvider: DeltaSharingProfileProvider, - refresher: () => (Map[String, String], Option[Long]), - expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs - ): Unit = { + refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs, + refreshToken: Option[String] + ): Unit = { val customTablePath = profileProvider.getCustomTablePath(tablePath) val customRefresher = profileProvider.getCustomRefresher(refresher) - val (resolvedIdToUrl, resolvedExpiration) = + val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) = if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) { - val (refreshedIdToUrl, expOpt) = customRefresher() + val (refreshedIdToUrl, expOpt, newRefreshToken) = customRefresher(refreshToken) if (isValidUrlExpirationTime(expOpt)) { - (refreshedIdToUrl, expOpt.get) + (refreshedIdToUrl, expOpt.get, newRefreshToken) } else { - (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs) + (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs, newRefreshToken) } } else { - (idToUrl, expirationTimestamp) + (idToUrl, expirationTimestamp, refreshToken) } val cachedTable = new CachedTable( @@ -185,7 +192,8 @@ class CachedTableManager( idToUrl = resolvedIdToUrl, refs, System.currentTimeMillis(), - customRefresher + customRefresher, + resolvedRefreshToken ) var oldTable = cache.putIfAbsent(customTablePath, cachedTable) if (oldTable == null) { @@ -203,7 +211,8 @@ class CachedTableManager( // Try to avoid storing duplicate references refs.filterNot(ref => oldTable.refs.exists(_.get eq ref.get)) ++ oldTable.refs, lastAccess = System.currentTimeMillis(), - customRefresher + customRefresher, + cachedTable.refreshToken ) if (cache.replace(customTablePath, oldTable, mergedTable)) { // Put the merged one to the cache diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index 2fdfede19..e2c1d7555 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -162,7 +162,9 @@ case class DeltaSharingSource( // The latest function used to fetch presigned urls for the delta sharing table, record it in // a variable to be used by the CachedTableManager to refresh the presigned urls if the query // runs for a long time. - private var latestRefreshFunc = () => { (Map.empty[String, String], None: Option[Long]) } + private var latestRefreshFunc = (_: Option[String]) => { + (Map.empty[String, String], Option.empty[Long], Option.empty[String]) + } // Check the latest table version from the delta sharing server through the client.getTableVersion // RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive @@ -413,7 +415,7 @@ case class DeltaSharingSource( jsonPredicateHints = None, refreshToken = None ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val files = deltaLog.client.getFiles( table = deltaLog.table, @@ -439,7 +441,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration) + (idToUrl, minUrlExpiration, None) } val numFiles = tableFiles.files.size @@ -473,7 +475,7 @@ case class DeltaSharingSource( val tableFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val addFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) @@ -493,7 +495,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration) + (idToUrl, minUrlExpiration, None) } val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version) for (v <- fromVersion to endingVersionForQuery) { @@ -545,7 +547,7 @@ case class DeltaSharingSource( ), true ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val d = deltaLog.client.getCDFFiles( deltaLog.table, @@ -562,10 +564,7 @@ case class DeltaSharingSource( d.addFiles, d.cdfFiles, d.removeFiles) refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - ( - idToUrl, - minUrlExpiration - ) + (idToUrl, minUrlExpiration, None) } (Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m => @@ -774,7 +773,8 @@ case class DeltaSharingSource( urlExpirationTimestamp.get } else { lastQueryTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + None ) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala index 8341643ca..6e7b48fa4 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala @@ -56,11 +56,12 @@ case class RemoteDeltaCDFRelation( deltaTabelFiles.removeFiles, DeltaTableUtils.addCdcSchema(deltaTabelFiles.metadata.schemaString), false, - () => { + _ => { val d = client.getCDFFiles(table, cdfOptions, false) ( DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles), - DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles) + DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles), + None ) }, System.currentTimeMillis(), @@ -82,7 +83,7 @@ object DeltaSharingCDFReader { removeFiles: Seq[RemoveFile], schema: StructType, isStreaming: Boolean, - refresher: () => (Map[String, String], Option[Long]), + refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), lastQueryTableTimestamp: Long, expirationTimestamp: Option[Long] ): DataFrame = { @@ -111,7 +112,8 @@ object DeltaSharingCDFReader { expirationTimestamp.get } else { lastQueryTableTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + None ) dfs.reduce((df1, df2) => df1.unionAll(df2)) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala index 18028eee8..92940dc39 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala @@ -236,12 +236,12 @@ class RemoteSnapshot( idToUrl, Seq(new WeakReference(fileIndex)), fileIndex.params.profileProvider, - () => { - // TODO: use tableFiles.refreshToken - val files = client.getFiles( - table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, None).files + refreshToken => { + val tableFiles = client.getFiles( + table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, refreshToken + ) var minUrlExpiration: Option[Long] = None - val idToUrl = files.map { add => + val idToUrl = tableFiles.files.map { add => if (add.expirationTimestamp != null) { minUrlExpiration = if (minUrlExpiration.isDefined && minUrlExpiration.get < add.expirationTimestamp) { @@ -252,13 +252,14 @@ class RemoteSnapshot( } add.id -> add.url }.toMap - (idToUrl, minUrlExpiration) + (idToUrl, minUrlExpiration, tableFiles.refreshToken) }, if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { minUrlExpirationTimestamp.get } else { System.currentTimeMillis() + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + tableFiles.refreshToken ) checkProtocolNotChange(tableFiles.protocol) checkSchemaNotChange(tableFiles.metadata)