Skip to content

Commit

Permalink
support refresh token in ds server
Browse files Browse the repository at this point in the history
  • Loading branch information
charlenelyu-db committed Aug 31, 2023
1 parent c3b9ffa commit 279759b
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 46 deletions.
19 changes: 19 additions & 0 deletions server/src/main/protobuf/protocol.proto
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ message QueryTableRequest {
optional int32 maxFiles = 9;
// The page token used to retrieve the subsequent page.
optional string pageToken = 10;
// Whether or not to return a refresh token in the response. Only used in latest snapshot query
// AND first page query. For long running queries, delta sharing spark may make additional request
// to refresh pre-signed urls, and there might be table changes between the initial request and
// the refresh request. The refresh token will contain version information to make sure that
// the refresh request returns the same set of files.
optional bool includeRefreshToken = 11;
// The refresh token used to refresh pre-signed urls. Only used in latest snapshot query AND
// first page query.
optional string refreshToken = 12;

// Only one of the three parameters can be supported in a single query.
// If none of them is specified, the query is for the latest version.
Expand Down Expand Up @@ -136,3 +145,13 @@ message QueryTablePageToken {
// The latest version of the table when the first page request is received.
optional int64 latest_version = 8;
}

// Define a special class to generate the refresh token for latest snapshot query.
message RefreshToken {
// Id of the table being queried.
optional string id = 1;
// Only used in queryTable at snapshot, refers to the version being queried.
optional int64 version = 2;
// The expiration timestamp of the refresh token in milliseconds.
optional int64 expiration_timestamp = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ class DeltaSharingService(serverConfig: ServerConfig) {
endingVersion = None,
maxFiles = None,
pageToken = None,
includeRefreshToken = false,
refreshToken = None,
responseFormat = responseFormat)
streamingOutput(Some(v), responseFormat, actions)
}
Expand All @@ -336,8 +338,8 @@ class DeltaSharingService(serverConfig: ServerConfig) {
val capabilitiesMap = getDeltaSharingCapabilitiesMap(
req.headers().get(DELTA_SHARING_CAPABILITIES_HEADER)
)
val numVersionParams = Seq(request.version, request.timestamp, request.startingVersion)
.filter(_.isDefined).size
val numVersionParams =
Seq(request.version, request.timestamp, request.startingVersion).count(_.isDefined)
if (numVersionParams > 1) {
throw new DeltaSharingIllegalArgumentException(ErrorStrings.multipleParametersSetErrorMsg(
Seq("version", "timestamp", "startingVersion"))
Expand All @@ -352,6 +354,26 @@ class DeltaSharingService(serverConfig: ServerConfig) {
if (request.maxFiles.exists(_ <= 0)) {
throw new DeltaSharingIllegalArgumentException("maxFiles must be positive.")
}
if (numVersionParams > 0 && request.includeRefreshToken.contains(true)) {
throw new DeltaSharingIllegalArgumentException(
"includeRefreshToken must be used in latest version query."
)
}
if (request.pageToken.isDefined && request.includeRefreshToken.contains(true)) {
throw new DeltaSharingIllegalArgumentException(
"includeRefreshToken must be used in the first page request."
)
}
if (numVersionParams > 0 && request.refreshToken.isDefined) {
throw new DeltaSharingIllegalArgumentException(
"refreshToken must be used in latest version query."
)
}
if (request.pageToken.isDefined && request.refreshToken.isDefined) {
throw new DeltaSharingIllegalArgumentException(
"refreshToken must be used in the first page request."
)
}

val start = System.currentTimeMillis
val tableConfig = sharedTableManager.getTable(share, schema, table)
Expand Down Expand Up @@ -387,6 +409,8 @@ class DeltaSharingService(serverConfig: ServerConfig) {
request.endingVersion,
request.maxFiles,
request.pageToken,
request.includeRefreshToken.getOrElse(false),
request.refreshToken,
responseFormat = responseFormat)
if (version < tableConfig.startVersion) {
throw new DeltaSharingIllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ case class ServerConfig(
// The maximum page size permitted by queryTable/queryTableChanges API.
@BeanProperty var queryTablePageSizeLimit: Int,
// The TTL of the page token generated in queryTable/queryTableChanges API (in milliseconds).
@BeanProperty var queryTablePageTokenTtlMs: Int
@BeanProperty var queryTablePageTokenTtlMs: Int,
// The TTL of the refresh token generated in queryTable API (in milliseconds).
@BeanProperty var refreshTokenTtlMs: Int
) extends ConfigItem {
import ServerConfig._

Expand All @@ -82,7 +84,8 @@ case class ServerConfig(
evaluateJsonPredicateHints = false,
requestTimeoutSeconds = 30,
queryTablePageSizeLimit = 10000,
queryTablePageTokenTtlMs = 259200000 // 3 days
queryTablePageTokenTtlMs = 259200000, // 3 days
refreshTokenTtlMs = 3600000 // 1 hour
)
}

Expand Down
2 changes: 2 additions & 0 deletions server/src/main/scala/io/delta/sharing/server/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,13 @@ case class RemoveFile(
* An action that is returned as the last line of the streaming response. It allows the server
* to include additional data that might be dynamically generated while the streaming message
* is sent, such as:
* - refreshToken: a token used to refresh pre-signed urls for a long running query
* - nextPageToken: a token used to retrieve the subsequent page of a query
* - minUrlExpirationTimestamp: the minimum url expiration timestamp of the urls returned in
* current response
*/
case class EndStreamAction(
refreshToken: String,
nextPageToken: String,
minUrlExpirationTimestamp: java.lang.Long
) extends Action {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import scala.collection.JavaConverters._
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem
import com.google.common.cache.CacheBuilder
import com.google.common.hash.Hashing
import com.google.common.util.concurrent.UncheckedExecutionException
import io.delta.standalone.DeltaLog
import io.delta.standalone.internal.actions.{AddCDCFile, AddFile, Metadata, Protocol, RemoveFile}
import io.delta.standalone.internal.exception.DeltaErrors
Expand All @@ -41,21 +40,11 @@ import org.apache.hadoop.fs.s3a.S3AFileSystem
import org.apache.spark.sql.types.{DataType, MetadataBuilder, StructType}
import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal
import scalapb.{GeneratedMessage, GeneratedMessageCompanion}

import io.delta.sharing.server.{
model,
AbfsFileSigner,
CausedBy,
DeltaSharingIllegalArgumentException,
DeltaSharingUnsupportedOperationException,
ErrorStrings,
GCSFileSigner,
PreSignedUrl,
S3FileSigner,
WasbFileSigner
}
import io.delta.sharing.server.{model, AbfsFileSigner, CausedBy, DeltaSharingIllegalArgumentException, DeltaSharingUnsupportedOperationException, ErrorStrings, GCSFileSigner, PreSignedUrl, S3FileSigner, WasbFileSigner}
import io.delta.sharing.server.config.{ServerConfig, TableConfig}
import io.delta.sharing.server.protocol.QueryTablePageToken
import io.delta.sharing.server.protocol.{QueryTablePageToken, RefreshToken}
import io.delta.sharing.server.util.JsonUtils

/**
Expand All @@ -82,7 +71,8 @@ class DeltaSharedTableLoader(serverConfig: ServerConfig) {
serverConfig.evaluatePredicateHints,
serverConfig.evaluateJsonPredicateHints,
serverConfig.queryTablePageSizeLimit,
serverConfig.queryTablePageTokenTtlMs
serverConfig.queryTablePageTokenTtlMs,
serverConfig.refreshTokenTtlMs
)
}
)
Expand Down Expand Up @@ -122,7 +112,8 @@ class DeltaSharedTable(
evaluatePredicateHints: Boolean,
evaluateJsonPredicateHints: Boolean,
queryTablePageSizeLimit: Int,
queryTablePageTokenTtlMs: Int) {
queryTablePageTokenTtlMs: Int,
refreshTokenTtlMs: Int) {

private val conf = withClassLoader {
new Configuration()
Expand Down Expand Up @@ -376,8 +367,10 @@ class DeltaSharedTable(
// Construct and return the end action of the streaming response.
private def getEndStreamAction(
nextPageTokenStr: String,
minUrlExpirationTimestamp: Long): model.SingleAction = {
minUrlExpirationTimestamp: Long,
refreshTokenStr: String = null): model.SingleAction = {
model.EndStreamAction(
refreshTokenStr,
nextPageTokenStr,
if (minUrlExpirationTimestamp == Long.MaxValue) null else minUrlExpirationTimestamp
).wrap
Expand All @@ -395,6 +388,8 @@ class DeltaSharedTable(
endingVersion: Option[Long],
maxFiles: Option[Int],
pageToken: Option[String],
includeRefreshToken: Boolean,
refreshToken: Option[String],
responseFormat: String): (Long, Seq[Object]) = withClassLoader {
// scalastyle:on argcount
// TODO Support `limitHint`
Expand All @@ -419,17 +414,23 @@ class DeltaSharedTable(
)
)
val pageTokenOpt = pageToken.map(decodeAndValidatePageToken(_, queryParamChecksum))
// For queryTable at snapshot, override version in subsequent page calls using the version
// in the pageToken to make sure we're querying the same version across pages. Especially
// when the first page is querying the latest snapshot, table changes that are committed
// after the first page call should be ignored.
val versionFromPageToken = pageTokenOpt.flatMap(_.version)
// Validate refreshToken if it's specified
val refreshTokenOpt = refreshToken.map(decodeAndValidateRefreshToken)
// The version of the snapshot should follow the below precedence:
// 1. Use version specified in the pageToken, which is equal to the version we use in the
// first page request. This is to make sure that responses are consistent across pages.
// 2. Use version/timestamp/startingVersion specified by the user.
// 3. Use version specified in the refreshToken, which is equal to latest table version upon
// initial request. In this case, it must be a latest snapshot query and version/timestamp/
// startingVersion must not be specified.
val specifiedVersion = pageTokenOpt.flatMap(_.version)
.orElse(version)
.orElse(startingVersion)
.orElse(refreshTokenOpt.map(_.getVersion))
val snapshot =
if (versionFromPageToken.orElse(version).orElse(startingVersion).isDefined) {
if (specifiedVersion.isDefined) {
try {
deltaLog.getSnapshotForVersionAsOf(
versionFromPageToken.orElse(version).orElse(startingVersion).get
)
deltaLog.getSnapshotForVersionAsOf(specifiedVersion.get)
} catch {
case e: io.delta.standalone.exceptions.DeltaStandaloneException =>
throw new DeltaSharingIllegalArgumentException(e.getMessage)
Expand Down Expand Up @@ -509,7 +510,7 @@ class DeltaSharedTable(
// If number of valid files is greater than page size, generate nextPageToken and
// drop additional files.
if (pageSizeOpt.exists(_ < filteredIndexedFiles.length)) {
nextPageTokenStr = encodeQueryTablePageToken(
nextPageTokenStr = DeltaSharedTable.encodeToken(
QueryTablePageToken(
id = Some(tableConfig.id),
version = Some(snapshot.version),
Expand All @@ -533,11 +534,22 @@ class DeltaSharedTable(
responseFormat
)
}
// Return an `endStreamAction` object only when `maxFiles` is specified for backwards
// compatibility.
val refreshTokenStr = if (includeRefreshToken) {
DeltaSharedTable.encodeToken(
RefreshToken(
id = Some(tableConfig.id),
version = Some(snapshot.version),
expirationTimestamp = Some(System.currentTimeMillis() + refreshTokenTtlMs)
)
)
} else {
null
}
// For backwards compatibility, return an `endStreamAction` object only when
// `includeRefreshToken` is true or `maxFiles` is specified
filteredFiles ++ {
if (maxFiles.isDefined) {
Seq(getEndStreamAction(nextPageTokenStr, minUrlExpirationTimestamp))
if (includeRefreshToken || maxFiles.isDefined) {
Seq(getEndStreamAction(nextPageTokenStr, minUrlExpirationTimestamp, refreshTokenStr))
} else {
Nil
}
Expand Down Expand Up @@ -589,7 +601,7 @@ class DeltaSharedTable(
// Enforce page size only when `maxFiles` is specified for backwards compatibility.
val pageSizeOpt = maxFilesOpt.map(_.min(queryTablePageSizeLimit))
val tokenGenerator = { (v: Long, idx: Int) =>
encodeQueryTablePageToken(
DeltaSharedTable.encodeToken(
QueryTablePageToken(
id = Some(tableConfig.id),
startingVersion = Some(v),
Expand Down Expand Up @@ -735,7 +747,7 @@ class DeltaSharedTable(
// Enforce page size only when `maxFiles` is specified for backwards compatibility.
val pageSizeOpt = maxFiles.map(_.min(queryTablePageSizeLimit))
val tokenGenerator = { (v: Long, idx: Int) =>
encodeQueryTablePageToken(
DeltaSharedTable.encodeToken(
QueryTablePageToken(
id = Some(tableConfig.id),
startingVersion = Some(v),
Expand Down Expand Up @@ -893,7 +905,14 @@ class DeltaSharedTable(
private def decodeAndValidatePageToken(
tokenStr: String,
expectedChecksum: String): QueryTablePageToken = {
val token = decodeQueryTablePageToken(tokenStr)
val token = try {
DeltaSharedTable.decodeToken[QueryTablePageToken](tokenStr)
} catch {
case NonFatal(_) =>
throw new DeltaSharingIllegalArgumentException(
s"Error decoding the page token: $tokenStr."
)
}
if (token.getExpirationTimestamp < System.currentTimeMillis()) {
throw new DeltaSharingIllegalArgumentException(
"The page token has expired. Please restart the query."
Expand All @@ -913,23 +932,39 @@ class DeltaSharedTable(
token
}

private def encodeQueryTablePageToken(token: QueryTablePageToken): String = {
Base64.getUrlEncoder.encodeToString(token.toByteArray)
}

private def decodeQueryTablePageToken(tokenStr: String): QueryTablePageToken = {
try {
QueryTablePageToken.parseFrom(Base64.getUrlDecoder.decode(tokenStr))
private def decodeAndValidateRefreshToken(tokenStr: String): RefreshToken = {
val token = try {
DeltaSharedTable.decodeToken[RefreshToken](tokenStr)
} catch {
case NonFatal(_) =>
throw new DeltaSharingIllegalArgumentException(
s"Error decoding the page token: $tokenStr."
s"Error decoding refresh token: $tokenStr."
)
}
if (token.getExpirationTimestamp < System.currentTimeMillis()) {
throw new DeltaSharingIllegalArgumentException(
"The refresh token has expired. Please restart the query."
)
}
if (token.getId != tableConfig.id) {
throw new DeltaSharingIllegalArgumentException(
"The table specified in the refresh token does not match the table being queried."
)
}
token
}
}

object DeltaSharedTable {
val RESPONSE_FORMAT_PARQUET = "parquet"
val RESPONSE_FORMAT_DELTA = "delta"

private def encodeToken[T <: GeneratedMessage](token: T): String = {
Base64.getUrlEncoder.encodeToString(token.toByteArray)
}

private def decodeToken[T <: GeneratedMessage](tokenStr: String)(
implicit protoCompanion: GeneratedMessageCompanion[T]): T = {
protoCompanion.parseFrom(Base64.getUrlDecoder.decode(tokenStr))
}
}
Loading

0 comments on commit 279759b

Please sign in to comment.