Skip to content

Commit

Permalink
Do not set includeEndStreamAction for getTableVersion/getTableMetadata
Browse files Browse the repository at this point in the history
  • Loading branch information
linzhou-db committed Oct 8, 2024
1 parent 7c15178 commit e262f9c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,12 @@ class DeltaSharingRestClient(
val target =
getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" +
s"$encodedTableName/version$encodedParam")
val (version, _, _) = getResponse(new HttpGet(target), true, true)
val (version, _, _) = getResponse(
new HttpGet(target),
allowNoContent = true,
fetchAsOneString = true,
setIncludeEndStreamAction = false
)
version.getOrElse {
throw new IllegalStateException(s"Cannot find " +
s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header," + getDsQueryIdForLogging)
Expand Down Expand Up @@ -336,7 +341,7 @@ class DeltaSharingRestClient(
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/metadata" +
s"$encodedParams")
val response = getNDJson(target, requireVersion = true, checkEndStreamActionHeader = false)
val response = getNDJson(target, requireVersion = true, setIncludeEndStreamAction = false)

checkRespondedFormat(
response.respondedFormat,
Expand Down Expand Up @@ -495,7 +500,7 @@ class DeltaSharingRestClient(
(version, respondedFormat, lines)
} else {
val response = getNDJsonPost(
target = target, data = request, checkEndStreamActionHeader = true
target = target, data = request, setIncludeEndStreamAction = true
)
val (filteredLines, _) = maybeExtractEndStreamAction(response.lines)
(response.version, response.respondedFormat, filteredLines)
Expand Down Expand Up @@ -551,7 +556,7 @@ class DeltaSharingRestClient(
getNDJsonWithAsync(table, targetUrl, request)
} else {
val response = getNDJsonPost(
target = targetUrl, data = request, checkEndStreamActionHeader = true
target = targetUrl, data = request, setIncludeEndStreamAction = true
)
(response.version, response.respondedFormat, response.lines, None)
}
Expand Down Expand Up @@ -607,8 +612,8 @@ class DeltaSharingRestClient(
expectedProtocol = protocol,
expectedMetadata = metadata,
pageNumber = numPages,
// EndStreamAction is not supported for async queries yet.
checkEndStreamActionHeader = !enableAsyncQuery
// Do not set EndStreamAction for async queries yet, and set it for sync queries.
setIncludeEndStreamAction = !enableAsyncQuery
)
allLines.appendAll(res._1)
endStreamAction = res._2
Expand Down Expand Up @@ -651,7 +656,7 @@ class DeltaSharingRestClient(
)
getCDFFilesByPage(target)
} else {
val response = getNDJson(target, requireVersion = false, checkEndStreamActionHeader = true)
val response = getNDJson(target, requireVersion = false, setIncludeEndStreamAction = true)
val (filteredLines, _) = maybeExtractEndStreamAction(response.lines)
(response.version, response.respondedFormat, filteredLines)
}
Expand Down Expand Up @@ -706,7 +711,7 @@ class DeltaSharingRestClient(

// Fetch first page
var updatedUrl = s"$targetUrl&maxFiles=$maxFilesPerReq"
val response = getNDJson(updatedUrl, requireVersion = false, checkEndStreamActionHeader = true)
val response = getNDJson(updatedUrl, requireVersion = false, setIncludeEndStreamAction = true)
var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(response.lines)
if (endStreamAction.isEmpty) {
logWarning(
Expand Down Expand Up @@ -735,7 +740,7 @@ class DeltaSharingRestClient(
expectedProtocol = protocol,
expectedMetadata = metadata,
pageNumber = numPages,
checkEndStreamActionHeader = true
setIncludeEndStreamAction = true
)
allLines.appendAll(res._1)
endStreamAction = res._2
Expand Down Expand Up @@ -771,19 +776,19 @@ class DeltaSharingRestClient(
expectedProtocol: String,
expectedMetadata: String,
pageNumber: Int,
checkEndStreamActionHeader: Boolean): (Seq[String], Option[EndStreamAction]) = {
setIncludeEndStreamAction: Boolean): (Seq[String], Option[EndStreamAction]) = {
val start = System.currentTimeMillis()
val response = if (requestBody.isDefined) {
getNDJsonPost(
target = targetUrl,
data = requestBody.get,
checkEndStreamActionHeader = checkEndStreamActionHeader
setIncludeEndStreamAction = setIncludeEndStreamAction
)
} else {
getNDJson(
target = targetUrl,
requireVersion = false,
checkEndStreamActionHeader = checkEndStreamActionHeader)
setIncludeEndStreamAction = setIncludeEndStreamAction)
}
logInfo(s"Took ${System.currentTimeMillis() - start} to fetch ${pageNumber}th page " +
s"of ${response.lines.size} lines," + getDsQueryIdForLogging)
Expand Down Expand Up @@ -850,8 +855,10 @@ class DeltaSharingRestClient(
private def getNDJson(
target: String,
requireVersion: Boolean,
checkEndStreamActionHeader: Boolean): ParsedDeltaSharingResponse = {
val (version, capabilities, lines) = getResponse(new HttpGet(target))
setIncludeEndStreamAction: Boolean): ParsedDeltaSharingResponse = {
val (version, capabilities, lines) = getResponse(
new HttpGet(target), setIncludeEndStreamAction = setIncludeEndStreamAction
)
val (respondedFormat, includeEndStreamActionHeader) = getRespondedHeaders(capabilities)

val response = ParsedDeltaSharingResponse(
Expand All @@ -868,7 +875,7 @@ class DeltaSharingRestClient(
lines,
capabilities = capabilities
)
if (checkEndStreamActionHeader) {
if (setIncludeEndStreamAction) {
checkEndStreamAction(response)
}
response
Expand Down Expand Up @@ -896,7 +903,7 @@ class DeltaSharingRestClient(
pageToken = pageToken)

val response = getNDJsonPost(
target = target, data = request, checkEndStreamActionHeader = false
target = target, data = request, setIncludeEndStreamAction = false
)
(response.version, response.respondedFormat, response.lines)
}
Expand Down Expand Up @@ -937,7 +944,7 @@ class DeltaSharingRestClient(
request: QueryTableRequest): (Long, String, Seq[String], Option[String]) = {
// Initial query to get NDJson data
val response = getNDJsonPost(
target = target, data = request, checkEndStreamActionHeader = false
target = target, data = request, setIncludeEndStreamAction = false
)

// Check if the query is still pending
Expand Down Expand Up @@ -979,12 +986,14 @@ class DeltaSharingRestClient(
private def getNDJsonPost[T: Manifest](
target: String,
data: T,
checkEndStreamActionHeader: Boolean): ParsedDeltaSharingResponse = {
setIncludeEndStreamAction: Boolean): ParsedDeltaSharingResponse = {
val httpPost = new HttpPost(target)
val json = JsonUtils.toJson(data)
httpPost.setHeader("Content-type", "application/json")
httpPost.setEntity(new StringEntity(json, UTF_8))
val (version, capabilities, lines) = getResponse(httpPost)
val (version, capabilities, lines) = getResponse(
httpPost, setIncludeEndStreamAction = setIncludeEndStreamAction
)
val (respondedFormat, includeEndStreamActionHeader) = getRespondedHeaders(capabilities)

val response = ParsedDeltaSharingResponse(
Expand All @@ -998,7 +1007,7 @@ class DeltaSharingRestClient(
lines,
capabilities = capabilities
)
if (checkEndStreamActionHeader) {
if (setIncludeEndStreamAction) {
checkEndStreamAction(response)
}
response
Expand Down Expand Up @@ -1058,7 +1067,12 @@ class DeltaSharingRestClient(
}

private def getJson[R: Manifest](target: String): R = {
val (_, _, response) = getResponse(new HttpGet(target), false, true)
val (_, _, response) = getResponse(
new HttpGet(target),
allowNoContent = false,
fetchAsOneString = true,
setIncludeEndStreamAction = false
)
if (response.size != 1) {
throw new IllegalStateException(
s"Unexpected response for target:$target, response=$response" + getDsQueryIdForLogging
Expand All @@ -1082,7 +1096,8 @@ class DeltaSharingRestClient(
authCredentialProvider.isExpired()
}

private[client] def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = {
private[client] def prepareHeaders(
httpRequest: HttpRequestBase, setIncludeEndStreamAction: Boolean): HttpRequestBase = {
val customeHeaders = profileProvider.getCustomHeaders
if (customeHeaders.contains(HttpHeaders.AUTHORIZATION)
|| customeHeaders.contains(HttpHeaders.USER_AGENT)) {
Expand All @@ -1093,7 +1108,9 @@ class DeltaSharingRestClient(
}
val headers = Map(
HttpHeaders.USER_AGENT -> getUserAgent(),
DELTA_SHARING_CAPABILITIES_HEADER -> constructDeltaSharingCapabilities()
DELTA_SHARING_CAPABILITIES_HEADER -> constructDeltaSharingCapabilities(
setIncludeEndStreamAction
)
) ++ customeHeaders
headers.foreach(header => httpRequest.setHeader(header._1, header._2))
authCredentialProvider.addAuthHeader(httpRequest)
Expand All @@ -1113,15 +1130,16 @@ class DeltaSharingRestClient(
private def getResponse(
httpRequest: HttpRequestBase,
allowNoContent: Boolean = false,
fetchAsOneString: Boolean = false
fetchAsOneString: Boolean = false,
setIncludeEndStreamAction: Boolean = false
): (Option[Long], Option[String], Seq[String]) = {
// Reset dsQueryId before calling RetryUtils, and before prepareHeaders.
dsQueryId = Some(UUID.randomUUID().toString().split('-').head)
RetryUtils.runWithExponentialBackoff(numRetries, maxRetryDuration) {
val profile = profileProvider.getProfile
val response = client.execute(
getHttpHost(profile.endpoint),
prepareHeaders(httpRequest),
prepareHeaders(httpRequest, setIncludeEndStreamAction = setIncludeEndStreamAction),
HttpClientContext.create()
)
try {
Expand Down Expand Up @@ -1207,7 +1225,7 @@ class DeltaSharingRestClient(
// The value for delta-sharing-capabilities header, semicolon separated capabilities.
// Each capability is in the format of "key=value1,value2", values are separated by comma.
// Example: "capability1=value1;capability2=value3,value4,value5"
private def constructDeltaSharingCapabilities(): String = {
private def constructDeltaSharingCapabilities(setIncludeEndStreamAction: Boolean): String = {
var capabilities = Seq[String](s"${RESPONSE_FORMAT}=$responseFormat")
if (responseFormatSet.contains(RESPONSE_FORMAT_DELTA) && readerFeatures.nonEmpty) {
capabilities = capabilities :+ s"$READER_FEATURES=$readerFeatures"
Expand All @@ -1217,7 +1235,7 @@ class DeltaSharingRestClient(
capabilities = capabilities :+ s"$DELTA_SHARING_CAPABILITIES_ASYNC_READ=true"
}

if (includeEndStreamAction) {
if (includeEndStreamAction && setIncludeEndStreamAction) {
capabilities = capabilities :+ s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ class DeltaSharingRestClientDeltaSuite extends DeltaSharingIntegrationTest {
val httpRequest = new HttpGet("random_url")

val client = new DeltaSharingRestClient(testProfileProvider)
var h = client.prepareHeaders(httpRequest).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER)
var h = client.prepareHeaders(httpRequest, setIncludeEndStreamAction = false).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER)
// scalastyle:off caselocale
assert(h.getValue.toLowerCase().contains(s"responseformat=${DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET}"))

val deltaClient = new DeltaSharingRestClient(
testProfileProvider,
responseFormat = DeltaSharingRestClient.RESPONSE_FORMAT_DELTA
)
h = deltaClient.prepareHeaders(httpRequest).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER)
h = deltaClient.prepareHeaders(httpRequest, setIncludeEndStreamAction = false).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER)
// scalastyle:off caselocale
assert(h.getValue.toLowerCase().contains(s"responseformat=${DeltaSharingRestClient.RESPONSE_FORMAT_DELTA}"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest {
testProfileProvider,
forStreaming = forStreaming,
includeEndStreamAction = includeEndStreamAction,
readerFeatures = "willBeIgnored").prepareHeaders(httpRequest)
readerFeatures = "willBeIgnored")
.prepareHeaders(httpRequest, setIncludeEndStreamAction = includeEndStreamAction)
checkUserAgent(client, forStreaming)
checkDeltaSharingCapabilities(client, "parquet", "", includeEndStreamAction)

Expand All @@ -123,7 +124,8 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest {
forStreaming = forStreaming,
includeEndStreamAction = includeEndStreamAction,
responseFormat = RESPONSE_FORMAT_DELTA,
readerFeatures = readerFeatures).prepareHeaders(httpRequest)
readerFeatures = readerFeatures)
.prepareHeaders(httpRequest, setIncludeEndStreamAction = includeEndStreamAction)
checkUserAgent(client, forStreaming)
checkDeltaSharingCapabilities(
client, "delta", s";$READER_FEATURES=$readerFeatures", includeEndStreamAction
Expand All @@ -134,7 +136,8 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest {
forStreaming = forStreaming,
includeEndStreamAction = includeEndStreamAction,
responseFormat = s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET",
readerFeatures = readerFeatures).prepareHeaders(httpRequest)
readerFeatures = readerFeatures)
.prepareHeaders(httpRequest, setIncludeEndStreamAction = includeEndStreamAction)
checkUserAgent(client, forStreaming)
checkDeltaSharingCapabilities(
client, s"delta,parquet", s";$READER_FEATURES=$readerFeatures", includeEndStreamAction
Expand Down

0 comments on commit e262f9c

Please sign in to comment.