diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala index a93324448..37d855b7b 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingProfileProvider.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.delta.sharing.TableRefreshResult import io.delta.sharing.client.util.JsonUtils @@ -47,10 +48,10 @@ trait DeltaSharingProfileProvider { def getCustomTablePath(tablePath: String): String = tablePath - // Map[String, String] is the id to url map. - // Long is the minimum url expiration time for all the urls. - def getCustomRefresher(refresher: () => (Map[String, String], Option[Long])): () => - (Map[String, String], Option[Long]) = { + // `refresher` takes an optional refreshToken, and returns + // (idToUrlMap, minUrlExpirationTimestamp, refreshToken) + def getCustomRefresher( + refresher: Option[String] => TableRefreshResult): Option[String] => TableRefreshResult = { refresher } } diff --git a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala index 38cd05bd2..d2f819642 100644 --- a/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala +++ b/client/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala @@ -28,6 +28,12 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils} import io.delta.sharing.client.DeltaSharingProfileProvider +case class TableRefreshResult( + idToUrl: Map[String, String], + expirationTimestamp: Option[Long], + refreshToken: Option[String] +) + /** * @param expiration the expiration time of the pre signed urls * @param idToUrl the file id to pre sign url map @@ -35,15 +41,18 @@ import io.delta.sharing.client.DeltaSharingProfileProvider * remove the cached table from our cache. * @param lastAccess When the table was accessed last time. We will remove old tables that are not * accessed after `expireAfterAccessMs` milliseconds. - * @param refresher the function to generate a new file id to pre sign url map, as long as the new - * expiration timestamp of the urls. + * @param refresher the function to generate a new file id to pre sign url map, with the new + * expiration timestamp of the urls and the new refresh token. + * @param refreshToken the optional refresh token that can be used by the refresher to retrieve + * the same set of files with refreshed urls. */ class CachedTable( val expiration: Long, val idToUrl: Map[String, String], val refs: Seq[WeakReference[AnyRef]], @volatile var lastAccess: Long, - val refresher: () => (Map[String, String], Option[Long])) + val refresher: Option[String] => TableRefreshResult, + val refreshToken: Option[String]) class CachedTableManager( val preSignedUrlExpirationMs: Long, @@ -96,17 +105,18 @@ class CachedTableManager( logInfo(s"Updating pre signed urls for $tablePath (expiration time: " + s"${new java.util.Date(cachedTable.expiration)})") try { - val (idToUrl, expOpt) = cachedTable.refresher() + val refreshRes = cachedTable.refresher(cachedTable.refreshToken) val newTable = new CachedTable( - if (isValidUrlExpirationTime(expOpt)) { - expOpt.get + if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) { + refreshRes.expirationTimestamp.get } else { preSignedUrlExpirationMs + System.currentTimeMillis() }, - idToUrl, + refreshRes.idToUrl, cachedTable.refs, cachedTable.lastAccess, - cachedTable.refresher + cachedTable.refresher, + refreshRes.refreshToken ) // Failing to replace the table is fine because if it did happen, we would retry after // `refreshCheckIntervalMs` milliseconds. @@ -149,35 +159,42 @@ class CachedTableManager( * @param tablePath the table path. This is usually the profile file path. * @param idToUrl the pre signed url map. This will be refreshed when the pre signed urls is going * to expire. - * @param refs A list of weak references which can be used to determine whether the cache is - * still needed. When all the weak references return null, we will remove the pre - * signed url cache of this table form the cache. + * @param refs A list of weak references which can be used to determine whether the cache is + * still needed. When all the weak references return null, we will remove the pre + * signed url cache of this table form the cache. * @param profileProvider a profile Provider that can provide customized refresher function. - * @param refresher A function to re-generate pre signed urls for the table. - * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration - * timestamp of the idToUrl. + * @param refresher A function to re-generate pre signed urls for the table. + * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration + * timestamp of the idToUrl. + * @param refreshToken an optional refresh token that can be used by the refresher to retrieve + * the same set of files with refreshed urls. */ def register( tablePath: String, idToUrl: Map[String, String], refs: Seq[WeakReference[AnyRef]], profileProvider: DeltaSharingProfileProvider, - refresher: () => (Map[String, String], Option[Long]), - expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs - ): Unit = { + refresher: Option[String] => TableRefreshResult, + expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs, + refreshToken: Option[String] + ): Unit = { val customTablePath = profileProvider.getCustomTablePath(tablePath) val customRefresher = profileProvider.getCustomRefresher(refresher) - val (resolvedIdToUrl, resolvedExpiration) = + val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) = if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) { - val (refreshedIdToUrl, expOpt) = customRefresher() - if (isValidUrlExpirationTime(expOpt)) { - (refreshedIdToUrl, expOpt.get) + val refreshRes = customRefresher(refreshToken) + if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) { + (refreshRes.idToUrl, refreshRes.expirationTimestamp.get, refreshRes.refreshToken) } else { - (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs) + ( + refreshRes.idToUrl, + System.currentTimeMillis() + preSignedUrlExpirationMs, + refreshRes.refreshToken + ) } } else { - (idToUrl, expirationTimestamp) + (idToUrl, expirationTimestamp, refreshToken) } val cachedTable = new CachedTable( @@ -185,7 +202,8 @@ class CachedTableManager( idToUrl = resolvedIdToUrl, refs, System.currentTimeMillis(), - customRefresher + customRefresher, + resolvedRefreshToken ) var oldTable = cache.putIfAbsent(customTablePath, cachedTable) if (oldTable == null) { @@ -203,7 +221,8 @@ class CachedTableManager( // Try to avoid storing duplicate references refs.filterNot(ref => oldTable.refs.exists(_.get eq ref.get)) ++ oldTable.refs, lastAccess = System.currentTimeMillis(), - customRefresher + customRefresher, + cachedTable.refreshToken ) if (cache.replace(customTablePath, oldTable, mergedTable)) { // Put the merged one to the cache diff --git a/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala b/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala index 1d2793892..8899d96df 100644 --- a/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala +++ b/client/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala @@ -47,9 +47,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url1", "id2" -> "url2"), None) - }) + _ => { + TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None) + }, + refreshToken = None) assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), "id1")._1 == "url1") assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), @@ -60,9 +61,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url3", "id2" -> "url4"), None) - }) + _ => { + TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None) + }, + refreshToken = None) // We should get the new urls eventually eventually(timeout(10.seconds)) { assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path2"), @@ -76,9 +78,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(new AnyRef)), provider, - () => { - (Map("id1" -> "url3", "id2" -> "url4"), None) - }) + _ => { + TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None) + }, + refreshToken = None) // We should remove the cached table eventually eventually(timeout(10.seconds)) { System.gc() @@ -93,10 +96,11 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url3", "id2" -> "url4"), None) + _ => { + TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None) }, - -1 + -1, + refreshToken = None ) // We should get new urls immediately because it's refreshed upon register assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path4"), @@ -124,14 +128,16 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { + _ => { refreshTime += 1 - ( + TableRefreshResult( Map("id1" -> ("url" + refreshTime.toString), "id2" -> "url4"), - Some(System.currentTimeMillis() + 1900) + Some(System.currentTimeMillis() + 1900), + None ) }, - System.currentTimeMillis() + 1900 + System.currentTimeMillis() + 1900, + None ) // We should refresh at least 5 times within 10 seconds based on // (System.currentTimeMillis() + 1900). @@ -148,14 +154,16 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { + _ => { refreshTime2 += 1 - ( + TableRefreshResult( Map("id1" -> ("url" + refreshTime2.toString), "id2" -> "url4"), - Some(System.currentTimeMillis() + 4900) + Some(System.currentTimeMillis() + 4900), + None ) }, - System.currentTimeMillis() + 4900 + System.currentTimeMillis() + 4900, + None ) // We should refresh 2 times within 10 seconds based on (System.currentTimeMillis() + 4900). eventually(timeout(10.seconds)) { @@ -171,14 +179,16 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { + _ => { refreshTime3 += 1 - ( + TableRefreshResult( Map("id1" -> ("url" + refreshTime3.toString), "id2" -> "url4"), - Some(System.currentTimeMillis() - 4900) + Some(System.currentTimeMillis() - 4900), + None ) }, - System.currentTimeMillis() + 6000 + System.currentTimeMillis() + 6000, + None ) // We should refresh 1 times within 10 seconds based on (preSignedUrlExpirationMs = 6000). try { @@ -197,6 +207,52 @@ class CachedTableManagerSuite extends SparkFunSuite { } } + test("refresh using refresh token") { + val manager = new CachedTableManager( + preSignedUrlExpirationMs = 10, + refreshCheckIntervalMs = 10, + refreshThresholdMs = 10, + expireAfterAccessMs = 60000 + ) + try { + val ref = new AnyRef + val provider = new TestDeltaSharingProfileProvider + manager.register( + "test-table-path", + Map("id1" -> "url1", "id2" -> "url2"), + Seq(new WeakReference(ref)), + provider, + refreshToken => { + if (refreshToken.contains("refresh-token-1")) { + TableRefreshResult( + Map("id1" -> "url3", "id2" -> "url4"), + None, + Some("refresh-token-2") + ) + } else if (refreshToken.contains("refresh-token-2")) { + TableRefreshResult( + Map("id1" -> "url5", "id2" -> "url6"), + None, + Some("refresh-token-2") + ) + } else { + fail("Expecting to refresh with a refresh token") + } + }, + refreshToken = Some("refresh-token-1") + ) + // We should get url5 and url6 eventually. + eventually(timeout(10.seconds)) { + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), + "id1")._1 == "url5") + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), + "id2")._1 == "url6") + } + } finally { + manager.stop() + } + } + test("expireAfterAccessMs") { val manager = new CachedTableManager( preSignedUrlExpirationMs = 10, @@ -213,9 +269,10 @@ class CachedTableManagerSuite extends SparkFunSuite { Map("id1" -> "url1", "id2" -> "url2"), Seq(new WeakReference(ref)), provider, - () => { - (Map("id1" -> "url1", "id2" -> "url2"), None) - }) + _ => { + TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None) + }, + refreshToken = None) Thread.sleep(1000) // We should remove the cached table when it's not accessed intercept[IllegalStateException](manager.getPreSignedUrl( diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index 2fdfede19..ad4df8a4f 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -21,28 +21,16 @@ import java.lang.ref.WeakReference import scala.collection.mutable.ArrayBuffer -import org.apache.spark.delta.sharing.CachedTableManager +import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, DeltaSharingScanUtils, SparkSession} import org.apache.spark.sql.connector.read.streaming -import org.apache.spark.sql.connector.read.streaming.{ - ReadAllAvailable, - ReadLimit, - ReadMaxFiles, - SupportsAdmissionControl -} +import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit, ReadMaxFiles, SupportsAdmissionControl} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.types.StructType -import io.delta.sharing.client.model.{ - AddCDCFile, - AddFile, - AddFileForCDF, - DeltaTableFiles, - FileAction, - RemoveFile -} +import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, FileAction, RemoveFile} import io.delta.sharing.spark.util.SchemaUtils /** @@ -162,7 +150,9 @@ case class DeltaSharingSource( // The latest function used to fetch presigned urls for the delta sharing table, record it in // a variable to be used by the CachedTableManager to refresh the presigned urls if the query // runs for a long time. - private var latestRefreshFunc = () => { (Map.empty[String, String], None: Option[Long]) } + private var latestRefreshFunc = (_: Option[String]) => { + TableRefreshResult(Map.empty[String, String], None, None) + } // Check the latest table version from the delta sharing server through the client.getTableVersion // RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive @@ -413,7 +403,7 @@ case class DeltaSharingSource( jsonPredicateHints = None, refreshToken = None ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val files = deltaLog.client.getFiles( table = deltaLog.table, @@ -439,7 +429,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration) + TableRefreshResult(idToUrl, minUrlExpiration, None) } val numFiles = tableFiles.files.size @@ -473,7 +463,7 @@ case class DeltaSharingSource( val tableFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val addFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) @@ -493,7 +483,7 @@ case class DeltaSharingSource( refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - (idToUrl, minUrlExpiration) + TableRefreshResult(idToUrl, minUrlExpiration, None) } val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version) for (v <- fromVersion to endingVersionForQuery) { @@ -545,7 +535,7 @@ case class DeltaSharingSource( ), true ) - latestRefreshFunc = () => { + latestRefreshFunc = _ => { val queryTimestamp = System.currentTimeMillis() val d = deltaLog.client.getCDFFiles( deltaLog.table, @@ -562,10 +552,7 @@ case class DeltaSharingSource( d.addFiles, d.cdfFiles, d.removeFiles) refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) - ( - idToUrl, - minUrlExpiration - ) + TableRefreshResult(idToUrl, minUrlExpiration, None) } (Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m => @@ -774,7 +761,8 @@ case class DeltaSharingSource( urlExpirationTimestamp.get } else { lastQueryTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + None ) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala index 8341643ca..62c53689e 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala @@ -20,7 +20,7 @@ import java.lang.ref.WeakReference import scala.collection.mutable.ListBuffer -import org.apache.spark.delta.sharing.CachedTableManager +import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DeltaSharingScanUtils, Row, SparkSession, SQLContext} import org.apache.spark.sql.execution.LogicalRDD @@ -56,11 +56,12 @@ case class RemoteDeltaCDFRelation( deltaTabelFiles.removeFiles, DeltaTableUtils.addCdcSchema(deltaTabelFiles.metadata.schemaString), false, - () => { + _ => { val d = client.getCDFFiles(table, cdfOptions, false) - ( + TableRefreshResult( DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles), - DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles) + DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles), + None ) }, System.currentTimeMillis(), @@ -82,7 +83,7 @@ object DeltaSharingCDFReader { removeFiles: Seq[RemoveFile], schema: StructType, isStreaming: Boolean, - refresher: () => (Map[String, String], Option[Long]), + refresher: Option[String] => TableRefreshResult, lastQueryTableTimestamp: Long, expirationTimestamp: Option[Long] ): DataFrame = { @@ -111,7 +112,8 @@ object DeltaSharingCDFReader { expirationTimestamp.get } else { lastQueryTableTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + None ) dfs.reduce((df1, df2) => df1.unionAll(df2)) diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala index 18028eee8..0ed75266f 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala @@ -20,7 +20,7 @@ import java.lang.ref.WeakReference import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.delta.sharing.CachedTableManager +import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, Encoder, SparkSession} import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} @@ -32,15 +32,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{DataType, StructField, StructType} import io.delta.sharing.client.{DeltaSharingClient, DeltaSharingProfileProvider, DeltaSharingRestClient} -import io.delta.sharing.client.model.{ - AddFile, - CDFColumnInfo, - DeltaTableFiles, - FileAction, - Metadata, - Protocol, - Table => DeltaSharingTable -} +import io.delta.sharing.client.model.{AddFile, CDFColumnInfo, DeltaTableFiles, FileAction, Metadata, Protocol, Table => DeltaSharingTable} import io.delta.sharing.client.util.ConfUtils import io.delta.sharing.spark.perf.DeltaSharingLimitPushDown @@ -236,12 +228,12 @@ class RemoteSnapshot( idToUrl, Seq(new WeakReference(fileIndex)), fileIndex.params.profileProvider, - () => { - // TODO: use tableFiles.refreshToken - val files = client.getFiles( - table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, None).files + refreshToken => { + val tableFiles = client.getFiles( + table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints, refreshToken + ) var minUrlExpiration: Option[Long] = None - val idToUrl = files.map { add => + val idToUrl = tableFiles.files.map { add => if (add.expirationTimestamp != null) { minUrlExpiration = if (minUrlExpiration.isDefined && minUrlExpiration.get < add.expirationTimestamp) { @@ -252,13 +244,14 @@ class RemoteSnapshot( } add.id -> add.url }.toMap - (idToUrl, minUrlExpiration) + TableRefreshResult(idToUrl, minUrlExpiration, tableFiles.refreshToken) }, if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { minUrlExpirationTimestamp.get } else { System.currentTimeMillis() + CachedTableManager.INSTANCE.preSignedUrlExpirationMs - } + }, + tableFiles.refreshToken ) checkProtocolNotChange(tableFiles.protocol) checkSchemaNotChange(tableFiles.metadata)