Skip to content

Commit

Permalink
Backport client changes
Browse files Browse the repository at this point in the history
  • Loading branch information
linzhou-db committed Oct 10, 2024
1 parent 8f75b13 commit 0b97eb4
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 66 deletions.
179 changes: 141 additions & 38 deletions client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.SparkSession

import io.delta.sharing.client.model._
import io.delta.sharing.client.util.{ConfUtils, JsonUtils, RetryUtils, UnexpectedHttpStatus}
import io.delta.sharing.spark.MissingEndStreamActionException

/** An interface to fetch Delta metadata from remote server. */
trait DeltaSharingClient {
Expand Down Expand Up @@ -178,7 +179,8 @@ class DeltaSharingRestClient(
asyncQueryMaxDuration: Long = 600000L
) extends DeltaSharingClient with Logging {

logError(s"DeltaSharingRestClient with enableAsyncQuery $enableAsyncQuery")
logInfo(s"DeltaSharingRestClient with endStreamActionEnabled: $endStreamActionEnabled, " +
s"enableAsyncQuery:$enableAsyncQuery")

import DeltaSharingRestClient._

Expand Down Expand Up @@ -276,7 +278,12 @@ class DeltaSharingRestClient(
val target =
getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" +
s"$encodedTableName/version$encodedParam")
val (version, _, _) = getResponse(new HttpGet(target), true, true)
val (version, _, _) = getResponse(
new HttpGet(target),
allowNoContent = true,
fetchAsOneString = true,
setIncludeEndStreamAction = false
)
version.getOrElse {
throw new IllegalStateException(s"Cannot find " +
s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header," + getDsQueryIdForLogging)
Expand Down Expand Up @@ -310,7 +317,11 @@ class DeltaSharingRestClient(
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/metadata" +
s"$encodedParams")
val (version, respondedFormat, lines) = getNDJson(target)
val (version, respondedFormat, lines) = getNDJson(
target,
requireVersion = true,
setIncludeEndStreamAction = false
)

checkRespondedFormat(
respondedFormat,
Expand Down Expand Up @@ -457,7 +468,13 @@ class DeltaSharingRestClient(
val (version, respondedFormat, lines, _) = getFilesByPage(table, target, request)
(version, respondedFormat, lines)
} else {
getNDJson(target, request)
val (version, respondedFormat, lines) = getNDJsonPost(
target,
request,
setIncludeEndStreamAction = true
)
val (filteredLines, _) = maybeExtractEndStreamAction(lines)
(version, respondedFormat, filteredLines)
}

checkRespondedFormat(
Expand Down Expand Up @@ -509,13 +526,17 @@ class DeltaSharingRestClient(
val (version, respondedFormat, lines, queryIdOpt) = if (enableAsyncQuery) {
getNDJsonWithAsync(table, targetUrl, request)
} else {
val (version, respondedFormat, lines) = getNDJson(targetUrl, request)
val (version, respondedFormat, lines) = getNDJsonPost(
targetUrl, request, setIncludeEndStreamAction = endStreamActionEnabled
)
(version, respondedFormat, lines, None)
}

var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines)
if (endStreamAction.isEmpty) {
logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging)
logWarning(
"EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging
)
}

val protocol = filteredLines(0)
Expand Down Expand Up @@ -566,7 +587,9 @@ class DeltaSharingRestClient(
allLines.appendAll(res._1)
endStreamAction = res._2
if (endStreamAction.isEmpty) {
logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging)
logWarning(
"EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging
)
}
// Throw an error if the first page is expiring before we get all pages
if (minUrlExpirationTimestamp.exists(_ <= System.currentTimeMillis())) {
Expand Down Expand Up @@ -601,7 +624,11 @@ class DeltaSharingRestClient(
)
getCDFFilesByPage(target)
} else {
getNDJson(target, requireVersion = false)
val (version, respondedFormat, lines) = getNDJson(
target, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled
)
val (filteredLines, _) = maybeExtractEndStreamAction(lines)
(version, respondedFormat, filteredLines)
}

checkRespondedFormat(
Expand Down Expand Up @@ -654,10 +681,14 @@ class DeltaSharingRestClient(

// Fetch first page
var updatedUrl = s"$targetUrl&maxFiles=$maxFilesPerReq"
val (version, respondedFormat, lines) = getNDJson(updatedUrl, requireVersion = false)
val (version, respondedFormat, lines) = getNDJson(
updatedUrl, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled
)
var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines)
if (endStreamAction.isEmpty) {
logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging)
logWarning(
"EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging
)
}
val protocol = filteredLines(0)
val metadata = filteredLines(1)
Expand Down Expand Up @@ -685,7 +716,9 @@ class DeltaSharingRestClient(
allLines.appendAll(res._1)
endStreamAction = res._2
if (endStreamAction.isEmpty) {
logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging)
logWarning(
"EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging
)
}
// Throw an error if the first page is expiring before we get all pages
if (minUrlExpirationTimestamp.exists(_ <= System.currentTimeMillis())) {
Expand Down Expand Up @@ -715,9 +748,13 @@ class DeltaSharingRestClient(
pageNumber: Int): (Seq[String], Option[EndStreamAction]) = {
val start = System.currentTimeMillis()
val (version, respondedFormat, lines) = if (requestBody.isDefined) {
getNDJson(targetUrl, requestBody.get)
getNDJsonPost(targetUrl, requestBody.get, setIncludeEndStreamAction = endStreamActionEnabled)
} else {
getNDJson(targetUrl, requireVersion = false)
getNDJson(
targetUrl,
requireVersion = false,
setIncludeEndStreamAction = endStreamActionEnabled
)
}
logInfo(s"Took ${System.currentTimeMillis() - start} to fetch ${pageNumber}th page " +
s"of ${lines.size} lines," + getDsQueryIdForLogging)
Expand Down Expand Up @@ -781,8 +818,12 @@ class DeltaSharingRestClient(
}

private def getNDJson(
target: String, requireVersion: Boolean = true): (Long, String, Seq[String]) = {
val (version, capabilities, lines) = getResponse(new HttpGet(target))
target: String,
requireVersion: Boolean,
setIncludeEndStreamAction: Boolean): (Long, String, Seq[String]) = {
val (version, capabilitiesMap, lines) = getResponse(
new HttpGet(target), setIncludeEndStreamAction = setIncludeEndStreamAction
)
(
version.getOrElse {
if (requireVersion) {
Expand All @@ -792,7 +833,7 @@ class DeltaSharingRestClient(
0L
}
},
getRespondedFormat(capabilities),
getRespondedFormat(capabilitiesMap),
lines
)
}
Expand All @@ -818,7 +859,7 @@ class DeltaSharingRestClient(
maxFiles = maxFiles,
pageToken = pageToken)

getNDJson(target, request)
getNDJsonPost(target, request, setIncludeEndStreamAction = false)
}

/*
Expand Down Expand Up @@ -856,7 +897,9 @@ class DeltaSharingRestClient(
target: String,
request: QueryTableRequest): (Long, String, Seq[String], Option[String]) = {
// Initial query to get NDJson data
val (initialVersion, initialRespondedFormat, initialLines) = getNDJson(target, request)
val (initialVersion, initialRespondedFormat, initialLines) = getNDJsonPost(
target = target, data = request, setIncludeEndStreamAction = false
)

// Check if the query is still pending
var (lines, queryIdOpt, queryPending) = checkQueryPending(initialLines)
Expand Down Expand Up @@ -894,28 +937,70 @@ class DeltaSharingRestClient(
(version, respondedFormat, lines, queryIdOpt)
}


private def getNDJson[T: Manifest](target: String, data: T): (Long, String, Seq[String]) = {
private def getNDJsonPost[T: Manifest](
target: String,
data: T,
setIncludeEndStreamAction: Boolean): (Long, String, Seq[String]) = {
val httpPost = new HttpPost(target)
val json = JsonUtils.toJson(data)
httpPost.setHeader("Content-type", "application/json")
httpPost.setEntity(new StringEntity(json, UTF_8))
val (version, capabilities, lines) = getResponse(httpPost)
val (version, capabilitiesMap, lines) = getResponse(
httpPost, setIncludeEndStreamAction = setIncludeEndStreamAction
)
(
version.getOrElse {
throw new IllegalStateException(
"Cannot find Delta-Table-Version in the header," + getDsQueryIdForLogging)
},
getRespondedFormat(capabilities),
getRespondedFormat(capabilitiesMap),
lines
)
}

private def getRespondedFormat(capabilities: Option[String]): String = {
val capabilitiesMap = getDeltaSharingCapabilitiesMap(capabilities)
private def checkEndStreamAction(
capabilities: Option[String],
capabilitiesMap: Map[String, String],
lines: Seq[String]): Unit = {
val includeEndStreamActionHeader = getRespondedIncludeEndStreamActionHeader(capabilitiesMap)
includeEndStreamActionHeader match {
case Some(true) =>
val lastLineAction = JsonUtils.fromJson[SingleAction](lines.last)
if (lastLineAction.endStreamAction == null) {
throw new MissingEndStreamActionException(s"Client sets " +
s"${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " +
s"header, server responded with the header set to true(${capabilities}, " +
s"and ${lines.size} lines, and last line parsed as " +
s"${lastLineAction.unwrap.getClass()}," + getDsQueryIdForLogging)
}
logInfo(
s"Successfully verified endStreamAction in the response" + getDsQueryIdForLogging
)
case Some(false) =>
logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " +
s"header, but the server responded with the header set to false(" +
s"${capabilities})," + getDsQueryIdForLogging
)
case None =>
logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the" +
s" header, but server didn't respond with the header(${capabilities}), " +
s"for query($dsQueryId)."
)
}
}

private def getRespondedFormat(capabilitiesMap: Map[String, String]): String = {
capabilitiesMap.get(RESPONSE_FORMAT).getOrElse(RESPONSE_FORMAT_PARQUET)
}
private def getDeltaSharingCapabilitiesMap(capabilities: Option[String]): Map[String, String] = {

// includeEndStreamActionHeader indicates whether the last line is required to be an
// EndStreamAction, parsed from the response header.
private def getRespondedIncludeEndStreamActionHeader(
capabilitiesMap: Map[String, String]): Option[Boolean] = {
capabilitiesMap.get(DELTA_SHARING_INCLUDE_END_STREAM_ACTION).map(_.toBoolean)
}

private def parseDeltaSharingCapabilities(capabilities: Option[String]): Map[String, String] = {
if (capabilities.isEmpty) {
return Map.empty[String, String]
}
Expand All @@ -928,7 +1013,12 @@ class DeltaSharingRestClient(
}

private def getJson[R: Manifest](target: String): R = {
val (_, _, response) = getResponse(new HttpGet(target), false, true)
val (_, _, response) = getResponse(
new HttpGet(target),
allowNoContent = false,
fetchAsOneString = true,
setIncludeEndStreamAction = false
)
if (response.size != 1) {
throw new IllegalStateException(
s"Unexpected response for target: $target, response=$response," + getDsQueryIdForLogging
Expand Down Expand Up @@ -959,7 +1049,8 @@ class DeltaSharingRestClient(
}
}

private[client] def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = {
private[client] def prepareHeaders(
httpRequest: HttpRequestBase, setIncludeEndStreamAction: Boolean): HttpRequestBase = {
val customeHeaders = profileProvider.getCustomHeaders
if (customeHeaders.contains(HttpHeaders.AUTHORIZATION)
|| customeHeaders.contains(HttpHeaders.USER_AGENT)) {
Expand All @@ -971,7 +1062,9 @@ class DeltaSharingRestClient(
val headers = Map(
HttpHeaders.AUTHORIZATION -> s"Bearer ${profileProvider.getProfile.bearerToken}",
HttpHeaders.USER_AGENT -> getUserAgent(),
DELTA_SHARING_CAPABILITIES_HEADER -> getDeltaSharingCapabilities()
DELTA_SHARING_CAPABILITIES_HEADER -> constructDeltaSharingCapabilities(
setIncludeEndStreamAction
)
) ++ customeHeaders
headers.foreach(header => httpRequest.setHeader(header._1, header._2))

Expand All @@ -990,15 +1083,16 @@ class DeltaSharingRestClient(
private def getResponse(
httpRequest: HttpRequestBase,
allowNoContent: Boolean = false,
fetchAsOneString: Boolean = false
): (Option[Long], Option[String], Seq[String]) = {
fetchAsOneString: Boolean = false,
setIncludeEndStreamAction: Boolean = false
): (Option[Long], Map[String, String], Seq[String]) = {
// Reset dsQueryId before calling RetryUtils, and before prepareHeaders.
dsQueryId = Some(UUID.randomUUID().toString().split('-').head)
RetryUtils.runWithExponentialBackoff(numRetries, maxRetryDuration) {
val profile = profileProvider.getProfile
val response = client.execute(
getHttpHost(profile.endpoint),
prepareHeaders(httpRequest),
prepareHeaders(httpRequest, setIncludeEndStreamAction),
HttpClientContext.create()
)
try {
Expand Down Expand Up @@ -1047,17 +1141,22 @@ class DeltaSharingRestClient(
// Only show the last 100 lines in the error to keep it contained.
val responseToShow = lines.drop(lines.size - 100).mkString("\n")
throw new UnexpectedHttpStatus(
s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo"
+ getDsQueryIdForLogging,
s"HTTP request failed with status: $status" +
Seq(getDsQueryIdForLogging, additionalErrorInfo, responseToShow).mkString(" "),
statusCode)
}
val capabilities = Option(
response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER)
).map(_.getValue)
val capabilitiesMap = parseDeltaSharingCapabilities(capabilities)
if (setIncludeEndStreamAction) {
checkEndStreamAction(capabilities, capabilitiesMap, lines)
}
(
Option(
response.getFirstHeader(RESPONSE_TABLE_VERSION_HEADER_KEY)
).map(_.getValue.toLong),
Option(
response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER)
).map(_.getValue),
capabilitiesMap,
lines
)
} finally {
Expand All @@ -1084,7 +1183,7 @@ class DeltaSharingRestClient(
// The value for delta-sharing-capabilities header, semicolon separated capabilities.
// Each capability is in the format of "key=value1,value2", values are separated by comma.
// Example: "capability1=value1;capability2=value3,value4,value5"
private def getDeltaSharingCapabilities(): String = {
private def constructDeltaSharingCapabilities(setIncludeEndStreamAction: Boolean): String = {
var capabilities = Seq[String](s"${RESPONSE_FORMAT}=$responseFormat")
if (responseFormatSet.contains(RESPONSE_FORMAT_DELTA) && readerFeatures.nonEmpty) {
capabilities = capabilities :+ s"$READER_FEATURES=$readerFeatures"
Expand All @@ -1094,6 +1193,10 @@ class DeltaSharingRestClient(
capabilities = capabilities :+ s"$DELTA_SHARING_CAPABILITIES_ASYNC_READ=true"
}

if (setIncludeEndStreamAction) {
capabilities = capabilities :+ s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true"
}

val cap = capabilities.mkString(DELTA_SHARING_CAPABILITIES_DELIMITER)
cap
}
Expand All @@ -1116,10 +1219,10 @@ object DeltaSharingRestClient extends Logging {
val RESPONSE_FORMAT = "responseformat"
val READER_FEATURES = "readerfeatures"
val DELTA_SHARING_CAPABILITIES_ASYNC_READ = "asyncquery"
val DELTA_SHARING_INCLUDE_END_STREAM_ACTION = "includeendstreamaction"
val RESPONSE_FORMAT_DELTA = "delta"
val RESPONSE_FORMAT_PARQUET = "parquet"
val DELTA_SHARING_CAPABILITIES_DELIMITER = ";"
val QUERY_PENDING_TRUE = "pending"

lazy val USER_AGENT = {
try {
Expand Down
Loading

0 comments on commit 0b97eb4

Please sign in to comment.