Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
charlenelyu-db committed Sep 1, 2023
1 parent 6dc8ba1 commit 02ddb7e
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.delta.sharing.TableRefreshResult

import io.delta.sharing.client.util.JsonUtils

Expand Down Expand Up @@ -50,8 +51,7 @@ trait DeltaSharingProfileProvider {
// `refresher` takes an optional refreshToken, and returns
// (idToUrlMap, minUrlExpirationTimestamp, refreshToken)
def getCustomRefresher(
refresher: Option[String] => (Map[String, String], Option[Long], Option[String]))
: Option[String] => (Map[String, String], Option[Long], Option[String]) = {
refresher: Option[String] => TableRefreshResult): Option[String] => TableRefreshResult = {
refresher
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}

import io.delta.sharing.client.DeltaSharingProfileProvider

case class TableRefreshResult(
idToUrl: Map[String, String],
expirationTimestamp: Option[Long],
refreshToken: Option[String]
)

/**
* @param expiration the expiration time of the pre signed urls
* @param idToUrl the file id to pre sign url map
Expand All @@ -45,7 +51,7 @@ class CachedTable(
val idToUrl: Map[String, String],
val refs: Seq[WeakReference[AnyRef]],
@volatile var lastAccess: Long,
val refresher: Option[String] => (Map[String, String], Option[Long], Option[String]),
val refresher: Option[String] => TableRefreshResult,
val refreshToken: Option[String])

class CachedTableManager(
Expand Down Expand Up @@ -99,18 +105,18 @@ class CachedTableManager(
logInfo(s"Updating pre signed urls for $tablePath (expiration time: " +
s"${new java.util.Date(cachedTable.expiration)})")
try {
val (idToUrl, expOpt, refreshToken) = cachedTable.refresher(cachedTable.refreshToken)
val refreshRes = cachedTable.refresher(cachedTable.refreshToken)
val newTable = new CachedTable(
if (isValidUrlExpirationTime(expOpt)) {
expOpt.get
if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) {
refreshRes.expirationTimestamp.get
} else {
preSignedUrlExpirationMs + System.currentTimeMillis()
},
idToUrl,
refreshRes.idToUrl,
cachedTable.refs,
cachedTable.lastAccess,
cachedTable.refresher,
refreshToken
refreshRes.refreshToken
)
// Failing to replace the table is fine because if it did happen, we would retry after
// `refreshCheckIntervalMs` milliseconds.
Expand Down Expand Up @@ -168,7 +174,7 @@ class CachedTableManager(
idToUrl: Map[String, String],
refs: Seq[WeakReference[AnyRef]],
profileProvider: DeltaSharingProfileProvider,
refresher: Option[String] => (Map[String, String], Option[Long], Option[String]),
refresher: Option[String] => TableRefreshResult,
expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs,
refreshToken: Option[String]
): Unit = {
Expand All @@ -177,11 +183,15 @@ class CachedTableManager(

val (resolvedIdToUrl, resolvedExpiration, resolvedRefreshToken) =
if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) {
val (refreshedIdToUrl, expOpt, newRefreshToken) = customRefresher(refreshToken)
if (isValidUrlExpirationTime(expOpt)) {
(refreshedIdToUrl, expOpt.get, newRefreshToken)
val refreshRes = customRefresher(refreshToken)
if (isValidUrlExpirationTime(refreshRes.expirationTimestamp)) {
(refreshRes.idToUrl, refreshRes.expirationTimestamp.get, refreshRes.refreshToken)
} else {
(refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs, newRefreshToken)
(
refreshRes.idToUrl,
System.currentTimeMillis() + preSignedUrlExpirationMs,
refreshRes.refreshToken
)
}
} else {
(idToUrl, expirationTimestamp, refreshToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
Seq(new WeakReference(ref)),
provider,
_ => {
(Map("id1" -> "url1", "id2" -> "url2"), None, None)
TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None)
},
refreshToken = None)
assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"),
Expand All @@ -62,7 +62,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
Seq(new WeakReference(ref)),
provider,
_ => {
(Map("id1" -> "url3", "id2" -> "url4"), None, None)
TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None)
},
refreshToken = None)
// We should get the new urls eventually
Expand All @@ -79,7 +79,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
Seq(new WeakReference(new AnyRef)),
provider,
_ => {
(Map("id1" -> "url3", "id2" -> "url4"), None, None)
TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None)
},
refreshToken = None)
// We should remove the cached table eventually
Expand All @@ -97,7 +97,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
Seq(new WeakReference(ref)),
provider,
_ => {
(Map("id1" -> "url3", "id2" -> "url4"), None, None)
TableRefreshResult(Map("id1" -> "url3", "id2" -> "url4"), None, None)
},
-1,
refreshToken = None
Expand Down Expand Up @@ -130,7 +130,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
provider,
_ => {
refreshTime += 1
(
TableRefreshResult(
Map("id1" -> ("url" + refreshTime.toString), "id2" -> "url4"),
Some(System.currentTimeMillis() + 1900),
None
Expand All @@ -156,7 +156,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
provider,
_ => {
refreshTime2 += 1
(
TableRefreshResult(
Map("id1" -> ("url" + refreshTime2.toString), "id2" -> "url4"),
Some(System.currentTimeMillis() + 4900),
None
Expand All @@ -181,7 +181,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
provider,
_ => {
refreshTime3 += 1
(
TableRefreshResult(
Map("id1" -> ("url" + refreshTime3.toString), "id2" -> "url4"),
Some(System.currentTimeMillis() - 4900),
None
Expand Down Expand Up @@ -224,9 +224,17 @@ class CachedTableManagerSuite extends SparkFunSuite {
provider,
refreshToken => {
if (refreshToken.contains("refresh-token-1")) {
(Map("id1" -> "url3", "id2" -> "url4"), None, Some("refresh-token-2"))
TableRefreshResult(
Map("id1" -> "url3", "id2" -> "url4"),
None,
Some("refresh-token-2")
)
} else if (refreshToken.contains("refresh-token-2")) {
(Map("id1" -> "url5", "id2" -> "url6"), None, Some("refresh-token-2"))
TableRefreshResult(
Map("id1" -> "url5", "id2" -> "url6"),
None,
Some("refresh-token-2")
)
} else {
fail("Expecting to refresh with a refresh token")
}
Expand Down Expand Up @@ -262,7 +270,7 @@ class CachedTableManagerSuite extends SparkFunSuite {
Seq(new WeakReference(ref)),
provider,
_ => {
(Map("id1" -> "url1", "id2" -> "url2"), None, None)
TableRefreshResult(Map("id1" -> "url1", "id2" -> "url2"), None, None)
},
refreshToken = None)
Thread.sleep(1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,16 @@ import java.lang.ref.WeakReference

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.delta.sharing.CachedTableManager
import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, DeltaSharingScanUtils, SparkSession}
import org.apache.spark.sql.connector.read.streaming
import org.apache.spark.sql.connector.read.streaming.{
ReadAllAvailable,
ReadLimit,
ReadMaxFiles,
SupportsAdmissionControl
}
import org.apache.spark.sql.connector.read.streaming.{ReadAllAvailable, ReadLimit, ReadMaxFiles, SupportsAdmissionControl}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.types.StructType

import io.delta.sharing.client.model.{
AddCDCFile,
AddFile,
AddFileForCDF,
DeltaTableFiles,
FileAction,
RemoveFile
}
import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, FileAction, RemoveFile}
import io.delta.sharing.spark.util.SchemaUtils

/**
Expand Down Expand Up @@ -163,7 +151,7 @@ case class DeltaSharingSource(
// a variable to be used by the CachedTableManager to refresh the presigned urls if the query
// runs for a long time.
private var latestRefreshFunc = (_: Option[String]) => {
(Map.empty[String, String], Option.empty[Long], Option.empty[String])
TableRefreshResult(Map.empty[String, String], None, None)
}

// Check the latest table version from the delta sharing server through the client.getTableVersion
Expand Down Expand Up @@ -441,7 +429,7 @@ case class DeltaSharingSource(

refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration)

(idToUrl, minUrlExpiration, None)
TableRefreshResult(idToUrl, minUrlExpiration, None)
}

val numFiles = tableFiles.files.size
Expand Down Expand Up @@ -495,7 +483,7 @@ case class DeltaSharingSource(

refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration)

(idToUrl, minUrlExpiration, None)
TableRefreshResult(idToUrl, minUrlExpiration, None)
}
val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version)
for (v <- fromVersion to endingVersionForQuery) {
Expand Down Expand Up @@ -564,7 +552,7 @@ case class DeltaSharingSource(
d.addFiles, d.cdfFiles, d.removeFiles)
refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration)

(idToUrl, minUrlExpiration, None)
TableRefreshResult(idToUrl, minUrlExpiration, None)
}

(Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.lang.ref.WeakReference

import scala.collection.mutable.ListBuffer

import org.apache.spark.delta.sharing.CachedTableManager
import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DeltaSharingScanUtils, Row, SparkSession, SQLContext}
import org.apache.spark.sql.execution.LogicalRDD
Expand Down Expand Up @@ -58,7 +58,7 @@ case class RemoteDeltaCDFRelation(
false,
_ => {
val d = client.getCDFFiles(table, cdfOptions, false)
(
TableRefreshResult(
DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles),
DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles),
None
Expand All @@ -83,7 +83,7 @@ object DeltaSharingCDFReader {
removeFiles: Seq[RemoveFile],
schema: StructType,
isStreaming: Boolean,
refresher: Option[String] => (Map[String, String], Option[Long], Option[String]),
refresher: Option[String] => TableRefreshResult,
lastQueryTableTimestamp: Long,
expirationTimestamp: Option[Long]
): DataFrame = {
Expand Down
14 changes: 3 additions & 11 deletions spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.lang.ref.WeakReference

import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.delta.sharing.CachedTableManager
import org.apache.spark.delta.sharing.{CachedTableManager, TableRefreshResult}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, Encoder, SparkSession}
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
Expand All @@ -32,15 +32,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{DataType, StructField, StructType}

import io.delta.sharing.client.{DeltaSharingClient, DeltaSharingProfileProvider, DeltaSharingRestClient}
import io.delta.sharing.client.model.{
AddFile,
CDFColumnInfo,
DeltaTableFiles,
FileAction,
Metadata,
Protocol,
Table => DeltaSharingTable
}
import io.delta.sharing.client.model.{AddFile, CDFColumnInfo, DeltaTableFiles, FileAction, Metadata, Protocol, Table => DeltaSharingTable}
import io.delta.sharing.client.util.ConfUtils
import io.delta.sharing.spark.perf.DeltaSharingLimitPushDown

Expand Down Expand Up @@ -252,7 +244,7 @@ class RemoteSnapshot(
}
add.id -> add.url
}.toMap
(idToUrl, minUrlExpiration, tableFiles.refreshToken)
TableRefreshResult(idToUrl, minUrlExpiration, tableFiles.refreshToken)
},
if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) {
minUrlExpirationTimestamp.get
Expand Down

0 comments on commit 02ddb7e

Please sign in to comment.