Skip to content

Commit

Permalink
Client side changes
Browse files Browse the repository at this point in the history
  • Loading branch information
linzhou-db committed Oct 10, 2024
1 parent 5284fc9 commit a68fe09
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 51 deletions.
177 changes: 145 additions & 32 deletions spark/src/main/scala/io/delta/sharing/spark/DeltaSharingClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ import io.delta.sharing.spark.util.{JsonUtils, RetryUtils, UnexpectedHttpStatus}

/** An interface to fetch Delta metadata from remote server. */
private[sharing] trait DeltaSharingClient {

protected var dsQueryId: Option[String] = None

protected def getDsQueryIdForLogging: String = {
s" for query($dsQueryId)."
}

def listAllTables(): Seq[Table]

def getTableVersion(table: Table, startingTimestamp: Option[String] = None): Long
Expand Down Expand Up @@ -103,14 +110,13 @@ private[spark] class DeltaSharingRestClient(
numRetries: Int = 10,
maxRetryDuration: Long = Long.MaxValue,
sslTrustAll: Boolean = false,
forStreaming: Boolean = false
forStreaming: Boolean = false,
endStreamActionEnabled: Boolean = true
) extends DeltaSharingClient with Logging {
import DeltaSharingRestClient._

@volatile private var created = false

private var queryId: Option[String] = None

private lazy val client = {
val clientBuilder: HttpClientBuilder = if (sslTrustAll) {
val sslBuilder = new SSLContextBuilder()
Expand Down Expand Up @@ -200,9 +206,15 @@ private[spark] class DeltaSharingRestClient(
val target =
getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" +
s"$encodedTableName/version$encodedParam")
val (version, _) = getResponse(new HttpGet(target), true, true)
val (version, _) = getResponse(
new HttpGet(target),
allowNoContent = true,
fetchAsOneString = true,
setIncludeEndStreamAction = false
)
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
throw new IllegalStateException(s"Cannot find $RESPONSE_TABLE_VERSION_HEADER_KEY in the " +
s"header")
}
}

Expand All @@ -212,7 +224,9 @@ private[spark] class DeltaSharingRestClient(
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/metadata")
val (version, lines) = getNDJson(target)
val (version, lines) = getNDJson(
target, requireVersion = true, setIncludeEndStreamAction = false
)
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
Expand Down Expand Up @@ -245,7 +259,7 @@ private[spark] class DeltaSharingRestClient(
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/query")
val (version, lines) = getNDJson(
val (version, lines) = getNDJsonPost(
target,
QueryTableRequest(
predicates,
Expand All @@ -257,7 +271,8 @@ private[spark] class DeltaSharingRestClient(
jsonPredicateHints,
Some(includeRefreshToken),
refreshToken
)
),
setIncludeEndStreamAction = endStreamActionEnabled
)
val (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines)
val refreshTokenOpt = endStreamAction.flatMap { e =>
Expand Down Expand Up @@ -294,7 +309,7 @@ private[spark] class DeltaSharingRestClient(
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/query")
val (version, lines) = getNDJson(
val (version, lines) = getNDJsonPost(
target,
QueryTableRequest(
/* predicateHint */ Nil,
Expand All @@ -306,15 +321,17 @@ private[spark] class DeltaSharingRestClient(
/* jsonPredicateHints */ None,
/* includeRefreshToken */ None,
/* refreshToken */ None
)
),
setIncludeEndStreamAction = endStreamActionEnabled
)
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
val (filteredLines, _) = maybeExtractEndStreamAction(lines)
val protocol = JsonUtils.fromJson[SingleAction](filteredLines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
val metadata = JsonUtils.fromJson[SingleAction](filteredLines(1)).metaData
val addFiles = ArrayBuffer[AddFileForCDF]()
val removeFiles = ArrayBuffer[RemoveFile]()
val additionalMetadatas = ArrayBuffer[Metadata]()
lines.drop(2).foreach { line =>
filteredLines.drop(2).foreach { line =>
val action = JsonUtils.fromJson[SingleAction](line).unwrap
action match {
case a: AddFileForCDF => addFiles.append(a)
Expand Down Expand Up @@ -344,16 +361,19 @@ private[spark] class DeltaSharingRestClient(

val target = getTargetUrl(
s"/shares/$encodedShare/schemas/$encodedSchema/tables/$encodedTable/changes?$encodedParams")
val (version, lines) = getNDJson(target, requireVersion = false)
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
val (version, lines) = getNDJson(
target, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled
)
val (filteredLines, _) = maybeExtractEndStreamAction(lines)
val protocol = JsonUtils.fromJson[SingleAction](filteredLines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
val metadata = JsonUtils.fromJson[SingleAction](filteredLines(1)).metaData

val addFiles = ArrayBuffer[AddFileForCDF]()
val cdfFiles = ArrayBuffer[AddCDCFile]()
val removeFiles = ArrayBuffer[RemoveFile]()
val additionalMetadatas = ArrayBuffer[Metadata]()
lines.drop(2).foreach { line =>
filteredLines.drop(2).foreach { line =>
val action = JsonUtils.fromJson[SingleAction](line).unwrap
action match {
case c: AddCDCFile => cdfFiles.append(c)
Expand Down Expand Up @@ -398,30 +418,47 @@ private[spark] class DeltaSharingRestClient(
}.mkString("&")
}

private def getNDJson(target: String, requireVersion: Boolean = true): (Long, Seq[String]) = {
val (version, lines) = getResponse(new HttpGet(target))
private def getNDJson(
target: String,
requireVersion: Boolean,
setIncludeEndStreamAction: Boolean): (Long, Seq[String]) = {
val (version, lines) = getResponse(
new HttpGet(target), setIncludeEndStreamAction = setIncludeEndStreamAction
)
version.getOrElse {
if (requireVersion) {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
throw new IllegalStateException(s"Cannot find $RESPONSE_TABLE_VERSION_HEADER_KEY in the " +
s"header")
} else {
0L
}
} -> lines
}

private def getNDJson[T: Manifest](target: String, data: T): (Long, Seq[String]) = {
private def getNDJsonPost[T: Manifest](
target: String,
data: T,
setIncludeEndStreamAction: Boolean): (Long, Seq[String]) = {
val httpPost = new HttpPost(target)
val json = JsonUtils.toJson(data)
httpPost.setHeader("Content-type", "application/json")
httpPost.setEntity(new StringEntity(json, UTF_8))
val (version, lines) = getResponse(httpPost)
val (version, lines) = getResponse(
httpPost, setIncludeEndStreamAction = setIncludeEndStreamAction
)
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
throw new IllegalStateException(s"Cannot find $RESPONSE_TABLE_VERSION_HEADER_KEY in the " +
s"header")
} -> lines
}

private def getJson[R: Manifest](target: String): R = {
val (_, response) = getResponse(new HttpGet(target), false, true)
val (_, response) = getResponse(
new HttpGet(target),
allowNoContent = false,
fetchAsOneString = true,
setIncludeEndStreamAction = false
)
if (response.size != 1) {
throw new IllegalStateException(
"Unexpected response for target: " + target + ", response=" + response
Expand Down Expand Up @@ -452,7 +489,8 @@ private[spark] class DeltaSharingRestClient(
}
}

private[spark] def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = {
private[spark] def prepareHeaders(
httpRequest: HttpRequestBase, setIncludeEndStreamAction: Boolean): HttpRequestBase = {
val customeHeaders = profileProvider.getCustomHeaders
if (customeHeaders.contains(HttpHeaders.AUTHORIZATION)
|| customeHeaders.contains(HttpHeaders.USER_AGENT)) {
Expand All @@ -464,7 +502,7 @@ private[spark] class DeltaSharingRestClient(
val headers = Map(
HttpHeaders.AUTHORIZATION -> s"Bearer ${profileProvider.getProfile.bearerToken}",
HttpHeaders.USER_AGENT -> getUserAgent()
) ++ customeHeaders
) ++ customeHeaders ++ constructDeltaSharingCapabilities(setIncludeEndStreamAction)
headers.foreach(header => httpRequest.setHeader(header._1, header._2))

httpRequest
Expand All @@ -482,15 +520,16 @@ private[spark] class DeltaSharingRestClient(
private def getResponse(
httpRequest: HttpRequestBase,
allowNoContent: Boolean = false,
fetchAsOneString: Boolean = false
fetchAsOneString: Boolean = false,
setIncludeEndStreamAction: Boolean = false
): (Option[Long], Seq[String]) = {
// Reset queryId before calling RetryUtils, and before prepareHeaders.
queryId = Some(UUID.randomUUID().toString().split('-').head)
// Reset dsQueryId before calling RetryUtils, and before prepareHeaders.
dsQueryId = Some(UUID.randomUUID().toString().split('-').head)
RetryUtils.runWithExponentialBackoff(numRetries, maxRetryDuration) {
val profile = profileProvider.getProfile
val response = client.execute(
getHttpHost(profile.endpoint),
prepareHeaders(httpRequest),
prepareHeaders(httpRequest, setIncludeEndStreamAction),
HttpClientContext.create()
)
try {
Expand Down Expand Up @@ -541,7 +580,15 @@ private[spark] class DeltaSharingRestClient(
s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo",
statusCode)
}
Option(response.getFirstHeader("Delta-Table-Version")).map(_.getValue.toLong) -> lines
if (setIncludeEndStreamAction) {
val capabilities = Option(
response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER)
).map(_.getValue)
val capabilitiesMap = parseDeltaSharingCapabilities(capabilities)
checkEndStreamAction(capabilities, capabilitiesMap, lines)
}
Option(response.getFirstHeader(RESPONSE_TABLE_VERSION_HEADER_KEY)).map(
_.getValue.toLong) -> lines
} finally {
response.close()
}
Expand All @@ -560,7 +607,69 @@ private[spark] class DeltaSharingRestClient(
}

private def getQueryIdString: String = {
s"QueryId-${queryId.getOrElse("not_set")}"
s"QueryId-${dsQueryId.getOrElse("not_set")}"
}

private def checkEndStreamAction(
capabilities: Option[String],
capabilitiesMap: Map[String, String],
lines: Seq[String]): Unit = {
val includeEndStreamActionHeader = getRespondedIncludeEndStreamActionHeader(capabilitiesMap)
includeEndStreamActionHeader match {
case Some(true) =>
val lastLineAction = JsonUtils.fromJson[SingleAction](lines.last)
if (lastLineAction.endStreamAction == null) {
throw new MissingEndStreamActionException(s"Client sets " +
s"${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " +
s"header, server responded with the header set to true(${capabilities}, " +
s"and ${lines.size} lines, and last line parsed as " +
s"${lastLineAction.unwrap.getClass()}," + getDsQueryIdForLogging)
}
logInfo(
s"Successfully verified endStreamAction in the response" + getDsQueryIdForLogging
)
case Some(false) =>
logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " +
s"header, but the server responded with the header set to false(" +
s"${capabilities})," + getDsQueryIdForLogging
)
case None =>
logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the" +
s" header, but server didn't respond with the header(${capabilities}), " +
getDsQueryIdForLogging
)
}
}

// includeEndStreamActionHeader indicates whether the last line is required to be an
// EndStreamAction, parsed from the response header.
private def getRespondedIncludeEndStreamActionHeader(
capabilitiesMap: Map[String, String]): Option[Boolean] = {
capabilitiesMap.get(DELTA_SHARING_INCLUDE_END_STREAM_ACTION).map(_.toBoolean)
}

private def parseDeltaSharingCapabilities(capabilities: Option[String]): Map[String, String] = {
if (capabilities.isEmpty) {
return Map.empty[String, String]
}
capabilities.get.toLowerCase().split(DELTA_SHARING_CAPABILITIES_DELIMITER)
.map(_.split("="))
.filter(_.size == 2)
.map { splits =>
(splits(0), splits(1))
}.toMap
}

// The value for delta-sharing-capabilities header, semicolon separated capabilities.
// Each capability is in the format of "key=value1,value2", values are separated by comma.
// Example: "capability1=value1;capability2=value3,value4,value5"
private def constructDeltaSharingCapabilities(
setIncludeEndStreamAction: Boolean): Map[String, String] = {
if (setIncludeEndStreamAction) {
Map(DELTA_SHARING_CAPABILITIES_HEADER -> s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true")
} else {
Map.empty[String, String]
}
}

def close(): Unit = {
Expand All @@ -578,6 +687,10 @@ private[spark] object DeltaSharingRestClient extends Logging {
val CURRENT = 1

val SPARK_STRUCTURED_STREAMING = "Delta-Sharing-SparkStructuredStreaming"
val RESPONSE_TABLE_VERSION_HEADER_KEY = "Delta-Table-Version"
val DELTA_SHARING_CAPABILITIES_HEADER = "delta-sharing-capabilities"
val DELTA_SHARING_INCLUDE_END_STREAM_ACTION = "includeendstreamaction"
val DELTA_SHARING_CAPABILITIES_DELIMITER = ";"

lazy val USER_AGENT = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package io.delta.sharing.spark

import org.apache.spark.sql.types.StructType

class MissingEndStreamActionException(message: String) extends IllegalStateException(message)


object DeltaSharingErrors {
def nonExistentDeltaSharingTable(tableId: String): Throwable = {
new IllegalStateException(s"Delta sharing table ${tableId} doesn't exist. " +
Expand Down
13 changes: 10 additions & 3 deletions spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ private[sharing] object RemoteDeltaLog {
val numRetries = ConfUtils.numRetries(sqlConf)
val maxRetryDurationMillis = ConfUtils.maxRetryDurationMillis(sqlConf)
val timeoutInSeconds = ConfUtils.timeoutInSeconds(sqlConf)
val endStreamActionEnabled = ConfUtils.includeEndStreamAction(sqlConf)

val clientClass =
sqlConf.getConfString("spark.delta.sharing.client.class",
Expand All @@ -172,13 +173,19 @@ private[sharing] object RemoteDeltaLog {
val client: DeltaSharingClient =
Class.forName(clientClass)
.getConstructor(classOf[DeltaSharingProfileProvider],
classOf[Int], classOf[Int], classOf[Long], classOf[Boolean], classOf[Boolean])
.newInstance(profileProvider,
classOf[Int],
classOf[Int],
classOf[Long],
classOf[Boolean],
classOf[Boolean],
classOf[Boolean]
).newInstance(profileProvider,
java.lang.Integer.valueOf(timeoutInSeconds),
java.lang.Integer.valueOf(numRetries),
java.lang.Long.valueOf(maxRetryDurationMillis),
java.lang.Boolean.valueOf(sslTrustAll),
java.lang.Boolean.valueOf(forStreaming))
java.lang.Boolean.valueOf(forStreaming),
java.lang.Boolean.valueOf(endStreamActionEnabled))
.asInstanceOf[DeltaSharingClient]
new RemoteDeltaLog(deltaSharingTable, new Path(path + getFormattedTimestampWithUUID), client)
}
Expand Down
11 changes: 11 additions & 0 deletions spark/src/main/scala/io/delta/sharing/spark/util/ConfUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ object ConfUtils {
val MAX_RETRY_DURATION_CONF = "spark.delta.sharing.network.maxRetryDuration"
val MAX_RETRY_DURATION_DEFAULT_MILLIS = 10L * 60L* 1000L /* 10 mins */

val INCLUDE_END_STREAM_ACTION_CONF = "spark.delta.sharing.query.includeEndStreamAction"
val INCLUDE_END_STREAM_ACTION_DEFAULT = "true"

val TIMEOUT_CONF = "spark.delta.sharing.network.timeout"
val TIMEOUT_DEFAULT = "320s"

Expand Down Expand Up @@ -65,6 +68,14 @@ object ConfUtils {
maxDur
}

def includeEndStreamAction(conf: Configuration): Boolean = {
conf.getBoolean(INCLUDE_END_STREAM_ACTION_CONF, INCLUDE_END_STREAM_ACTION_DEFAULT.toBoolean)
}

def includeEndStreamAction(conf: SQLConf): Boolean = {
conf.getConfString(INCLUDE_END_STREAM_ACTION_CONF, INCLUDE_END_STREAM_ACTION_DEFAULT).toBoolean
}

def timeoutInSeconds(conf: Configuration): Int = {
val timeoutStr = conf.get(TIMEOUT_CONF, TIMEOUT_DEFAULT)
toTimeInSeconds(timeoutStr, TIMEOUT_CONF)
Expand Down
Loading

0 comments on commit a68fe09

Please sign in to comment.