From cb59b6d5246d97fad21dcde7a33d38d9e0f347d4 Mon Sep 17 00:00:00 2001 From: Charlene Lyu Date: Fri, 1 Sep 2023 01:06:56 -0700 Subject: [PATCH 1/3] refresh using refreshToken --- .../client/DeltaSharingProfileProvider.scala | 9 ++-- .../delta/sharing/PreSignedUrlCache.scala | 51 +++++++++++-------- .../sharing/spark/DeltaSharingSource.scala | 22 ++++---- .../spark/RemoteDeltaCDFRelation.scala | 10 ++-- .../delta/sharing/spark/RemoteDeltaLog.scala | 15 +++--- 5 files changed, 60 insertions(+), 47 deletions(-) diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala index a93324448..2a1f6acf2 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala @@ -47,10 +47,11 @@ trait DeltaSharingProfileProvider { def getCustomTablePath(tablePath: String): String = tablePath - // Map[String, String] is the id to url map. - // Long is the minimum url expiration time for all the urls. - def getCustomRefresher(refresher: () => (Map[String, String], Option[Long])): () => - (Map[String, String], Option[Long]) = { + // `refresher` takes an optional refreshToken, and returns + // (idToUrlMap, minUrlExpirationTimestamp, refreshToken) + def getCustomRefresher( + refresher: Option[String] => (Map[String, String], Option[Long], Option[String])) + : Option[String] => (Map[String, String], Option[Long], Option[String]) = { refresher } } diff --git a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala index 38cd05bd2..3e6c10fe5 100644 --- a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala +++ b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala @@ -35,15 +35,18 @@ import io.delta.sharing.client.DeltaSharingProfileProvider * remove the cached table from our cache. * @param lastAccess When the table was accessed last time. We will remove old tables that are not * accessed after `expireAfterAccessMs` milliseconds. - * @param refresher the function to generate a new file id to pre sign url map, as long as the new - * expiration timestamp of the urls. + * @param refresher the function to generate a new file id to pre sign url map, with the new + * expiration timestamp of the urls and the new refresh token. + * @param refreshToken the optional refresh token that can be used by the refresher to retrieve + * the same set of files with refreshed urls. */ class CachedTable( val expiration: Long, val idToUrl: Map[String, String], val refs: Seq[WeakReference[AnyRef]], @volatile var lastAccess: Long, - val refresher: () => (Map[String, String], Option[Long])) + val refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + val refreshToken: Option[String]) class CachedTableManager( val preSignedUrlExpirationMs: Long, @@ -96,7 +99,7 @@ class CachedTableManager( logInfo(s"Updating pre signed urls for $tablePath (expiration time: " + s"${new java.util.Date(cachedTable.expiration)})") try { - val (idToUrl, expOpt) = cachedTable.refresher() + val (idToUrl, expOpt, refreshToken) = cachedTable.refresher(cachedTable.refreshToken) val newTable = new CachedTable( if (isValidUrlExpirationTime(expOpt)) { expOpt.get @@ -106,7 +109,8 @@ class CachedTableManager( idToUrl, cachedTable.refs, cachedTable.lastAccess, - cachedTable.refresher + cachedTable.refresher, + refreshToken ) // Failing to replace the table is fine because if it did happen, we would retry after // `refreshCheckIntervalMs` milliseconds. @@ -149,35 +153,38 @@ class CachedTableManager( * @param tablePath the table path. This is usually the profile file path. * @param idToUrl the pre signed url map. This will be refreshed when the pre signed urls is going * to expire. - * @param refs A list of weak references which can be used to determine whether the cache is - * still needed. When all the weak references return null, we will remove the pre - * signed url cache of this table form the cache. + * @param refs A list of weak references which can be used to determine whether the cache is + * still needed. When all the weak references return null, we will remove the pre + * signed url cache of this table form the cache. * @param profileProvider a profile Provider that can provide customized refresher function. - * @param refresher A function to re-generate pre signed urls for the table. - * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration - * timestamp of the idToUrl. + * @param refresher A function to re-generate pre signed urls for the table. + * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration + * timestamp of the idToUrl. + * @param refreshToken an optional refresh token that can be used by the refresher to retrieve + * the same set of files with refreshed urls. */ def register( tablePath: String, idToUrl: Map[String, String], refs: Seq[WeakReference[AnyRef]], profileProvider: DeltaSharingProfileProvider, - refresher: () => (Map[String, String], Option[Long]), - expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs - ): Unit = { + refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs, + refreshToken: Option[String] + ): Unit = { val customTablePath = profileProvider.getCustomTablePath(tablePath) val customRefresher = profileProvider.getCustomRefresher(refresher) - val (resolvedIdToUrl, resolvedExpiration) = + val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) = if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) { - val (refreshedIdToUrl, expOpt) = customRefresher() + val (refreshedIdToUrl, expOpt, newRefreshToken) = customRefresher(refreshToken) if (isValidUrlExpirationTime(expOpt)) { - (refreshedIdToUrl, expOpt.get) + (refreshedIdToUrl, expOpt.get, newRefreshToken) } else { - (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs) + (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs, newRefreshToken) } } else { - (idToUrl, expirationTimestamp) + (idToUrl, expirationTimestamp, refreshToken) } val cachedTable = new CachedTable( @@ -185,7 +192,8 @@ class CachedTableManager( idToUrl = resolvedIdToUrl, refs, System.currentTimeMillis(), - customRefresher + customRefresher, + resolvedRefreshToken ) var oldTable = cache.putIfAbsent(customTablePath, cachedTable) if (oldTable == null) { @@ -203,7 +211,8 @@ class CachedTableManager( // Try to avoid storing duplicate references refs.filterNot(ref => oldTable.refs.exists(_.get eq ref.get)) ++ oldTable.refs, lastAccess = System.currentTimeMillis(), - customRefresher + customRefresher, + cachedTable.refreshToken ) if (cache.replace(customTablePath, oldTable, mergedTable)) { // Put the merged one to the cache diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index 2fdfede19..e2c1d7555 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -162,7 +162,9 @@ case class DeltaSharingSource( // The latest function used to fetch presigned urls for the delta sharing table, record it in // a variable to be used by the CachedTableManager to refresh the presigned urls if the query // runs for a long time. - private var latestRefreshFunc = () => { (Map.empty[String, String], None: Option[Long]) } + private var latestRefreshFunc = (_: Option[String]) => { + (Map.empty[String, String], Option.empty[Long], Option.empty[String]) + } // Check the latest table version from the delta sharing server through the client.getTableVersion // RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive @@ -413,7 +415,7 @@ case class DeltaSharingSource( jsonPredicateHints = None, refreshToken = None ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val files = deltaLog.client.getFiles( table = deltaLog.table, @@ -439,7 +441,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration) + (idToUrl, minUrlExpiration, None) } val numFiles = tableFiles.files.size @@ -473,7 +475,7 @@ case class DeltaSharingSource( val tableFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val addFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) @@ -493,7 +495,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration) + (idToUrl, minUrlExpiration, None) } val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version) for (v <- fromVersion to endingVersionForQuery) { @@ -545,7 +547,7 @@ case class DeltaSharingSource( ), true ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val d = deltaLog.client.getCDFFiles( deltaLog.table, @@ -562,10 +564,7 @@ case class DeltaSharingSource( d.addFiles, d.cdfFiles, d.removeFiles) refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - ( - idToUrl, - minUrlExpiration - ) + (idToUrl, minUrlExpiration, None) } (Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m => @@ -774,7 +773,8 @@ case class DeltaSharingSource( urlExpirationTimestamp.get } else { lastQueryTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + None ) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala index 8341643ca..6e7b48fa4 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala @@ -56,11 +56,12 @@ case class RemoteDeltaCDFRelation( deltaTabelFiles.removeFiles, DeltaTableUtils.addCdcSchema(deltaTabelFiles.metadata.schemaString), false, - () => { + _ => { val d = client.getCDFFiles(table, cdfOptions, false) ( DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles), - DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles) + DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles), + None ) }, System.currentTimeMillis(), @@ -82,7 +83,7 @@ object DeltaSharingCDFReader { removeFiles: Seq[RemoveFile], schema: StructType, isStreaming: Boolean, - refresher: () => (Map[String, String], Option[Long]), + refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), lastQueryTableTimestamp: Long, expirationTimestamp: Option[Long] ): DataFrame = { @@ -111,7 +112,8 @@ object DeltaSharingCDFReader { expirationTimestamp.get } else { lastQueryTableTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + None ) dfs.reduce((df1, df2) => df1.unionAll(df2)) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala index 18028eee8..92940dc39 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala @@ -236,12 +236,12 @@ class RemoteSnapshot( idToUrl, Seq(new WeakReference(fileIndex)), fileIndex.params.profileProvider, - () => { - // TODO: use tableFiles.refreshToken - val files = client.getFiles( - table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, None).files + refreshToken => { + val tableFiles = client.getFiles( + table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, refreshToken + ) var minUrlExpiration: Option[Long] = None - val idToUrl = files.map { add => + val idToUrl = tableFiles.files.map { add => if (add.expirationTimestamp != null) { minUrlExpiration = if (minUrlExpiration.isDefined && minUrlExpiration.get < add.expirationTimestamp) { @@ -252,13 +252,14 @@ class RemoteSnapshot( } add.id -> add.url }.toMap - (idToUrl, minUrlExpiration) + (idToUrl, minUrlExpiration, tableFiles.refreshToken) }, if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { minUrlExpirationTimestamp.get } else { System.currentTimeMillis() + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + tableFiles.refreshToken ) checkProtocolNotChange(tableFiles.protocol) checkSchemaNotChange(tableFiles.metadata) From 00504cc9153f8e7caaeed0735809965715cecc39 Mon Sep 17 00:00:00 2001 From: Charlene Lyu Date: Fri, 1 Sep 2023 10:29:49 -0700 Subject: [PATCH 2/3] unit test --- .../sharing/CachedTableManagerSuite.scala | 97 ++++++++++++++----- 1 file changed, 73 insertions(+), 24 deletions(-) diff --git a/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala b/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala index 1d2793892..094c8fede 100644 --- a/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala +++ b/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala @@ -47,9 +47,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url1", "id2" -> "url2"), None) - }) + _ => { + (Map("id1" -> "url1", "id2" -> "url2"), None, None) + }, + refreshToken = None) assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), "id1")._1 == "url1") assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), @@ -60,9 +61,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url3", "id2" -> "url4"), None) - }) + _ => { + (Map("id1" -> "url3", "id2" -> "url4"), None, None) + }, + refreshToken = None) // We should get the new urls eventually eventually(timeout(10.seconds)) { assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path2"), @@ -76,9 +78,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(new AnyRef)), provider, - () => { - (Map("id1" -> "url3", "id2" -> "url4"), None) - }) + _ => { + (Map("id1" -> "url3", "id2" -> "url4"), None, None) + }, + refreshToken = None) // We should remove the cached table eventually eventually(timeout(10.seconds)) { System.gc() @@ -93,10 +96,11 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url3", "id2" -> "url4"), None) + _ => { + (Map("id1" -> "url3", "id2" -> "url4"), None, None) }, - -1 + -1, + refreshToken = None ) // We should get new urls immediately because it's refreshed upon register assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"), @@ -124,14 +128,16 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { + _ => { refreshTime += 1 ( Map("id1" -> ("url" + refreshTime.toString), "id2" -> "url4"), - Some(System.currentTimeMillis() + 1900) + Some(System.currentTimeMillis() + 1900), + None ) }, - System.currentTimeMillis() + 1900 + System.currentTimeMillis() + 1900, + None ) // We should refresh at least 5 times within 10 seconds based on // (System.currentTimeMillis() + 1900). @@ -148,14 +154,16 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { + _ => { refreshTime2 += 1 ( Map("id1" -> ("url" + refreshTime2.toString), "id2" -> "url4"), - Some(System.currentTimeMillis() + 4900) + Some(System.currentTimeMillis() + 4900), + None ) }, - System.currentTimeMillis() + 4900 + System.currentTimeMillis() + 4900, + None ) // We should refresh 2 times within 10 seconds based on (System.currentTimeMillis() + 4900). eventually(timeout(10.seconds)) { @@ -171,14 +179,16 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { + _ => { refreshTime3 += 1 ( Map("id1" -> ("url" + refreshTime3.toString), "id2" -> "url4"), - Some(System.currentTimeMillis() - 4900) + Some(System.currentTimeMillis() - 4900), + None ) }, - System.currentTimeMillis() + 6000 + System.currentTimeMillis() + 6000, + None ) // We should refresh 1 times within 10 seconds based on (preSignedUrlExpirationMs = 6000). try { @@ -197,6 +207,44 @@ class CachedTableManagerSuite extends SparkFunSuite { } } + test("refresh using refresh token") { + val manager = new CachedTableManager( + preSignedUrlExpirationMs = 10, + refreshCheckIntervalMs = 10, + refreshThresholdMs = 10, + expireAfterAccessMs = 60000 + ) + try { + val ref = new AnyRef + val provider = new TestDeltaSharingProfileProvider + manager.register( + "test-table-path", + Map("id1" -> "url1", "id2" -> "url2"), + Seq(new WeakReference(ref)), + provider, + refreshToken => { + if (refreshToken.contains("refresh-token-1")) { + (Map("id1" -> "url3", "id2" -> "url4"), None, Some("refresh-token-2")) + } else if (refreshToken.contains("refresh-token-2")) { + (Map("id1" -> "url5", "id2" -> "url6"), None, Some("refresh-token-2")) + } else { + fail("Expecting to refresh with a refresh token") + } + }, + refreshToken = Some("refresh-token-1") + ) + // We should get url5 and url6 eventually. + eventually(timeout(10.seconds)) { + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), + "id1")._1 == "url5") + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), + "id2")._1 == "url6") + } + } finally { + manager.stop() + } + } + test("expireAfterAccessMs") { val manager = new CachedTableManager( preSignedUrlExpirationMs = 10, @@ -213,9 +261,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url1", "id2" -> "url2"), None) - }) + _ => { + (Map("id1" -> "url1", "id2" -> "url2"), None, None) + }, + refreshToken = None) Thread.sleep(1000) // We should remove the cached table when it's not accessed intercept[IllegalStateException](manager.getPreSignedUrl( From a79872d14fd6385ad34a783b463ec0ffbe84e2ea Mon Sep 17 00:00:00 2001 From: Charlene Lyu Date: Fri, 1 Sep 2023 16:39:44 -0700 Subject: [PATCH 3/3] refactor --- .../client/DeltaSharingProfileProvider.scala | 4 +-- .../delta/sharing/PreSignedUrlCache.scala | 32 ++++++++++++------- .../sharing/CachedTableManagerSuite.scala | 28 ++++++++++------ .../sharing/spark/DeltaSharingSource.scala | 26 ++++----------- .../spark/RemoteDeltaCDFRelation.scala | 6 ++-- .../delta/sharing/spark/RemoteDeltaLog.scala | 14 ++------ 6 files changed, 54 insertions(+), 56 deletions(-) diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala index 2a1f6acf2..37d855b7b 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.delta.sharing.TableRefreshResult import io.delta.sharing.client.util.JsonUtils @@ -50,8 +51,7 @@ trait DeltaSharingProfileProvider { // `refresher` takes an optional refreshToken, and returns // (idToUrlMap, minUrlExpirationTimestamp, refreshToken) def getCustomRefresher( - refresher: Option[String] => (Map[String, String], Option[Long], Option[String])) - : Option[String] => (Map[String, String], Option[Long], Option[String]) = { + refresher: Option[String] => TableRefreshResult): Option[String] => TableRefreshResult = { refresher } } diff --git a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala index 3e6c10fe5..d2f819642 100644 --- a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala +++ b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala @@ -28,6 +28,12 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} import io.delta.sharing.client.DeltaSharingProfileProvider +case class TableRefreshResult( + idToUrl: Map[String, String], + expirationTimestamp: Option[Long], + refreshToken: Option[String] +) + /** * @param expiration the expiration time of the pre signed urls * @param idToUrl the file id to pre sign url map @@ -45,7 +51,7 @@ class CachedTable( val idToUrl: Map[String, String], val refs: Seq[WeakReference[AnyRef]], @volatile var lastAccess: Long, - val refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + val refresher: Option[String] => TableRefreshResult, val refreshToken: Option[String]) class CachedTableManager( @@ -99,18 +105,18 @@ class CachedTableManager( logInfo(s"Updating pre signed urls for $tablePath (expiration time: " + s"${new java.util.Date(cachedTable.expiration)})") try { - val (idToUrl, expOpt, refreshToken) = cachedTable.refresher(cachedTable.refreshToken) + val refreshRes = cachedTable.refresher(cachedTable.refreshToken) val newTable = new CachedTable( - if (isValidUrlExpirationTime(expOpt)) { - expOpt.get + if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) { + refreshRes.expirationTimestamp.get } else { preSignedUrlExpirationMs + System.currentTimeMillis() }, - idToUrl, + refreshRes.idToUrl, cachedTable.refs, cachedTable.lastAccess, cachedTable.refresher, - refreshToken + refreshRes.refreshToken ) // Failing to replace the table is fine because if it did happen, we would retry after // `refreshCheckIntervalMs` milliseconds. @@ -168,7 +174,7 @@ class CachedTableManager( idToUrl: Map[String, String], refs: Seq[WeakReference[AnyRef]], profileProvider: DeltaSharingProfileProvider, - refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + refresher: Option[String] => TableRefreshResult, expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs, refreshToken: Option[String] ): Unit = { @@ -177,11 +183,15 @@ class CachedTableManager( val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) = if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) { - val (refreshedIdToUrl, expOpt, newRefreshToken) = customRefresher(refreshToken) - if (isValidUrlExpirationTime(expOpt)) { - (refreshedIdToUrl, expOpt.get, newRefreshToken) + val refreshRes = customRefresher(refreshToken) + if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) { + (refreshRes.idToUrl, refreshRes.expirationTimestamp.get, refreshRes.refreshToken) } else { - (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs, newRefreshToken) + ( + refreshRes.idToUrl, + System.currentTimeMillis() + preSignedUrlExpirationMs, + refreshRes.refreshToken + ) } } else { (idToUrl, expirationTimestamp, refreshToken) diff --git a/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala b/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala index 094c8fede..8899d96df 100644 --- a/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala +++ b/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala @@ -48,7 +48,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, _ => { - (Map("id1" -> "url1", "id2" -> "url2"), None, None) + TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None) }, refreshToken = None) assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), @@ -62,7 +62,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, _ => { - (Map("id1" -> "url3", "id2" -> "url4"), None, None) + TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None) }, refreshToken = None) // We should get the new urls eventually @@ -79,7 +79,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(new AnyRef)), provider, _ => { - (Map("id1" -> "url3", "id2" -> "url4"), None, None) + TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None) }, refreshToken = None) // We should remove the cached table eventually @@ -97,7 +97,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, _ => { - (Map("id1" -> "url3", "id2" -> "url4"), None, None) + TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None) }, -1, refreshToken = None @@ -130,7 +130,7 @@ class CachedTableManagerSuite extends SparkFunSuite { provider, _ => { refreshTime += 1 - ( + TableRefreshResult( Map("id1" -> ("url" + refreshTime.toString), "id2" -> "url4"), Some(System.currentTimeMillis() + 1900), None @@ -156,7 +156,7 @@ class CachedTableManagerSuite extends SparkFunSuite { provider, _ => { refreshTime2 += 1 - ( + TableRefreshResult( Map("id1" -> ("url" + refreshTime2.toString), "id2" -> "url4"), Some(System.currentTimeMillis() + 4900), None @@ -181,7 +181,7 @@ class CachedTableManagerSuite extends SparkFunSuite { provider, _ => { refreshTime3 += 1 - ( + TableRefreshResult( Map("id1" -> ("url" + refreshTime3.toString), "id2" -> "url4"), Some(System.currentTimeMillis() - 4900), None @@ -224,9 +224,17 @@ class CachedTableManagerSuite extends SparkFunSuite { provider, refreshToken => { if (refreshToken.contains("refresh-token-1")) { - (Map("id1" -> "url3", "id2" -> "url4"), None, Some("refresh-token-2")) + TableRefreshResult( + Map("id1" -> "url3", "id2" -> "url4"), + None, + Some("refresh-token-2") + ) } else if (refreshToken.contains("refresh-token-2")) { - (Map("id1" -> "url5", "id2" -> "url6"), None, Some("refresh-token-2")) + TableRefreshResult( + Map("id1" -> "url5", "id2" -> "url6"), + None, + Some("refresh-token-2") + ) } else { fail("Expecting to refresh with a refresh token") } @@ -262,7 +270,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, _ => { - (Map("id1" -> "url1", "id2" -> "url2"), None, None) + TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None) }, refreshToken = None) Thread.sleep(1000) diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index e2c1d7555..ad4df8a4f 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -21,28 +21,16 @@ import java.lang.ref.WeakReference import scala.collection.mutable.ArrayBuffer -import org.apache.spark.delta.sharing.CachedTableManager +import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, DeltaSharingScanUtils, SparkSession} import org.apache.spark.sql.connector.read.streaming -import org.apache.spark.sql.connector.read.streaming.{ - ReadAllAvailable, - ReadLimit, - ReadMaxFiles, - SupportsAdmissionControl -} +import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit, ReadMaxFiles, SupportsAdmissionControl} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.types.StructType -import io.delta.sharing.client.model.{ - AddCDCFile, - AddFile, - AddFileForCDF, - DeltaTableFiles, - FileAction, - RemoveFile -} +import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, FileAction, RemoveFile} import io.delta.sharing.spark.util.SchemaUtils /** @@ -163,7 +151,7 @@ case class DeltaSharingSource( // a variable to be used by the CachedTableManager to refresh the presigned urls if the query // runs for a long time. private var latestRefreshFunc = (_: Option[String]) => { - (Map.empty[String, String], Option.empty[Long], Option.empty[String]) + TableRefreshResult(Map.empty[String, String], None, None) } // Check the latest table version from the delta sharing server through the client.getTableVersion @@ -441,7 +429,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration, None) + TableRefreshResult(idToUrl, minUrlExpiration, None) } val numFiles = tableFiles.files.size @@ -495,7 +483,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration, None) + TableRefreshResult(idToUrl, minUrlExpiration, None) } val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version) for (v <- fromVersion to endingVersionForQuery) { @@ -564,7 +552,7 @@ case class DeltaSharingSource( d.addFiles, d.cdfFiles, d.removeFiles) refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration, None) + TableRefreshResult(idToUrl, minUrlExpiration, None) } (Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m => diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala index 6e7b48fa4..62c53689e 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala @@ -20,7 +20,7 @@ import java.lang.ref.WeakReference import scala.collection.mutable.ListBuffer -import org.apache.spark.delta.sharing.CachedTableManager +import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DeltaSharingScanUtils, Row, SparkSession, SQLContext} import org.apache.spark.sql.execution.LogicalRDD @@ -58,7 +58,7 @@ case class RemoteDeltaCDFRelation( false, _ => { val d = client.getCDFFiles(table, cdfOptions, false) - ( + TableRefreshResult( DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles), DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles), None @@ -83,7 +83,7 @@ object DeltaSharingCDFReader { removeFiles: Seq[RemoveFile], schema: StructType, isStreaming: Boolean, - refresher: Option[String] => (Map[String, String], Option[Long], Option[String]), + refresher: Option[String] => TableRefreshResult, lastQueryTableTimestamp: Long, expirationTimestamp: Option[Long] ): DataFrame = { diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala index 92940dc39..0ed75266f 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala @@ -20,7 +20,7 @@ import java.lang.ref.WeakReference import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.delta.sharing.CachedTableManager +import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, Encoder, SparkSession} import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} @@ -32,15 +32,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{DataType, StructField, StructType} import io.delta.sharing.client.{DeltaSharingClient, DeltaSharingProfileProvider, DeltaSharingRestClient} -import io.delta.sharing.client.model.{ - AddFile, - CDFColumnInfo, - DeltaTableFiles, - FileAction, - Metadata, - Protocol, - Table => DeltaSharingTable -} +import io.delta.sharing.client.model.{AddFile, CDFColumnInfo, DeltaTableFiles, FileAction, Metadata, Protocol, Table => DeltaSharingTable} import io.delta.sharing.client.util.ConfUtils import io.delta.sharing.spark.perf.DeltaSharingLimitPushDown @@ -252,7 +244,7 @@ class RemoteSnapshot( } add.id -> add.url }.toMap - (idToUrl, minUrlExpiration, tableFiles.refreshToken) + TableRefreshResult(idToUrl, minUrlExpiration, tableFiles.refreshToken) }, if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { minUrlExpirationTimestamp.get