From 279759b053f72500097bce5d401179428380746e Mon Sep 17 00:00:00 2001 From: Charlene Lyu Date: Thu, 31 Aug 2023 15:08:56 -0700 Subject: [PATCH] support refresh token in ds server --- server/src/main/protobuf/protocol.proto | 19 +++ .../sharing/server/DeltaSharingService.scala | 28 ++++- .../sharing/server/config/ServerConfig.scala | 7 +- .../scala/io/delta/sharing/server/model.scala | 2 + .../internal/DeltaSharedTableLoader.scala | 119 +++++++++++------- .../server/DeltaSharingServiceSuite.scala | 113 +++++++++++++++++ .../conf/delta-sharing-server.yaml.template | 2 + 7 files changed, 244 insertions(+), 46 deletions(-) diff --git a/server/src/main/protobuf/protocol.proto b/server/src/main/protobuf/protocol.proto index 8ec9a3de4..d70361d27 100644 --- a/server/src/main/protobuf/protocol.proto +++ b/server/src/main/protobuf/protocol.proto @@ -54,6 +54,15 @@ message QueryTableRequest { optional int32 maxFiles = 9; // The page token used to retrieve the subsequent page. optional string pageToken = 10; + // Whether or not to return a refresh token in the response. Only used in latest snapshot query + // AND first page query. For long running queries, delta sharing spark may make additional request + // to refresh pre-signed urls, and there might be table changes between the initial request and + // the refresh request. The refresh token will contain version information to make sure that + // the refresh request returns the same set of files. + optional bool includeRefreshToken = 11; + // The refresh token used to refresh pre-signed urls. Only used in latest snapshot query AND + // first page query. + optional string refreshToken = 12; // Only one of the three parameters can be supported in a single query. // If none of them is specified, the query is for the latest version. @@ -136,3 +145,13 @@ message QueryTablePageToken { // The latest version of the table when the first page request is received. optional int64 latest_version = 8; } + +// Define a special class to generate the refresh token for latest snapshot query. +message RefreshToken { + // Id of the table being queried. + optional string id = 1; + // Only used in queryTable at snapshot, refers to the version being queried. + optional int64 version = 2; + // The expiration timestamp of the refresh token in milliseconds. + optional int64 expiration_timestamp = 3; +} diff --git a/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala b/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala index 79fa6fc6e..89080ef1d 100644 --- a/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala +++ b/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala @@ -321,6 +321,8 @@ class DeltaSharingService(serverConfig: ServerConfig) { endingVersion = None, maxFiles = None, pageToken = None, + includeRefreshToken = false, + refreshToken = None, responseFormat = responseFormat) streamingOutput(Some(v), responseFormat, actions) } @@ -336,8 +338,8 @@ class DeltaSharingService(serverConfig: ServerConfig) { val capabilitiesMap = getDeltaSharingCapabilitiesMap( req.headers().get(DELTA_SHARING_CAPABILITIES_HEADER) ) - val numVersionParams = Seq(request.version, request.timestamp, request.startingVersion) - .filter(_.isDefined).size + val numVersionParams = + Seq(request.version, request.timestamp, request.startingVersion).count(_.isDefined) if (numVersionParams > 1) { throw new DeltaSharingIllegalArgumentException(ErrorStrings.multipleParametersSetErrorMsg( Seq("version", "timestamp", "startingVersion")) @@ -352,6 +354,26 @@ class DeltaSharingService(serverConfig: ServerConfig) { if (request.maxFiles.exists(_ <= 0)) { throw new DeltaSharingIllegalArgumentException("maxFiles must be positive.") } + if (numVersionParams > 0 && request.includeRefreshToken.contains(true)) { + throw new DeltaSharingIllegalArgumentException( + "includeRefreshToken must be used in latest version query." + ) + } + if (request.pageToken.isDefined && request.includeRefreshToken.contains(true)) { + throw new DeltaSharingIllegalArgumentException( + "includeRefreshToken must be used in the first page request." + ) + } + if (numVersionParams > 0 && request.refreshToken.isDefined) { + throw new DeltaSharingIllegalArgumentException( + "refreshToken must be used in latest version query." + ) + } + if (request.pageToken.isDefined && request.refreshToken.isDefined) { + throw new DeltaSharingIllegalArgumentException( + "refreshToken must be used in the first page request." + ) + } val start = System.currentTimeMillis val tableConfig = sharedTableManager.getTable(share, schema, table) @@ -387,6 +409,8 @@ class DeltaSharingService(serverConfig: ServerConfig) { request.endingVersion, request.maxFiles, request.pageToken, + request.includeRefreshToken.getOrElse(false), + request.refreshToken, responseFormat = responseFormat) if (version < tableConfig.startVersion) { throw new DeltaSharingIllegalArgumentException( diff --git a/server/src/main/scala/io/delta/sharing/server/config/ServerConfig.scala b/server/src/main/scala/io/delta/sharing/server/config/ServerConfig.scala index cc9232f8a..1cec58f05 100644 --- a/server/src/main/scala/io/delta/sharing/server/config/ServerConfig.scala +++ b/server/src/main/scala/io/delta/sharing/server/config/ServerConfig.scala @@ -61,7 +61,9 @@ case class ServerConfig( // The maximum page size permitted by queryTable/queryTableChanges API. @BeanProperty var queryTablePageSizeLimit: Int, // The TTL of the page token generated in queryTable/queryTableChanges API (in milliseconds). - @BeanProperty var queryTablePageTokenTtlMs: Int + @BeanProperty var queryTablePageTokenTtlMs: Int, + // The TTL of the refresh token generated in queryTable API (in milliseconds). + @BeanProperty var refreshTokenTtlMs: Int ) extends ConfigItem { import ServerConfig._ @@ -82,7 +84,8 @@ case class ServerConfig( evaluateJsonPredicateHints = false, requestTimeoutSeconds = 30, queryTablePageSizeLimit = 10000, - queryTablePageTokenTtlMs = 259200000 // 3 days + queryTablePageTokenTtlMs = 259200000, // 3 days + refreshTokenTtlMs = 3600000 // 1 hour ) } diff --git a/server/src/main/scala/io/delta/sharing/server/model.scala b/server/src/main/scala/io/delta/sharing/server/model.scala index 2be48e54e..a0a8fbe6d 100644 --- a/server/src/main/scala/io/delta/sharing/server/model.scala +++ b/server/src/main/scala/io/delta/sharing/server/model.scala @@ -137,11 +137,13 @@ case class RemoveFile( * An action that is returned as the last line of the streaming response. It allows the server * to include additional data that might be dynamically generated while the streaming message * is sent, such as: + * - refreshToken: a token used to refresh pre-signed urls for a long running query * - nextPageToken: a token used to retrieve the subsequent page of a query * - minUrlExpirationTimestamp: the minimum url expiration timestamp of the urls returned in * current response */ case class EndStreamAction( + refreshToken: String, nextPageToken: String, minUrlExpirationTimestamp: java.lang.Long ) extends Action { diff --git a/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala b/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala index faa852b43..212ff8e82 100644 --- a/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala +++ b/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConverters._ import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem import com.google.common.cache.CacheBuilder import com.google.common.hash.Hashing -import com.google.common.util.concurrent.UncheckedExecutionException import io.delta.standalone.DeltaLog import io.delta.standalone.internal.actions.{AddCDCFile, AddFile, Metadata, Protocol, RemoveFile} import io.delta.standalone.internal.exception.DeltaErrors @@ -41,21 +40,11 @@ import org.apache.hadoop.fs.s3a.S3AFileSystem import org.apache.spark.sql.types.{DataType, MetadataBuilder, StructType} import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal +import scalapb.{GeneratedMessage, GeneratedMessageCompanion} -import io.delta.sharing.server.{ - model, - AbfsFileSigner, - CausedBy, - DeltaSharingIllegalArgumentException, - DeltaSharingUnsupportedOperationException, - ErrorStrings, - GCSFileSigner, - PreSignedUrl, - S3FileSigner, - WasbFileSigner -} +import io.delta.sharing.server.{model, AbfsFileSigner, CausedBy, DeltaSharingIllegalArgumentException, DeltaSharingUnsupportedOperationException, ErrorStrings, GCSFileSigner, PreSignedUrl, S3FileSigner, WasbFileSigner} import io.delta.sharing.server.config.{ServerConfig, TableConfig} -import io.delta.sharing.server.protocol.QueryTablePageToken +import io.delta.sharing.server.protocol.{QueryTablePageToken, RefreshToken} import io.delta.sharing.server.util.JsonUtils /** @@ -82,7 +71,8 @@ class DeltaSharedTableLoader(serverConfig: ServerConfig) { serverConfig.evaluatePredicateHints, serverConfig.evaluateJsonPredicateHints, serverConfig.queryTablePageSizeLimit, - serverConfig.queryTablePageTokenTtlMs + serverConfig.queryTablePageTokenTtlMs, + serverConfig.refreshTokenTtlMs ) } ) @@ -122,7 +112,8 @@ class DeltaSharedTable( evaluatePredicateHints: Boolean, evaluateJsonPredicateHints: Boolean, queryTablePageSizeLimit: Int, - queryTablePageTokenTtlMs: Int) { + queryTablePageTokenTtlMs: Int, + refreshTokenTtlMs: Int) { private val conf = withClassLoader { new Configuration() @@ -376,8 +367,10 @@ class DeltaSharedTable( // Construct and return the end action of the streaming response. private def getEndStreamAction( nextPageTokenStr: String, - minUrlExpirationTimestamp: Long): model.SingleAction = { + minUrlExpirationTimestamp: Long, + refreshTokenStr: String = null): model.SingleAction = { model.EndStreamAction( + refreshTokenStr, nextPageTokenStr, if (minUrlExpirationTimestamp == Long.MaxValue) null else minUrlExpirationTimestamp ).wrap @@ -395,6 +388,8 @@ class DeltaSharedTable( endingVersion: Option[Long], maxFiles: Option[Int], pageToken: Option[String], + includeRefreshToken: Boolean, + refreshToken: Option[String], responseFormat: String): (Long, Seq[Object]) = withClassLoader { // scalastyle:on argcount // TODO Support `limitHint` @@ -419,17 +414,23 @@ class DeltaSharedTable( ) ) val pageTokenOpt = pageToken.map(decodeAndValidatePageToken(_, queryParamChecksum)) - // For queryTable at snapshot, override version in subsequent page calls using the version - // in the pageToken to make sure we're querying the same version across pages. Especially - // when the first page is querying the latest snapshot, table changes that are committed - // after the first page call should be ignored. - val versionFromPageToken = pageTokenOpt.flatMap(_.version) + // Validate refreshToken if it's specified + val refreshTokenOpt = refreshToken.map(decodeAndValidateRefreshToken) + // The version of the snapshot should follow the below precedence: + // 1. Use version specified in the pageToken, which is equal to the version we use in the + // first page request. This is to make sure that responses are consistent across pages. + // 2. Use version/timestamp/startingVersion specified by the user. + // 3. Use version specified in the refreshToken, which is equal to latest table version upon + // initial request. In this case, it must be a latest snapshot query and version/timestamp/ + // startingVersion must not be specified. + val specifiedVersion = pageTokenOpt.flatMap(_.version) + .orElse(version) + .orElse(startingVersion) + .orElse(refreshTokenOpt.map(_.getVersion)) val snapshot = - if (versionFromPageToken.orElse(version).orElse(startingVersion).isDefined) { + if (specifiedVersion.isDefined) { try { - deltaLog.getSnapshotForVersionAsOf( - versionFromPageToken.orElse(version).orElse(startingVersion).get - ) + deltaLog.getSnapshotForVersionAsOf(specifiedVersion.get) } catch { case e: io.delta.standalone.exceptions.DeltaStandaloneException => throw new DeltaSharingIllegalArgumentException(e.getMessage) @@ -509,7 +510,7 @@ class DeltaSharedTable( // If number of valid files is greater than page size, generate nextPageToken and // drop additional files. if (pageSizeOpt.exists(_ < filteredIndexedFiles.length)) { - nextPageTokenStr = encodeQueryTablePageToken( + nextPageTokenStr = DeltaSharedTable.encodeToken( QueryTablePageToken( id = Some(tableConfig.id), version = Some(snapshot.version), @@ -533,11 +534,22 @@ class DeltaSharedTable( responseFormat ) } - // Return an `endStreamAction` object only when `maxFiles` is specified for backwards - // compatibility. + val refreshTokenStr = if (includeRefreshToken) { + DeltaSharedTable.encodeToken( + RefreshToken( + id = Some(tableConfig.id), + version = Some(snapshot.version), + expirationTimestamp = Some(System.currentTimeMillis() + refreshTokenTtlMs) + ) + ) + } else { + null + } + // For backwards compatibility, return an `endStreamAction` object only when + // `includeRefreshToken` is true or `maxFiles` is specified filteredFiles ++ { - if (maxFiles.isDefined) { - Seq(getEndStreamAction(nextPageTokenStr, minUrlExpirationTimestamp)) + if (includeRefreshToken || maxFiles.isDefined) { + Seq(getEndStreamAction(nextPageTokenStr, minUrlExpirationTimestamp, refreshTokenStr)) } else { Nil } @@ -589,7 +601,7 @@ class DeltaSharedTable( // Enforce page size only when `maxFiles` is specified for backwards compatibility. val pageSizeOpt = maxFilesOpt.map(_.min(queryTablePageSizeLimit)) val tokenGenerator = { (v: Long, idx: Int) => - encodeQueryTablePageToken( + DeltaSharedTable.encodeToken( QueryTablePageToken( id = Some(tableConfig.id), startingVersion = Some(v), @@ -735,7 +747,7 @@ class DeltaSharedTable( // Enforce page size only when `maxFiles` is specified for backwards compatibility. val pageSizeOpt = maxFiles.map(_.min(queryTablePageSizeLimit)) val tokenGenerator = { (v: Long, idx: Int) => - encodeQueryTablePageToken( + DeltaSharedTable.encodeToken( QueryTablePageToken( id = Some(tableConfig.id), startingVersion = Some(v), @@ -893,7 +905,14 @@ class DeltaSharedTable( private def decodeAndValidatePageToken( tokenStr: String, expectedChecksum: String): QueryTablePageToken = { - val token = decodeQueryTablePageToken(tokenStr) + val token = try { + DeltaSharedTable.decodeToken[QueryTablePageToken](tokenStr) + } catch { + case NonFatal(_) => + throw new DeltaSharingIllegalArgumentException( + s"Error decoding the page token: $tokenStr." + ) + } if (token.getExpirationTimestamp < System.currentTimeMillis()) { throw new DeltaSharingIllegalArgumentException( "The page token has expired. Please restart the query." @@ -913,23 +932,39 @@ class DeltaSharedTable( token } - private def encodeQueryTablePageToken(token: QueryTablePageToken): String = { - Base64.getUrlEncoder.encodeToString(token.toByteArray) - } - - private def decodeQueryTablePageToken(tokenStr: String): QueryTablePageToken = { - try { - QueryTablePageToken.parseFrom(Base64.getUrlDecoder.decode(tokenStr)) + private def decodeAndValidateRefreshToken(tokenStr: String): RefreshToken = { + val token = try { + DeltaSharedTable.decodeToken[RefreshToken](tokenStr) } catch { case NonFatal(_) => throw new DeltaSharingIllegalArgumentException( - s"Error decoding the page token: $tokenStr." + s"Error decoding refresh token: $tokenStr." ) } + if (token.getExpirationTimestamp < System.currentTimeMillis()) { + throw new DeltaSharingIllegalArgumentException( + "The refresh token has expired. Please restart the query." + ) + } + if (token.getId != tableConfig.id) { + throw new DeltaSharingIllegalArgumentException( + "The table specified in the refresh token does not match the table being queried." + ) + } + token } } object DeltaSharedTable { val RESPONSE_FORMAT_PARQUET = "parquet" val RESPONSE_FORMAT_DELTA = "delta" + + private def encodeToken[T <: GeneratedMessage](token: T): String = { + Base64.getUrlEncoder.encodeToString(token.toByteArray) + } + + private def decodeToken[T <: GeneratedMessage](tokenStr: String)( + implicit protoCompanion: GeneratedMessageCompanion[T]): T = { + protoCompanion.parseFrom(Base64.getUrlDecoder.decode(tokenStr)) + } } diff --git a/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala b/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala index 8807d3191..2b928e4b1 100644 --- a/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala +++ b/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala @@ -739,6 +739,119 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { } } + integrationTest("refresh query returns the same set of files as initial query") { + val initialResponse = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some("""{"includeRefreshToken": true}"""), + Some(2) + ).split("\n") + assert(initialResponse.length == 5) + val endAction = JsonUtils.fromJson[SingleAction](initialResponse.last).endStreamAction + assert(endAction.refreshToken != null) + + val refreshResponse = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some(s"""{"includeRefreshToken": true, "refreshToken": "${endAction.refreshToken}"}"""), + Some(2) + ).split("\n") + assert(refreshResponse.length == 5) + // protocol + assert(initialResponse(0) == refreshResponse(0)) + // metadata + assert(initialResponse(1) == refreshResponse(1)) + // files + val initialFiles = initialResponse.slice(2, 4).map(f => JsonUtils.fromJson[SingleAction](f).file) + val refreshedFiles = refreshResponse.slice(2, 4).map(f => JsonUtils.fromJson[SingleAction](f).file) + assert(initialFiles.map(_.id) sameElements refreshedFiles.map(_.id)) + } + + integrationTest("refresh query - exception") { + // invalid refresh token + assertHttpError( + url = requestPath("/shares/share1/schemas/default/tables/table1/query"), + method = "POST", + data = Some("""{"includeRefreshToken": true, "refreshToken": "foo"}"""), + expectedErrorCode = 400, + expectedErrorMessage = "Error decoding refresh token" + ) + + // invalid query parameters + assertHttpError( + url = requestPath("/shares/share8/schemas/default/tables/streaming_notnull_to_null/query"), + method = "POST", + data = Some("""{"version": 1, "includeRefreshToken": true}"""), + expectedErrorCode = 400, + expectedErrorMessage = "includeRefreshToken must be used in latest version query" + ) + assertHttpError( + url = requestPath("/shares/share1/schemas/default/tables/table1/query"), + method = "POST", + data = Some("""{"pageToken": "foo", "includeRefreshToken": true}"""), + expectedErrorCode = 400, + expectedErrorMessage = "includeRefreshToken must be used in the first page request" + ) + assertHttpError( + url = requestPath("/shares/share8/schemas/default/tables/streaming_notnull_to_null/query"), + method = "POST", + data = Some("""{"startingVersion": 1, "refreshToken": "foo"}"""), + expectedErrorCode = 400, + expectedErrorMessage = "refreshToken must be used in latest version query" + ) + assertHttpError( + url = requestPath("/shares/share1/schemas/default/tables/table1/query"), + method = "POST", + data = Some("""{"pageToken": "foo", "refreshToken": "foo"}"""), + expectedErrorCode = 400, + expectedErrorMessage = "refreshToken must be used in the first page request" + ) + + var response = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some("""{"includeRefreshToken": true}"""), + Some(2) + ) + var lines = response.split("\n") + assert(lines.length == 5) + var endAction = JsonUtils.fromJson[SingleAction](lines.last).endStreamAction + assert(endAction.refreshToken != null) + assertHttpError( + url = requestPath("/shares/share2/schemas/default/tables/table2/query"), + method = "POST", + data = Some(s"""{"includeRefreshToken": true, "refreshToken": "${endAction.refreshToken}"}"""), + expectedErrorCode = 400, + expectedErrorMessage = "The table specified in the refresh token does not match the table being queried" + ) + + // refresh token expired + val updatedServerConfig = serverConfig.copy(refreshTokenTtlMs = 0) + server.stop().get() + server = DeltaSharingService.start(updatedServerConfig) + response = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some("""{"includeRefreshToken": true}"""), + Some(2) + ) + lines = response.split("\n") + assert(lines.length == 5) + endAction = JsonUtils.fromJson[SingleAction](lines.last).endStreamAction + assert(endAction.refreshToken != null) + + assertHttpError( + url = requestPath("/shares/share1/schemas/default/tables/table1/query"), + method = "POST", + data = Some(s"""{"includeRefreshToken": true, "refreshToken": "${endAction.refreshToken}"}"""), + expectedErrorCode = 400, + expectedErrorMessage = "The refresh token has expired" + ) + + server.stop().get() + server = DeltaSharingService.start(serverConfig) + } + integrationTest("table2 - partitioned - /shares/{share}/schemas/{schema}/tables/{table}/metadata") { val response = readNDJson(requestPath("/shares/share2/schemas/default/tables/table2/metadata"), expectedTableVersion = Some(2)) val Array(protocol, metadata) = response.split("\n") diff --git a/server/src/universal/conf/delta-sharing-server.yaml.template b/server/src/universal/conf/delta-sharing-server.yaml.template index 7d79a5db6..e6ea8c766 100644 --- a/server/src/universal/conf/delta-sharing-server.yaml.template +++ b/server/src/universal/conf/delta-sharing-server.yaml.template @@ -61,3 +61,5 @@ evaluateJsonPredicateHints: false queryTablePageSizeLimit: 10000 # The TTL of the page token generated in queryTable/queryTableChanges API (in milliseconds). queryTablePageTokenTtlMs: 259200000 +# The TTL of the refresh token generated in queryTable API (in milliseconds). +queryTablePageTokenTtlMs: 3600000