From 0b97eb44e7e4b05797562fe226e6ebaf177390a6 Mon Sep 17 00:00:00 2001 From: Lin Zhou Date: Wed, 9 Oct 2024 17:32:40 -0700 Subject: [PATCH] Backport client changes --- .../sharing/client/DeltaSharingClient.scala | 179 ++++++++++++++---- .../sharing/client/util/RetryUtils.scala | 3 + .../sharing/spark/DeltaSharingErrors.scala | 2 + .../DeltaSharingRestClientDeltaSuite.scala | 4 +- .../client/DeltaSharingRestClientSuite.scala | 80 +++++--- .../sharing/client/util/RetryUtilsSuite.scala | 2 + .../sharing/spark/DeltaSharingSuite.scala | 3 +- 7 files changed, 207 insertions(+), 66 deletions(-) diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala index b6155e351..3c9b3695f 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.SparkSession import io.delta.sharing.client.model._ import io.delta.sharing.client.util.{ConfUtils, JsonUtils, RetryUtils, UnexpectedHttpStatus} +import io.delta.sharing.spark.MissingEndStreamActionException /** An interface to fetch Delta metadata from remote server. */ trait DeltaSharingClient { @@ -178,7 +179,8 @@ class DeltaSharingRestClient( asyncQueryMaxDuration: Long = 600000L ) extends DeltaSharingClient with Logging { - logError(s"DeltaSharingRestClient with enableAsyncQuery $enableAsyncQuery") + logInfo(s"DeltaSharingRestClient with endStreamActionEnabled: $endStreamActionEnabled, " + + s"enableAsyncQuery:$enableAsyncQuery") import DeltaSharingRestClient._ @@ -276,7 +278,12 @@ class DeltaSharingRestClient( val target = getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" + s"$encodedTableName/version$encodedParam") - val (version, _, _) = getResponse(new HttpGet(target), true, true) + val (version, _, _) = getResponse( + new HttpGet(target), + allowNoContent = true, + fetchAsOneString = true, + setIncludeEndStreamAction = false + ) version.getOrElse { throw new IllegalStateException(s"Cannot find " + s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header," + getDsQueryIdForLogging) @@ -310,7 +317,11 @@ class DeltaSharingRestClient( val target = getTargetUrl( s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/metadata" + s"$encodedParams") - val (version, respondedFormat, lines) = getNDJson(target) + val (version, respondedFormat, lines) = getNDJson( + target, + requireVersion = true, + setIncludeEndStreamAction = false + ) checkRespondedFormat( respondedFormat, @@ -457,7 +468,13 @@ class DeltaSharingRestClient( val (version, respondedFormat, lines, _) = getFilesByPage(table, target, request) (version, respondedFormat, lines) } else { - getNDJson(target, request) + val (version, respondedFormat, lines) = getNDJsonPost( + target, + request, + setIncludeEndStreamAction = true + ) + val (filteredLines, _) = maybeExtractEndStreamAction(lines) + (version, respondedFormat, filteredLines) } checkRespondedFormat( @@ -509,13 +526,17 @@ class DeltaSharingRestClient( val (version, respondedFormat, lines, queryIdOpt) = if (enableAsyncQuery) { getNDJsonWithAsync(table, targetUrl, request) } else { - val (version, respondedFormat, lines) = getNDJson(targetUrl, request) + val (version, respondedFormat, lines) = getNDJsonPost( + targetUrl, request, setIncludeEndStreamAction = endStreamActionEnabled + ) (version, respondedFormat, lines, None) } var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines) if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging) + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } val protocol = filteredLines(0) @@ -566,7 +587,9 @@ class DeltaSharingRestClient( allLines.appendAll(res._1) endStreamAction = res._2 if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging) + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } // Throw an error if the first page is expiring before we get all pages if (minUrlExpirationTimestamp.exists(_ <= System.currentTimeMillis())) { @@ -601,7 +624,11 @@ class DeltaSharingRestClient( ) getCDFFilesByPage(target) } else { - getNDJson(target, requireVersion = false) + val (version, respondedFormat, lines) = getNDJson( + target, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled + ) + val (filteredLines, _) = maybeExtractEndStreamAction(lines) + (version, respondedFormat, filteredLines) } checkRespondedFormat( @@ -654,10 +681,14 @@ class DeltaSharingRestClient( // Fetch first page var updatedUrl = s"$targetUrl&maxFiles=$maxFilesPerReq" - val (version, respondedFormat, lines) = getNDJson(updatedUrl, requireVersion = false) + val (version, respondedFormat, lines) = getNDJson( + updatedUrl, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled + ) var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines) if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging) + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } val protocol = filteredLines(0) val metadata = filteredLines(1) @@ -685,7 +716,9 @@ class DeltaSharingRestClient( allLines.appendAll(res._1) endStreamAction = res._2 if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response" + getDsQueryIdForLogging) + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } // Throw an error if the first page is expiring before we get all pages if (minUrlExpirationTimestamp.exists(_ <= System.currentTimeMillis())) { @@ -715,9 +748,13 @@ class DeltaSharingRestClient( pageNumber: Int): (Seq[String], Option[EndStreamAction]) = { val start = System.currentTimeMillis() val (version, respondedFormat, lines) = if (requestBody.isDefined) { - getNDJson(targetUrl, requestBody.get) + getNDJsonPost(targetUrl, requestBody.get, setIncludeEndStreamAction = endStreamActionEnabled) } else { - getNDJson(targetUrl, requireVersion = false) + getNDJson( + targetUrl, + requireVersion = false, + setIncludeEndStreamAction = endStreamActionEnabled + ) } logInfo(s"Took ${System.currentTimeMillis() - start} to fetch ${pageNumber}th page " + s"of ${lines.size} lines," + getDsQueryIdForLogging) @@ -781,8 +818,12 @@ class DeltaSharingRestClient( } private def getNDJson( - target: String, requireVersion: Boolean = true): (Long, String, Seq[String]) = { - val (version, capabilities, lines) = getResponse(new HttpGet(target)) + target: String, + requireVersion: Boolean, + setIncludeEndStreamAction: Boolean): (Long, String, Seq[String]) = { + val (version, capabilitiesMap, lines) = getResponse( + new HttpGet(target), setIncludeEndStreamAction = setIncludeEndStreamAction + ) ( version.getOrElse { if (requireVersion) { @@ -792,7 +833,7 @@ class DeltaSharingRestClient( 0L } }, - getRespondedFormat(capabilities), + getRespondedFormat(capabilitiesMap), lines ) } @@ -818,7 +859,7 @@ class DeltaSharingRestClient( maxFiles = maxFiles, pageToken = pageToken) - getNDJson(target, request) + getNDJsonPost(target, request, setIncludeEndStreamAction = false) } /* @@ -856,7 +897,9 @@ class DeltaSharingRestClient( target: String, request: QueryTableRequest): (Long, String, Seq[String], Option[String]) = { // Initial query to get NDJson data - val (initialVersion, initialRespondedFormat, initialLines) = getNDJson(target, request) + val (initialVersion, initialRespondedFormat, initialLines) = getNDJsonPost( + target = target, data = request, setIncludeEndStreamAction = false + ) // Check if the query is still pending var (lines, queryIdOpt, queryPending) = checkQueryPending(initialLines) @@ -894,28 +937,70 @@ class DeltaSharingRestClient( (version, respondedFormat, lines, queryIdOpt) } - - private def getNDJson[T: Manifest](target: String, data: T): (Long, String, Seq[String]) = { + private def getNDJsonPost[T: Manifest]( + target: String, + data: T, + setIncludeEndStreamAction: Boolean): (Long, String, Seq[String]) = { val httpPost = new HttpPost(target) val json = JsonUtils.toJson(data) httpPost.setHeader("Content-type", "application/json") httpPost.setEntity(new StringEntity(json, UTF_8)) - val (version, capabilities, lines) = getResponse(httpPost) + val (version, capabilitiesMap, lines) = getResponse( + httpPost, setIncludeEndStreamAction = setIncludeEndStreamAction + ) ( version.getOrElse { throw new IllegalStateException( "Cannot find Delta-Table-Version in the header," + getDsQueryIdForLogging) }, - getRespondedFormat(capabilities), + getRespondedFormat(capabilitiesMap), lines ) } - private def getRespondedFormat(capabilities: Option[String]): String = { - val capabilitiesMap = getDeltaSharingCapabilitiesMap(capabilities) + private def checkEndStreamAction( + capabilities: Option[String], + capabilitiesMap: Map[String, String], + lines: Seq[String]): Unit = { + val includeEndStreamActionHeader = getRespondedIncludeEndStreamActionHeader(capabilitiesMap) + includeEndStreamActionHeader match { + case Some(true) => + val lastLineAction = JsonUtils.fromJson[SingleAction](lines.last) + if (lastLineAction.endStreamAction == null) { + throw new MissingEndStreamActionException(s"Client sets " + + s"${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " + + s"header, server responded with the header set to true(${capabilities}, " + + s"and ${lines.size} lines, and last line parsed as " + + s"${lastLineAction.unwrap.getClass()}," + getDsQueryIdForLogging) + } + logInfo( + s"Successfully verified endStreamAction in the response" + getDsQueryIdForLogging + ) + case Some(false) => + logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " + + s"header, but the server responded with the header set to false(" + + s"${capabilities})," + getDsQueryIdForLogging + ) + case None => + logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the" + + s" header, but server didn't respond with the header(${capabilities}), " + + s"for query($dsQueryId)." + ) + } + } + + private def getRespondedFormat(capabilitiesMap: Map[String, String]): String = { capabilitiesMap.get(RESPONSE_FORMAT).getOrElse(RESPONSE_FORMAT_PARQUET) } - private def getDeltaSharingCapabilitiesMap(capabilities: Option[String]): Map[String, String] = { + + // includeEndStreamActionHeader indicates whether the last line is required to be an + // EndStreamAction, parsed from the response header. + private def getRespondedIncludeEndStreamActionHeader( + capabilitiesMap: Map[String, String]): Option[Boolean] = { + capabilitiesMap.get(DELTA_SHARING_INCLUDE_END_STREAM_ACTION).map(_.toBoolean) + } + + private def parseDeltaSharingCapabilities(capabilities: Option[String]): Map[String, String] = { if (capabilities.isEmpty) { return Map.empty[String, String] } @@ -928,7 +1013,12 @@ class DeltaSharingRestClient( } private def getJson[R: Manifest](target: String): R = { - val (_, _, response) = getResponse(new HttpGet(target), false, true) + val (_, _, response) = getResponse( + new HttpGet(target), + allowNoContent = false, + fetchAsOneString = true, + setIncludeEndStreamAction = false + ) if (response.size != 1) { throw new IllegalStateException( s"Unexpected response for target: $target, response=$response," + getDsQueryIdForLogging @@ -959,7 +1049,8 @@ class DeltaSharingRestClient( } } - private[client] def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = { + private[client] def prepareHeaders( + httpRequest: HttpRequestBase, setIncludeEndStreamAction: Boolean): HttpRequestBase = { val customeHeaders = profileProvider.getCustomHeaders if (customeHeaders.contains(HttpHeaders.AUTHORIZATION) || customeHeaders.contains(HttpHeaders.USER_AGENT)) { @@ -971,7 +1062,9 @@ class DeltaSharingRestClient( val headers = Map( HttpHeaders.AUTHORIZATION -> s"Bearer ${profileProvider.getProfile.bearerToken}", HttpHeaders.USER_AGENT -> getUserAgent(), - DELTA_SHARING_CAPABILITIES_HEADER -> getDeltaSharingCapabilities() + DELTA_SHARING_CAPABILITIES_HEADER -> constructDeltaSharingCapabilities( + setIncludeEndStreamAction + ) ) ++ customeHeaders headers.foreach(header => httpRequest.setHeader(header._1, header._2)) @@ -990,15 +1083,16 @@ class DeltaSharingRestClient( private def getResponse( httpRequest: HttpRequestBase, allowNoContent: Boolean = false, - fetchAsOneString: Boolean = false - ): (Option[Long], Option[String], Seq[String]) = { + fetchAsOneString: Boolean = false, + setIncludeEndStreamAction: Boolean = false + ): (Option[Long], Map[String, String], Seq[String]) = { // Reset dsQueryId before calling RetryUtils, and before prepareHeaders. dsQueryId = Some(UUID.randomUUID().toString().split('-').head) RetryUtils.runWithExponentialBackoff(numRetries, maxRetryDuration) { val profile = profileProvider.getProfile val response = client.execute( getHttpHost(profile.endpoint), - prepareHeaders(httpRequest), + prepareHeaders(httpRequest, setIncludeEndStreamAction), HttpClientContext.create() ) try { @@ -1047,17 +1141,22 @@ class DeltaSharingRestClient( // Only show the last 100 lines in the error to keep it contained. val responseToShow = lines.drop(lines.size - 100).mkString("\n") throw new UnexpectedHttpStatus( - s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo" - + getDsQueryIdForLogging, + s"HTTP request failed with status: $status" + + Seq(getDsQueryIdForLogging, additionalErrorInfo, responseToShow).mkString(" "), statusCode) } + val capabilities = Option( + response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER) + ).map(_.getValue) + val capabilitiesMap = parseDeltaSharingCapabilities(capabilities) + if (setIncludeEndStreamAction) { + checkEndStreamAction(capabilities, capabilitiesMap, lines) + } ( Option( response.getFirstHeader(RESPONSE_TABLE_VERSION_HEADER_KEY) ).map(_.getValue.toLong), - Option( - response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER) - ).map(_.getValue), + capabilitiesMap, lines ) } finally { @@ -1084,7 +1183,7 @@ class DeltaSharingRestClient( // The value for delta-sharing-capabilities header, semicolon separated capabilities. // Each capability is in the format of "key=value1,value2", values are separated by comma. // Example: "capability1=value1;capability2=value3,value4,value5" - private def getDeltaSharingCapabilities(): String = { + private def constructDeltaSharingCapabilities(setIncludeEndStreamAction: Boolean): String = { var capabilities = Seq[String](s"${RESPONSE_FORMAT}=$responseFormat") if (responseFormatSet.contains(RESPONSE_FORMAT_DELTA) && readerFeatures.nonEmpty) { capabilities = capabilities :+ s"$READER_FEATURES=$readerFeatures" @@ -1094,6 +1193,10 @@ class DeltaSharingRestClient( capabilities = capabilities :+ s"$DELTA_SHARING_CAPABILITIES_ASYNC_READ=true" } + if (setIncludeEndStreamAction) { + capabilities = capabilities :+ s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } + val cap = capabilities.mkString(DELTA_SHARING_CAPABILITIES_DELIMITER) cap } @@ -1116,10 +1219,10 @@ object DeltaSharingRestClient extends Logging { val RESPONSE_FORMAT = "responseformat" val READER_FEATURES = "readerfeatures" val DELTA_SHARING_CAPABILITIES_ASYNC_READ = "asyncquery" + val DELTA_SHARING_INCLUDE_END_STREAM_ACTION = "includeendstreamaction" val RESPONSE_FORMAT_DELTA = "delta" val RESPONSE_FORMAT_PARQUET = "parquet" val DELTA_SHARING_CAPABILITIES_DELIMITER = ";" - val QUERY_PENDING_TRUE = "pending" lazy val USER_AGENT = { try { diff --git a/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala b/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala index 4eb3a1c28..6ce6e17a4 100644 --- a/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala +++ b/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala @@ -22,6 +22,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging +import io.delta.sharing.spark.MissingEndStreamActionException + private[sharing] object RetryUtils extends Logging { // Expose it for testing @@ -70,6 +72,7 @@ private[sharing] object RetryUtils extends Logging { } else { false } + case _: MissingEndStreamActionException => true case _: java.net.SocketTimeoutException => true // do not retry on ConnectionClosedException because it can be caused by invalid json returned // from the delta sharing server. diff --git a/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala b/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala index 7f394f13d..6fd4db41d 100644 --- a/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala +++ b/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala @@ -18,6 +18,8 @@ package io.delta.sharing.spark import org.apache.spark.sql.types.StructType +class MissingEndStreamActionException(message: String) extends IllegalStateException(message) + object DeltaSharingErrors { def nonExistentDeltaSharingTable(tableId: String): Throwable = { new IllegalStateException(s"Delta sharing table ${tableId} doesn't exist. " + diff --git a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala index 7adc6088a..ce6cd3476 100644 --- a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala @@ -39,7 +39,7 @@ class DeltaSharingRestClientDeltaSuite extends DeltaSharingIntegrationTest { val httpRequest = new HttpGet("random_url") val client = new DeltaSharingRestClient(testProfileProvider) - var h = client.prepareHeaders(httpRequest).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) + var h = client.prepareHeaders(httpRequest, setIncludeEndStreamAction = false).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) // scalastyle:off caselocale assert(h.getValue.toLowerCase().contains(s"responseformat=${DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET}")) @@ -47,7 +47,7 @@ class DeltaSharingRestClientDeltaSuite extends DeltaSharingIntegrationTest { testProfileProvider, responseFormat = DeltaSharingRestClient.RESPONSE_FORMAT_DELTA ) - h = deltaClient.prepareHeaders(httpRequest).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) + h = deltaClient.prepareHeaders(httpRequest, setIncludeEndStreamAction = false).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) // scalastyle:off caselocale assert(h.getValue.toLowerCase().contains(s"responseformat=${DeltaSharingRestClient.RESPONSE_FORMAT_DELTA}")) } diff --git a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala index ec5f356a3..3d7939eb1 100644 --- a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala @@ -84,36 +84,65 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { assert(h.contains(" java/")) } - def checkDeltaSharingCapabilities(request: HttpRequestBase, expected: String): Unit = { + def getEndStreamActionHeader(endStreamActionEnabled: Boolean): String = { + if (endStreamActionEnabled) { + s";$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } else { + "" + } + } + + def checkDeltaSharingCapabilities( + request: HttpRequestBase, + responseFormat: String, + readerFeatures: String, + endStreamActionEnabled: Boolean): Unit = { + val expected = s"${RESPONSE_FORMAT}=$responseFormat$readerFeatures" + + getEndStreamActionHeader(endStreamActionEnabled) val h = request.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER) assert(h.getValue == expected) } - var httpRequestBase = new DeltaSharingRestClient( - testProfileProvider, forStreaming = false, readerFeatures = "willBeIgnored").prepareHeaders(httpRequest) - checkUserAgent(httpRequestBase, false) - checkDeltaSharingCapabilities(httpRequestBase, "responseformat=parquet") + Seq( + (true, true), + (true, false), + (false, true), + (false, false) + ).foreach { case (forStreaming, endStreamActionEnabled) => + var client = new DeltaSharingRestClient( + testProfileProvider, + forStreaming = forStreaming, + endStreamActionEnabled = endStreamActionEnabled, + readerFeatures = "willBeIgnored") + .prepareHeaders(httpRequest, setIncludeEndStreamAction = endStreamActionEnabled) + checkUserAgent(client, forStreaming) + checkDeltaSharingCapabilities(client, "parquet", "", endStreamActionEnabled) - val readerFeatures = "deletionVectors,columnMapping,timestampNTZ" - httpRequestBase = new DeltaSharingRestClient( - testProfileProvider, - forStreaming = true, - responseFormat = RESPONSE_FORMAT_DELTA, - readerFeatures = readerFeatures).prepareHeaders(httpRequest) - checkUserAgent(httpRequestBase, true) - checkDeltaSharingCapabilities( - httpRequestBase, s"responseformat=delta;readerfeatures=$readerFeatures" - ) + val readerFeatures = "deletionVectors,columnMapping,timestampNTZ" + client = new DeltaSharingRestClient( + testProfileProvider, + forStreaming = forStreaming, + endStreamActionEnabled = endStreamActionEnabled, + responseFormat = RESPONSE_FORMAT_DELTA, + readerFeatures = readerFeatures) + .prepareHeaders(httpRequest, setIncludeEndStreamAction = endStreamActionEnabled) + checkUserAgent(client, forStreaming) + checkDeltaSharingCapabilities( + client, "delta", s";$READER_FEATURES=$readerFeatures", endStreamActionEnabled + ) - httpRequestBase = new DeltaSharingRestClient( - testProfileProvider, - forStreaming = true, - responseFormat = s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", - readerFeatures = readerFeatures).prepareHeaders(httpRequest) - checkUserAgent(httpRequestBase, true) - checkDeltaSharingCapabilities( - httpRequestBase, s"responseformat=delta,parquet;readerfeatures=$readerFeatures" - ) + client = new DeltaSharingRestClient( + testProfileProvider, + forStreaming = forStreaming, + endStreamActionEnabled = endStreamActionEnabled, + responseFormat = s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", + readerFeatures = readerFeatures) + .prepareHeaders(httpRequest, setIncludeEndStreamAction = endStreamActionEnabled) + checkUserAgent(client, forStreaming) + checkDeltaSharingCapabilities( + client, s"delta,parquet", s";$READER_FEATURES=$readerFeatures", endStreamActionEnabled + ) + } } integrationTest("listAllTables") { @@ -1100,7 +1129,8 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { false ) }.getMessage - assert(errorMessage.contains("""400 Bad Request {"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) + assert(errorMessage.contains("""400 Bad Request for query""")) + assert(errorMessage.contains("""{"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) assert(errorMessage.contains("table files missing")) } finally { client.close() diff --git a/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala b/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala index 297d0fb27..934147233 100644 --- a/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkFunSuite import io.delta.sharing.client.util.{RetryUtils, UnexpectedHttpStatus} import io.delta.sharing.client.util.RetryUtils._ +import io.delta.sharing.spark.MissingEndStreamActionException class RetryUtilsSuite extends SparkFunSuite { test("shouldRetry") { @@ -35,6 +36,7 @@ class RetryUtilsSuite extends SparkFunSuite { assert(shouldRetry(new IOException)) assert(shouldRetry(new java.net.SocketTimeoutException)) assert(!shouldRetry(new RuntimeException)) + assert(shouldRetry(new MissingEndStreamActionException("missing"))) } test("runWithExponentialBackoff") { diff --git a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala index c7e9caba8..a808c43fc 100644 --- a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala +++ b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala @@ -448,7 +448,8 @@ class DeltaSharingSuite extends QueryTest with SharedSparkSession with DeltaShar .option("startingVersion", 0).load(tablePath) checkAnswer(df, Nil) } - assert (ex.getMessage.contains("""400 Bad Request {"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) + assert(ex.getMessage.contains("""400 Bad Request""")) + assert(ex.getMessage.contains("""{"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) } integrationTest("azure support") {