From 8906634dc9d47b501606f32afddeb6ec7ba06369 Mon Sep 17 00:00:00 2001 From: Lin Zhou <87341375+linzhou-db@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:30:02 -0700 Subject: [PATCH] Backport includeEndStreamAction in DeltasharingService to branch-1.0 (#600) * Backport includeEndStreamAction in DeltasharingService to branch-1.0 * fix test --- .../sharing/client/DeltaSharingClient.scala | 254 +++++++++++---- .../delta/sharing/client/util/ConfUtils.scala | 11 + .../sharing/client/util/RetryUtils.scala | 3 + .../sharing/spark/DeltaSharingErrors.scala | 2 + .../sharing/spark/DeltaSharingSource.scala | 63 ++-- .../DeltaSharingRestClientDeltaSuite.scala | 4 +- .../client/DeltaSharingRestClientSuite.scala | 81 +++-- .../sharing/client/util/ConfUtilsSuite.scala | 20 ++ .../sharing/client/util/RetryUtilsSuite.scala | 2 + .../sharing/server/DeltaSharingService.scala | 45 ++- .../internal/DeltaSharedTableLoader.scala | 18 +- .../server/DeltaSharingServiceSuite.scala | 292 +++++++++++------- .../spark/DeltaSharingSourceSuite.scala | 8 +- .../sharing/spark/DeltaSharingSuite.scala | 3 +- .../spark/TestDeltaSharingClient.scala | 3 +- 15 files changed, 560 insertions(+), 249 deletions(-) diff --git a/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala b/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala index 11e9ad7b1..dc9bdb9bd 100644 --- a/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala +++ b/client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala @@ -43,9 +43,21 @@ import org.apache.spark.sql.SparkSession import io.delta.sharing.client.model._ import io.delta.sharing.client.util.{ConfUtils, JsonUtils, RetryUtils, UnexpectedHttpStatus} +import io.delta.sharing.spark.MissingEndStreamActionException /** An interface to fetch Delta metadata from remote server. */ trait DeltaSharingClient { + + protected var dsQueryId: Option[String] = None + + def getQueryId: String = { + dsQueryId.getOrElse("dsQueryIdNotSet") + } + + protected def getDsQueryIdForLogging: String = { + s" for query($dsQueryId)." + } + def listAllTables(): Seq[Table] def getTableVersion(table: Table, startingTimestamp: Option[String] = None): Long @@ -121,9 +133,12 @@ class DeltaSharingRestClient( responseFormat: String = DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET, readerFeatures: String = "", queryTablePaginationEnabled: Boolean = false, - maxFilesPerReq: Int = 100000 + maxFilesPerReq: Int = 100000, + endStreamActionEnabled: Boolean = false ) extends DeltaSharingClient with Logging { + logInfo(s"DeltaSharingRestClient with endStreamActionEnabled: $endStreamActionEnabled.") + import DeltaSharingRestClient._ @volatile private var created = false @@ -131,8 +146,6 @@ class DeltaSharingRestClient( // Convert the responseFormat to a Seq to be used later. private val responseFormatSet = responseFormat.split(",").toSet - private var queryId: Option[String] = None - private lazy val client = { val clientBuilder: HttpClientBuilder = if (sslTrustAll) { val sslBuilder = new SSLContextBuilder() @@ -222,10 +235,15 @@ class DeltaSharingRestClient( val target = getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" + s"$encodedTableName/version$encodedParam") - val (version, _, _) = getResponse(new HttpGet(target), true, true) + val (version, _, _) = getResponse( + new HttpGet(target), + allowNoContent = true, + fetchAsOneString = true, + setIncludeEndStreamAction = false + ) version.getOrElse { throw new IllegalStateException(s"Cannot find " + - s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header") + s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header," + getDsQueryIdForLogging) } } @@ -237,10 +255,10 @@ class DeltaSharingRestClient( private def checkRespondedFormat(respondedFormat: String, rpc: String, table: String): Unit = { if (!responseFormatSet.contains(respondedFormat)) { logError(s"RespondedFormat($respondedFormat) is different from requested " + - s"responseFormat($responseFormat) for $rpc for table $table.") + s"responseFormat($responseFormat) for $rpc for table $table," + getDsQueryIdForLogging) throw new IllegalArgumentException("The responseFormat returned from the delta sharing " + s"server doesn't match the requested responseFormat: respondedFormat($respondedFormat)" + - s" != requestedFormat($responseFormat).") + s" != requestedFormat($responseFormat)," + getDsQueryIdForLogging) } } @@ -256,7 +274,11 @@ class DeltaSharingRestClient( val target = getTargetUrl( s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/metadata" + s"$encodedParams") - val (version, respondedFormat, lines) = getNDJson(target) + val (version, respondedFormat, lines) = getNDJson( + target, + requireVersion = true, + setIncludeEndStreamAction = false + ) checkRespondedFormat( respondedFormat, @@ -272,7 +294,7 @@ class DeltaSharingRestClient( checkProtocol(protocol) val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData if (lines.size != 2) { - throw new IllegalStateException("received more than two lines") + throw new IllegalStateException("received more than two lines," + getDsQueryIdForLogging) } DeltaTableMetadata(version, protocol, metadata, respondedFormat = respondedFormat) } @@ -281,7 +303,8 @@ class DeltaSharingRestClient( if (protocol.minReaderVersion > DeltaSharingProfile.CURRENT) { throw new IllegalArgumentException(s"The table requires a newer version" + s" ${protocol.minReaderVersion} to read. But the current release supports version " + - s"is ${DeltaSharingProfile.CURRENT} and below. Please upgrade to a newer release.") + s"is ${DeltaSharingProfile.CURRENT} and below. Please upgrade to a newer release." + + getDsQueryIdForLogging) } } @@ -320,7 +343,11 @@ class DeltaSharingRestClient( ) getFilesByPage(target, request) } else { - val (version, respondedFormat, lines) = getNDJson(target, request) + val (version, respondedFormat, lines) = getNDJsonPost( + target, + request, + setIncludeEndStreamAction = endStreamActionEnabled + ) val (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines) val refreshTokenOpt = endStreamAction.flatMap { e => Option(e.refreshToken).flatMap { token => @@ -357,7 +384,7 @@ class DeltaSharingRestClient( if (action.file != null) { files.append(action.file) } else { - throw new IllegalStateException(s"Unexpected Line:${line}") + throw new IllegalStateException(s"Unexpected Line:${line}," + getDsQueryIdForLogging) } } DeltaTableFiles( @@ -396,12 +423,19 @@ class DeltaSharingRestClient( val (version, respondedFormat, lines) = if (queryTablePaginationEnabled) { logInfo( s"Making paginated queryTable from version $startingVersion requests for table " + - s"${table.share}.${table.schema}.${table.name} with maxFiles=$maxFilesPerReq" + s"${table.share}.${table.schema}.${table.name} with maxFiles=$maxFilesPerReq" + + getDsQueryIdForLogging ) val (version, respondedFormat, lines, _) = getFilesByPage(target, request) (version, respondedFormat, lines) } else { - getNDJson(target, request) + val (version, respondedFormat, lines) = getNDJsonPost( + target, + request, + setIncludeEndStreamAction = endStreamActionEnabled + ) + val (filteredLines, _) = maybeExtractEndStreamAction(lines) + (version, respondedFormat, filteredLines) } checkRespondedFormat( @@ -425,7 +459,8 @@ class DeltaSharingRestClient( case a: AddFileForCDF => addFiles.append(a) case r: RemoveFile => removeFiles.append(r) case m: Metadata => additionalMetadatas.append(m) - case _ => throw new IllegalStateException(s"Unexpected Line:${line}") + case _ => throw new IllegalStateException( + s"Unexpected Line:${line}," + getDsQueryIdForLogging) } } DeltaTableFiles( @@ -450,10 +485,14 @@ class DeltaSharingRestClient( // Fetch first page var updatedRequest = request.copy(maxFiles = Some(maxFilesPerReq)) - val (version, respondedFormat, lines) = getNDJson(targetUrl, updatedRequest) + val (version, respondedFormat, lines) = getNDJsonPost( + targetUrl, request, setIncludeEndStreamAction = endStreamActionEnabled + ) var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines) if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response for paginated query.") + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } val protocol = filteredLines(0) val metadata = filteredLines(1) @@ -492,17 +531,20 @@ class DeltaSharingRestClient( allLines.appendAll(res._1) endStreamAction = res._2 if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response for paginated query.") + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } // Throw an error if the first page is expiring before we get all pages if (minUrlExpirationTimestamp.exists(_ <= System.currentTimeMillis())) { - throw new IllegalStateException("Unable to fetch all pages before minimum url expiration.") + throw new IllegalStateException( + "Unable to fetch all pages before minimum url expiration," + getDsQueryIdForLogging) } } // TODO: remove logging once changes are rolled out logInfo(s"Took ${System.currentTimeMillis() - start} ms to query $numPages pages " + - s"of ${allLines.size} files") + s"of ${allLines.size} files," + getDsQueryIdForLogging) (version, respondedFormat, allLines.toSeq, refreshToken) } @@ -521,11 +563,16 @@ class DeltaSharingRestClient( // TODO: remove logging once changes are rolled out logInfo( s"Making paginated queryTableChanges requests for table " + - s"${table.share}.${table.schema}.${table.name} with maxFiles=$maxFilesPerReq" + s"${table.share}.${table.schema}.${table.name} with maxFiles=$maxFilesPerReq," + + getDsQueryIdForLogging ) getCDFFilesByPage(target) } else { - getNDJson(target, requireVersion = false) + val (version, respondedFormat, lines) = getNDJson( + target, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled + ) + val (filteredLines, _) = maybeExtractEndStreamAction(lines) + (version, respondedFormat, filteredLines) } checkRespondedFormat( @@ -553,7 +600,8 @@ class DeltaSharingRestClient( case a: AddFileForCDF => addFiles.append(a) case r: RemoveFile => removeFiles.append(r) case m: Metadata => additionalMetadatas.append(m) - case _ => throw new IllegalStateException(s"Unexpected Line:${line}") + case _ => throw new IllegalStateException( + s"Unexpected Line:${line}, " + getDsQueryIdForLogging) } } DeltaTableFiles( @@ -577,10 +625,14 @@ class DeltaSharingRestClient( // Fetch first page var updatedUrl = s"$targetUrl&maxFiles=$maxFilesPerReq" - val (version, respondedFormat, lines) = getNDJson(updatedUrl, requireVersion = false) + val (version, respondedFormat, lines) = getNDJson( + updatedUrl, requireVersion = false, setIncludeEndStreamAction = endStreamActionEnabled + ) var (filteredLines, endStreamAction) = maybeExtractEndStreamAction(lines) if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response for paginated query.") + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } val protocol = filteredLines(0) val metadata = filteredLines(1) @@ -608,18 +660,21 @@ class DeltaSharingRestClient( allLines.appendAll(res._1) endStreamAction = res._2 if (endStreamAction.isEmpty) { - logWarning("EndStreamAction is not returned in the response for paginated query.") + logWarning( + "EndStreamAction is not returned in the paginated response" + getDsQueryIdForLogging + ) } // Throw an error if the first page is expiring before we get all pages if (minUrlExpirationTimestamp.exists(_ <= System.currentTimeMillis())) { - throw new IllegalStateException("Unable to fetch all pages before minimum url expiration.") + throw new IllegalStateException( + "Unable to fetch all pages before minimum url expiration," + getDsQueryIdForLogging) } } // TODO: remove logging once changes are rolled out logInfo( s"Took ${System.currentTimeMillis() - start} ms to query $numPages pages " + - s"of ${allLines.size} files" + s"of ${allLines.size} files," + getDsQueryIdForLogging ) (version, respondedFormat, allLines.toSeq) } @@ -637,12 +692,16 @@ class DeltaSharingRestClient( pageNumber: Int): (Seq[String], Option[EndStreamAction]) = { val start = System.currentTimeMillis() val (version, respondedFormat, lines) = if (requestBody.isDefined) { - getNDJson(targetUrl, requestBody.get) + getNDJsonPost(targetUrl, requestBody.get, setIncludeEndStreamAction = endStreamActionEnabled) } else { - getNDJson(targetUrl, requireVersion = false) + getNDJson( + targetUrl, + requireVersion = false, + setIncludeEndStreamAction = endStreamActionEnabled + ) } logInfo(s"Took ${System.currentTimeMillis() - start} to fetch ${pageNumber}th page " + - s"of ${lines.size} lines.") + s"of ${lines.size} lines," + getDsQueryIdForLogging) // Validate that version/format/protocol/metadata in the response don't change across pages if (version != expectedVersion || @@ -653,9 +712,9 @@ class DeltaSharingRestClient( |Received inconsistent version/format/protocol/metadata across pages. |Expected: version $expectedVersion, $expectedRespondedFormat, |$expectedProtocol, $expectedMetadata. Actual: version $version, - |$respondedFormat, ${lines(0)}, ${lines(1)}""".stripMargin + |$respondedFormat, ${lines},$getDsQueryIdForLogging""".stripMargin logError(s"Error while fetching next page files at url $targetUrl " + - s"with body(${JsonUtils.toJson(requestBody.orNull)}: $errorMsg)") + s"with body(${JsonUtils.toJson(requestBody.orNull)}: $errorMsg).") throw new IllegalStateException(errorMsg) } @@ -703,42 +762,90 @@ class DeltaSharingRestClient( } private def getNDJson( - target: String, requireVersion: Boolean = true): (Long, String, Seq[String]) = { - val (version, capabilities, lines) = getResponse(new HttpGet(target)) + target: String, + requireVersion: Boolean, + setIncludeEndStreamAction: Boolean): (Long, String, Seq[String]) = { + val (version, capabilitiesMap, lines) = getResponse( + new HttpGet(target), setIncludeEndStreamAction = setIncludeEndStreamAction + ) ( version.getOrElse { if (requireVersion) { throw new IllegalStateException(s"Cannot find " + - s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header") + s"${RESPONSE_TABLE_VERSION_HEADER_KEY} in the header," + getDsQueryIdForLogging) } else { 0L } }, - getRespondedFormat(capabilities), + getRespondedFormat(capabilitiesMap), lines ) } - private def getNDJson[T: Manifest](target: String, data: T): (Long, String, Seq[String]) = { + private def getNDJsonPost[T: Manifest]( + target: String, + data: T, + setIncludeEndStreamAction: Boolean): (Long, String, Seq[String]) = { val httpPost = new HttpPost(target) val json = JsonUtils.toJson(data) httpPost.setHeader("Content-type", "application/json") httpPost.setEntity(new StringEntity(json, UTF_8)) - val (version, capabilities, lines) = getResponse(httpPost) + val (version, capabilitiesMap, lines) = getResponse( + httpPost, setIncludeEndStreamAction = setIncludeEndStreamAction + ) ( version.getOrElse { - throw new IllegalStateException("Cannot find Delta-Table-Version in the header") + throw new IllegalStateException( + "Cannot find Delta-Table-Version in the header," + getDsQueryIdForLogging) }, - getRespondedFormat(capabilities), + getRespondedFormat(capabilitiesMap), lines ) } - private def getRespondedFormat(capabilities: Option[String]): String = { - val capabilitiesMap = getDeltaSharingCapabilitiesMap(capabilities) + private def checkEndStreamAction( + capabilities: Option[String], + capabilitiesMap: Map[String, String], + lines: Seq[String]): Unit = { + val includeEndStreamActionHeader = getRespondedIncludeEndStreamActionHeader(capabilitiesMap) + includeEndStreamActionHeader match { + case Some(true) => + val lastLineAction = JsonUtils.fromJson[SingleAction](lines.last) + if (lastLineAction.endStreamAction == null) { + throw new MissingEndStreamActionException(s"Client sets " + + s"${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " + + s"header, server responded with the header set to true(${capabilities}, " + + s"and ${lines.size} lines, and last line parsed as " + + s"${lastLineAction.unwrap.getClass()}," + getDsQueryIdForLogging) + } + logInfo( + s"Successfully verified endStreamAction in the response" + getDsQueryIdForLogging + ) + case Some(false) => + logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the " + + s"header, but the server responded with the header set to false(" + + s"${capabilities})," + getDsQueryIdForLogging + ) + case None => + logWarning(s"Client sets ${DELTA_SHARING_INCLUDE_END_STREAM_ACTION}=true in the" + + s" header, but server didn't respond with the header(${capabilities})," + + getDsQueryIdForLogging + ) + } + } + + private def getRespondedFormat(capabilitiesMap: Map[String, String]): String = { capabilitiesMap.get(RESPONSE_FORMAT).getOrElse(RESPONSE_FORMAT_PARQUET) } - private def getDeltaSharingCapabilitiesMap(capabilities: Option[String]): Map[String, String] = { + + // includeEndStreamActionHeader indicates whether the last line is required to be an + // EndStreamAction, parsed from the response header. + private def getRespondedIncludeEndStreamActionHeader( + capabilitiesMap: Map[String, String]): Option[Boolean] = { + capabilitiesMap.get(DELTA_SHARING_INCLUDE_END_STREAM_ACTION).map(_.toBoolean) + } + + private def parseDeltaSharingCapabilities(capabilities: Option[String]): Map[String, String] = { if (capabilities.isEmpty) { return Map.empty[String, String] } @@ -751,10 +858,15 @@ class DeltaSharingRestClient( } private def getJson[R: Manifest](target: String): R = { - val (_, _, response) = getResponse(new HttpGet(target), false, true) + val (_, _, response) = getResponse( + new HttpGet(target), + allowNoContent = false, + fetchAsOneString = true, + setIncludeEndStreamAction = false + ) if (response.size != 1) { throw new IllegalStateException( - "Unexpected response for target: " + target + ", response=" + response + s"Unexpected response for target: $target, response=$response," + getDsQueryIdForLogging ) } JsonUtils.fromJson[R](response(0)) @@ -782,7 +894,8 @@ class DeltaSharingRestClient( } } - private[client] def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = { + private[client] def prepareHeaders( + httpRequest: HttpRequestBase, setIncludeEndStreamAction: Boolean): HttpRequestBase = { val customeHeaders = profileProvider.getCustomHeaders if (customeHeaders.contains(HttpHeaders.AUTHORIZATION) || customeHeaders.contains(HttpHeaders.USER_AGENT)) { @@ -794,7 +907,9 @@ class DeltaSharingRestClient( val headers = Map( HttpHeaders.AUTHORIZATION -> s"Bearer ${profileProvider.getProfile.bearerToken}", HttpHeaders.USER_AGENT -> getUserAgent(), - DELTA_SHARING_CAPABILITIES_HEADER -> getDeltaSharingCapabilities() + DELTA_SHARING_CAPABILITIES_HEADER -> constructDeltaSharingCapabilities( + setIncludeEndStreamAction + ) ) ++ customeHeaders headers.foreach(header => httpRequest.setHeader(header._1, header._2)) @@ -813,15 +928,16 @@ class DeltaSharingRestClient( private def getResponse( httpRequest: HttpRequestBase, allowNoContent: Boolean = false, - fetchAsOneString: Boolean = false - ): (Option[Long], Option[String], Seq[String]) = { - // Reset queryId before calling RetryUtils, and before prepareHeaders. - queryId = Some(UUID.randomUUID().toString().split('-').head) + fetchAsOneString: Boolean = false, + setIncludeEndStreamAction: Boolean = false + ): (Option[Long], Map[String, String], Seq[String]) = { + // Reset dsQueryId before calling RetryUtils, and before prepareHeaders. + dsQueryId = Some(UUID.randomUUID().toString().split('-').head) RetryUtils.runWithExponentialBackoff(numRetries, maxRetryDuration) { val profile = profileProvider.getProfile val response = client.execute( getHttpHost(profile.endpoint), - prepareHeaders(httpRequest), + prepareHeaders(httpRequest, setIncludeEndStreamAction), HttpClientContext.create() ) try { @@ -849,7 +965,8 @@ class DeltaSharingRestClient( } } catch { case e: org.apache.http.ConnectionClosedException => - val error = s"Request to delta sharing server failed due to ${e}." + val error = s"Request to delta sharing server failed" + getDsQueryIdForLogging + + s", last line:[${lineBuffer.last}], due to ${e}." logError(error) lineBuffer += error lineBuffer.toList @@ -869,16 +986,22 @@ class DeltaSharingRestClient( // Only show the last 100 lines in the error to keep it contained. val responseToShow = lines.drop(lines.size - 100).mkString("\n") throw new UnexpectedHttpStatus( - s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo", + s"HTTP request failed with status: $status" + + Seq(getDsQueryIdForLogging, additionalErrorInfo, responseToShow).mkString(" "), statusCode) } + val capabilities = Option( + response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER) + ).map(_.getValue) + val capabilitiesMap = parseDeltaSharingCapabilities(capabilities) + if (setIncludeEndStreamAction) { + checkEndStreamAction(capabilities, capabilitiesMap, lines) + } ( Option( response.getFirstHeader(RESPONSE_TABLE_VERSION_HEADER_KEY) ).map(_.getValue.toLong), - Option( - response.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER) - ).map(_.getValue), + capabilitiesMap, lines ) } finally { @@ -899,17 +1022,20 @@ class DeltaSharingRestClient( } private def getQueryIdString: String = { - s"QueryId-${queryId.getOrElse("not_set")}" + s"QueryId-${dsQueryId.getOrElse("not_set")}" } // The value for delta-sharing-capabilities header, semicolon separated capabilities. // Each capability is in the format of "key=value1,value2", values are separated by comma. // Example: "capability1=value1;capability2=value3,value4,value5" - private def getDeltaSharingCapabilities(): String = { + private def constructDeltaSharingCapabilities(setIncludeEndStreamAction: Boolean): String = { var capabilities = Seq[String](s"${RESPONSE_FORMAT}=$responseFormat") if (responseFormatSet.contains(RESPONSE_FORMAT_DELTA) && readerFeatures.nonEmpty) { capabilities = capabilities :+ s"$READER_FEATURES=$readerFeatures" } + if (setIncludeEndStreamAction) { + capabilities = capabilities :+ s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } capabilities.mkString(DELTA_SHARING_CAPABILITIES_DELIMITER) } @@ -930,6 +1056,7 @@ object DeltaSharingRestClient extends Logging { val RESPONSE_TABLE_VERSION_HEADER_KEY = "Delta-Table-Version" val RESPONSE_FORMAT = "responseformat" val READER_FEATURES = "readerfeatures" + val DELTA_SHARING_INCLUDE_END_STREAM_ACTION = "includeendstreamaction" val RESPONSE_FORMAT_DELTA = "delta" val RESPONSE_FORMAT_PARQUET = "parquet" val DELTA_SHARING_CAPABILITIES_DELIMITER = ";" @@ -1018,6 +1145,7 @@ object DeltaSharingRestClient extends Logging { val timeoutInSeconds = ConfUtils.timeoutInSeconds(sqlConf) val queryTablePaginationEnabled = ConfUtils.queryTablePaginationEnabled(sqlConf) val maxFilesPerReq = ConfUtils.maxFilesPerQueryRequest(sqlConf) + val endStreamActionEnabled = ConfUtils.includeEndStreamAction(sqlConf) val clientClass = ConfUtils.clientClass(sqlConf) Class.forName(clientClass) @@ -1031,7 +1159,8 @@ object DeltaSharingRestClient extends Logging { classOf[String], classOf[String], classOf[Boolean], - classOf[Int] + classOf[Int], + classOf[Boolean] ).newInstance(profileProvider, java.lang.Integer.valueOf(timeoutInSeconds), java.lang.Integer.valueOf(numRetries), @@ -1041,7 +1170,8 @@ object DeltaSharingRestClient extends Logging { responseFormat, readerFeatures, java.lang.Boolean.valueOf(queryTablePaginationEnabled), - java.lang.Integer.valueOf(maxFilesPerReq) + java.lang.Integer.valueOf(maxFilesPerReq), + java.lang.Boolean.valueOf(endStreamActionEnabled) ).asInstanceOf[DeltaSharingClient] } } diff --git a/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala b/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala index 0b1b1eb4d..d17be41a2 100644 --- a/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala +++ b/client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala @@ -30,6 +30,9 @@ object ConfUtils { val MAX_RETRY_DURATION_CONF = "spark.delta.sharing.network.maxRetryDuration" val MAX_RETRY_DURATION_DEFAULT_MILLIS = 10L * 60L* 1000L /* 10 mins */ + val INCLUDE_END_STREAM_ACTION_CONF = "spark.delta.sharing.query.includeEndStreamAction" + val INCLUDE_END_STREAM_ACTION_DEFAULT = "false" + val TIMEOUT_CONF = "spark.delta.sharing.network.timeout" val TIMEOUT_DEFAULT = "320s" @@ -111,6 +114,14 @@ object ConfUtils { maxDur } + def includeEndStreamAction(conf: Configuration): Boolean = { + conf.getBoolean(INCLUDE_END_STREAM_ACTION_CONF, INCLUDE_END_STREAM_ACTION_DEFAULT.toBoolean) + } + + def includeEndStreamAction(conf: SQLConf): Boolean = { + conf.getConfString(INCLUDE_END_STREAM_ACTION_CONF, INCLUDE_END_STREAM_ACTION_DEFAULT).toBoolean + } + def timeoutInSeconds(conf: Configuration): Int = { val timeoutStr = conf.get(TIMEOUT_CONF, TIMEOUT_DEFAULT) toTimeInSeconds(timeoutStr, TIMEOUT_CONF) diff --git a/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala b/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala index 4eb3a1c28..6ce6e17a4 100644 --- a/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala +++ b/client/src/main/scala/io/delta/sharing/client/util/RetryUtils.scala @@ -22,6 +22,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging +import io.delta.sharing.spark.MissingEndStreamActionException + private[sharing] object RetryUtils extends Logging { // Expose it for testing @@ -70,6 +72,7 @@ private[sharing] object RetryUtils extends Logging { } else { false } + case _: MissingEndStreamActionException => true case _: java.net.SocketTimeoutException => true // do not retry on ConnectionClosedException because it can be caused by invalid json returned // from the delta sharing server. diff --git a/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala b/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala index 7f394f13d..6fd4db41d 100644 --- a/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala +++ b/client/src/main/scala/io/delta/sharing/spark/DeltaSharingErrors.scala @@ -18,6 +18,8 @@ package io.delta.sharing.spark import org.apache.spark.sql.types.StructType +class MissingEndStreamActionException(message: String) extends IllegalStateException(message) + object DeltaSharingErrors { def nonExistentDeltaSharingTable(tableId: String): Throwable = { new IllegalStateException(s"Delta sharing table ${tableId} doesn't exist. " + diff --git a/client/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/client/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index cb4beef28..4f033bc40 100644 --- a/client/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/client/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -18,6 +18,7 @@ package io.delta.sharing.spark // scalastyle:off import.ordering.noEmptyLine import java.lang.ref.WeakReference +import java.util.UUID import scala.collection.mutable.ArrayBuffer @@ -100,6 +101,8 @@ case class DeltaSharingSource( assert(deltaLog.client.getForStreaming, "forStreaming must be true for client in DeltaSharingSource.") + private val sourceId = Some(UUID.randomUUID().toString().split('-').head) + // The snapshot that's used to construct the dataframe, constructed when source is initialized. // Use latest snapshot instead of snapshot at startingVersion, to allow easy recovery from // failures on schema incompatibility. @@ -148,10 +151,12 @@ case class DeltaSharingSource( val intervalSeconds = ConfUtils.MINIMUM_TABLE_VERSION_INTERVAL_SECONDS.max( ConfUtils.streamingQueryTableVersionIntervalSeconds(spark.sessionState.conf) ) - logInfo(s"Configured queryTableVersionIntervalSeconds:${intervalSeconds}.") + logInfo(s"Configured queryTableVersionIntervalSeconds:${intervalSeconds}," + + getTableInfoForLogging) if (intervalSeconds < ConfUtils.MINIMUM_TABLE_VERSION_INTERVAL_SECONDS) { throw new IllegalArgumentException(s"QUERY_TABLE_VERSION_INTERVAL_MILLIS($intervalSeconds) " + - s"must not be less than ${ConfUtils.MINIMUM_TABLE_VERSION_INTERVAL_SECONDS} seconds.") + s"must not be less than ${ConfUtils.MINIMUM_TABLE_VERSION_INTERVAL_SECONDS} seconds," + + getTableInfoForLogging) } intervalSeconds * 1000 } @@ -166,6 +171,13 @@ case class DeltaSharingSource( TableRefreshResult(Map.empty[String, String], None, None) } + private lazy val getTableInfoForLogging: String = + s" for table(id:$tableId, name:${deltaLog.table.toString}, source:$sourceId)" + + private def getQueryIdForLogging: String = { + s", with queryId(${deltaLog.client.getQueryId})" + } + // Check the latest table version from the delta sharing server through the client.getTableVersion // RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive // rpcs to avoid traffic jam on the delta sharing server. @@ -174,13 +186,14 @@ case class DeltaSharingSource( if (lastGetVersionTimestamp == -1 || (currentTimeMillis - lastGetVersionTimestamp) >= QUERY_TABLE_VERSION_INTERVAL_MILLIS) { val serverVersion = deltaLog.client.getTableVersion(deltaLog.table) - logInfo(s"Got table version $serverVersion from Delta Sharing Server.") + logInfo(s"Got table version $serverVersion from Delta Sharing Server," + + getTableInfoForLogging) if (serverVersion < 0) { throw new IllegalStateException(s"Delta Sharing Server returning negative table version:" + - s"$serverVersion.") + s"$serverVersion," + getTableInfoForLogging) } else if (serverVersion < latestTableVersion) { logWarning(s"Delta Sharing Server returning smaller table version:$serverVersion < " + - s"$latestTableVersion.") + s"$latestTableVersion, " + getTableInfoForLogging) } latestTableVersion = serverVersion lastGetVersionTimestamp = currentTimeMillis @@ -246,7 +259,7 @@ case class DeltaSharingSource( s"$fromVersion, $fromIndex, $isStartingVersion) is not included in sortedFetchedFiles[" + s"(${headFile.version}, ${headFile.index}, ${headFile.isSnapshot}) to " + s"(${lastFile.version}, ${lastFile.index}, ${lastFile.isSnapshot})], " + - s"for table(id:$tableId, name:${deltaLog.table.toString})") + getTableInfoForLogging) sortedFetchedFiles = Seq.empty } else { return @@ -265,7 +278,7 @@ case class DeltaSharingSource( logInfo(s"Reducing ending version for delta sharing rpc from currentLatestVersion(" + s"$currentLatestVersion) to endingVersionForQuery($endingVersionForQuery), fromVersion:" + s"$fromVersion, maxVersionsPerRpc:$maxVersionsPerRpc, " + - s"for table(id:$tableId, name:${deltaLog.table.toString})." + getTableInfoForLogging ) } @@ -332,7 +345,7 @@ case class DeltaSharingSource( ): Unit = { synchronized { logInfo(s"Refreshing sortedFetchedFiles(size: ${sortedFetchedFiles.size}) with newIdToUrl(" + - s"size: ${newIdToUrl.size}), for table(id:$tableId, name:${deltaLog.table.toString}).") + s"size: ${newIdToUrl.size})," + getTableInfoForLogging + getQueryIdForLogging) lastQueryTableTimestamp = queryTimestamp minUrlExpirationTimestamp = newMinUrlExpiration if (!CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { @@ -351,7 +364,7 @@ case class DeltaSharingSource( val newUrl = newIdToUrl.getOrElse( indexedFile.add.id, throw new IllegalStateException(s"cannot find url for id ${indexedFile.add.id} " + - s"when refreshing table ${deltaLog.path}") + s"when refreshing table ${deltaLog.path}," + getTableInfoForLogging) ) indexedFile.add.copy(url = newUrl) }, @@ -362,7 +375,7 @@ case class DeltaSharingSource( val newUrl = newIdToUrl.getOrElse( indexedFile.remove.id, throw new IllegalStateException(s"cannot find url for id ${indexedFile.remove.id} " + - s"when refreshing table ${deltaLog.path}") + s"when refreshing table ${deltaLog.path}," + getTableInfoForLogging) ) indexedFile.remove.copy(url = newUrl) }, @@ -373,7 +386,7 @@ case class DeltaSharingSource( val newUrl = newIdToUrl.getOrElse( indexedFile.cdc.id, throw new IllegalStateException(s"cannot find url for id ${indexedFile.cdc.id} " + - s"when refreshing table ${deltaLog.path}") + s"when refreshing table ${deltaLog.path}," + getTableInfoForLogging) ) indexedFile.cdc.copy(url = newUrl) }, @@ -382,7 +395,7 @@ case class DeltaSharingSource( ) } logInfo(s"Refreshed ${numUrlsRefreshed} urls in sortedFetchedFiles(size: " + - s"${sortedFetchedFiles.size}).") + s"${sortedFetchedFiles.size})," + getTableInfoForLogging) } } @@ -408,8 +421,8 @@ case class DeltaSharingSource( isStartingVersion: Boolean, endingVersionForQuery: Long): Unit = { logInfo(s"Fetching files with fromVersion($fromVersion), fromIndex($fromIndex), " + - s"isStartingVersion($isStartingVersion), endingVersionForQuery($endingVersionForQuery), " + - s"for table(id:$tableId, name:${deltaLog.table.toString})." + s"isStartingVersion($isStartingVersion), endingVersionForQuery($endingVersionForQuery)," + + getTableInfoForLogging ) resetGlobalTimestamp() if (isStartingVersion) { @@ -456,7 +469,7 @@ case class DeltaSharingSource( val numFiles = tableFiles.files.size logInfo( s"Fetched ${numFiles} files for table version ${tableFiles.version} from" + - " delta sharing server." + s" delta sharing server," + getTableInfoForLogging + getQueryIdForLogging ) tableFiles.files.sortWith(fileActionCompareFunc).zipWithIndex.foreach { case (file, index) if (index > fromIndex) => @@ -510,11 +523,13 @@ case class DeltaSharingSource( TableRefreshResult(idToUrl, minUrlExpiration, None) } - val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version) + val filteredAddFiles = validateCommitAndFilterAddFiles(tableFiles) + val allAddFiles = filteredAddFiles.groupBy(a => a.version) logInfo( - s"Fetched and filtered ${allAddFiles.size} files from startingVersion " + + s"Fetched ${tableFiles.addFiles.size} files, filtered ${filteredAddFiles.size} " + + s"in ${allAddFiles.size} versions from startingVersion " + s"${fromVersion} to endingVersion ${endingVersionForQuery} from " + - "delta sharing server." + s"delta sharing server," + getTableInfoForLogging + getQueryIdForLogging ) for (v <- fromVersion to endingVersionForQuery) { val vAddFiles = allAddFiles.getOrElse(v, ArrayBuffer[AddFileForCDF]()) @@ -555,7 +570,7 @@ case class DeltaSharingSource( endingVersionForQuery: Long): Unit = { logInfo(s"Fetching CDF files with fromVersion($fromVersion), fromIndex($fromIndex), " + s"endingVersionForQuery($endingVersionForQuery), " + - s"for table(id:$tableId, name:${deltaLog.table.toString}).") + s"endingVersionForQuery($endingVersionForQuery)," + getTableInfoForLogging) resetGlobalTimestamp() val tableFiles = deltaLog.client.getCDFFiles( deltaLog.table, @@ -824,7 +839,7 @@ case class DeltaSharingSource( case cdf: AddCDCFile => cdfFiles.append(cdf) case add: AddFileForCDF => addFiles.append(add) case remove: RemoveFile => removeFiles.append(remove) - case f => throw new IllegalStateException(s"Unexpected File:${f}") + case f => throw new IllegalStateException(s"Unexpected File:${f}," + getTableInfoForLogging) } } @@ -1000,8 +1015,8 @@ case class DeltaSharingSource( } override def getBatch(startOffsetOption: Option[Offset], end: Offset): DataFrame = { - logInfo(s"getBatch with startOffsetOption($startOffsetOption) and end($end), " + - s"for table(id:$tableId, name:${deltaLog.table.toString})") + logInfo(s"getBatch with startOffsetOption($startOffsetOption) and end($end)," + + getTableInfoForLogging) val endOffset = DeltaSharingSourceOffset(tableId, end) val (startVersion, startIndex, isStartingVersion, startSourceVersion) = if ( @@ -1028,7 +1043,7 @@ case class DeltaSharingSource( val startOffset = DeltaSharingSourceOffset(tableId, startOffsetOption.get) if (startOffset == endOffset) { logInfo(s"startOffset($startOffset) is the same as endOffset($endOffset) in getBatch, " + - s"for table(id:$tableId, name:${deltaLog.table.toString})") + getTableInfoForLogging) previousOffset = endOffset // This happens only if we recover from a failure and `MicroBatchExecution` tries to call // us with the previous offsets. The returned DataFrame will be dropped immediately, so we @@ -1128,7 +1143,7 @@ case class DeltaSharingSource( } else if (options.startingTimestamp.isDefined) { val version = deltaLog.client.getTableVersion(deltaLog.table, options.startingTimestamp) logInfo(s"Got table version $version for timestamp ${options.startingTimestamp} " + - s"from Delta Sharing Server.") + s"from Delta Sharing Server," + getTableInfoForLogging) Some(version) } else { None diff --git a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala index 7adc6088a..ce6cd3476 100644 --- a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientDeltaSuite.scala @@ -39,7 +39,7 @@ class DeltaSharingRestClientDeltaSuite extends DeltaSharingIntegrationTest { val httpRequest = new HttpGet("random_url") val client = new DeltaSharingRestClient(testProfileProvider) - var h = client.prepareHeaders(httpRequest).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) + var h = client.prepareHeaders(httpRequest, setIncludeEndStreamAction = false).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) // scalastyle:off caselocale assert(h.getValue.toLowerCase().contains(s"responseformat=${DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET}")) @@ -47,7 +47,7 @@ class DeltaSharingRestClientDeltaSuite extends DeltaSharingIntegrationTest { testProfileProvider, responseFormat = DeltaSharingRestClient.RESPONSE_FORMAT_DELTA ) - h = deltaClient.prepareHeaders(httpRequest).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) + h = deltaClient.prepareHeaders(httpRequest, setIncludeEndStreamAction = false).getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER) // scalastyle:off caselocale assert(h.getValue.toLowerCase().contains(s"responseformat=${DeltaSharingRestClient.RESPONSE_FORMAT_DELTA}")) } diff --git a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala index 2f3e6e3fc..a1676bdeb 100644 --- a/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/DeltaSharingRestClientSuite.scala @@ -84,36 +84,66 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { assert(h.contains(" java/")) } - def checkDeltaSharingCapabilities(request: HttpRequestBase, expected: String): Unit = { + def getEndStreamActionHeader(endStreamActionEnabled: Boolean): String = { + if (endStreamActionEnabled) { + s";$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } else { + "" + } + } + + def checkDeltaSharingCapabilities( + request: HttpRequestBase, + responseFormat: String, + readerFeatures: String, + endStreamActionEnabled: Boolean): Unit = { + val expected = s"${RESPONSE_FORMAT}=$responseFormat$readerFeatures" + + getEndStreamActionHeader(endStreamActionEnabled) val h = request.getFirstHeader(DELTA_SHARING_CAPABILITIES_HEADER) assert(h.getValue == expected) } - var httpRequestBase = new DeltaSharingRestClient( - testProfileProvider, forStreaming = false, readerFeatures = "willBeIgnored").prepareHeaders(httpRequest) - checkUserAgent(httpRequestBase, false) - checkDeltaSharingCapabilities(httpRequestBase, "responseformat=parquet") + Seq( + (true, true), + (true, false), + (false, true), + (false, false) + ).foreach { case (forStreaming, endStreamActionEnabled) => + val httpRequest = new HttpGet("random_url") + var request = new DeltaSharingRestClient( + testProfileProvider, + forStreaming = forStreaming, + endStreamActionEnabled = endStreamActionEnabled, + readerFeatures = "willBeIgnored") + .prepareHeaders(httpRequest, setIncludeEndStreamAction = endStreamActionEnabled) + checkUserAgent(request, forStreaming) + checkDeltaSharingCapabilities(request, "parquet", "", endStreamActionEnabled) - val readerFeatures = "deletionVectors,columnMapping,timestampNTZ" - httpRequestBase = new DeltaSharingRestClient( - testProfileProvider, - forStreaming = true, - responseFormat = RESPONSE_FORMAT_DELTA, - readerFeatures = readerFeatures).prepareHeaders(httpRequest) - checkUserAgent(httpRequestBase, true) - checkDeltaSharingCapabilities( - httpRequestBase, s"responseformat=delta;readerfeatures=$readerFeatures" - ) + val readerFeatures = "deletionVectors,columnMapping,timestampNTZ" + request = new DeltaSharingRestClient( + testProfileProvider, + forStreaming = forStreaming, + endStreamActionEnabled = endStreamActionEnabled, + responseFormat = RESPONSE_FORMAT_DELTA, + readerFeatures = readerFeatures) + .prepareHeaders(httpRequest, setIncludeEndStreamAction = endStreamActionEnabled) + checkUserAgent(request, forStreaming) + checkDeltaSharingCapabilities( + request, "delta", s";$READER_FEATURES=$readerFeatures", endStreamActionEnabled + ) - httpRequestBase = new DeltaSharingRestClient( - testProfileProvider, - forStreaming = true, - responseFormat = s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", - readerFeatures = readerFeatures).prepareHeaders(httpRequest) - checkUserAgent(httpRequestBase, true) - checkDeltaSharingCapabilities( - httpRequestBase, s"responseformat=delta,parquet;readerfeatures=$readerFeatures" - ) + request = new DeltaSharingRestClient( + testProfileProvider, + forStreaming = forStreaming, + endStreamActionEnabled = endStreamActionEnabled, + responseFormat = s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", + readerFeatures = readerFeatures) + .prepareHeaders(httpRequest, setIncludeEndStreamAction = endStreamActionEnabled) + checkUserAgent(request, forStreaming) + checkDeltaSharingCapabilities( + request, s"delta,parquet", s";$READER_FEATURES=$readerFeatures", endStreamActionEnabled + ) + } } integrationTest("listAllTables") { @@ -1007,7 +1037,8 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { false ) }.getMessage - assert(errorMessage.contains("""400 Bad Request {"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) + assert(errorMessage.contains("""400 Bad Request for query""")) + assert(errorMessage.contains("""{"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) assert(errorMessage.contains("table files missing")) } finally { client.close() diff --git a/client/src/test/scala/io/delta/sharing/client/util/ConfUtilsSuite.scala b/client/src/test/scala/io/delta/sharing/client/util/ConfUtilsSuite.scala index e4f93bb99..8d15aba84 100644 --- a/client/src/test/scala/io/delta/sharing/client/util/ConfUtilsSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/util/ConfUtilsSuite.scala @@ -83,6 +83,26 @@ class ConfUtilsSuite extends SparkFunSuite { }.getMessage.contains(TIMEOUT_CONF) } + test("includeEndStreamAction") { + assert(includeEndStreamAction(newConf()) == false) + assert( + includeEndStreamAction(newConf(Map(INCLUDE_END_STREAM_ACTION_CONF -> "false"))) == false + ) + assert(includeEndStreamAction(newConf(Map(INCLUDE_END_STREAM_ACTION_CONF -> "true"))) == true) + assert(includeEndStreamAction(newConf(Map(INCLUDE_END_STREAM_ACTION_CONF -> "rdm"))) == false) + + assert(includeEndStreamAction(newSqlConf()) == false) + assert( + includeEndStreamAction(newSqlConf(Map(INCLUDE_END_STREAM_ACTION_CONF -> "true"))) == true + ) + assert( + includeEndStreamAction(newSqlConf(Map(INCLUDE_END_STREAM_ACTION_CONF -> "false"))) == false + ) + intercept[IllegalArgumentException] { + includeEndStreamAction(newSqlConf(Map(INCLUDE_END_STREAM_ACTION_CONF -> "random"))) + }.getMessage.contains(INCLUDE_END_STREAM_ACTION_CONF) + } + test("maxConnections") { assert(maxConnections(newConf()) == MAX_CONNECTION_DEFAULT) assert(maxConnections(newConf(Map(MAX_CONNECTION_CONF -> "100"))) == 100) diff --git a/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala b/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala index 297d0fb27..934147233 100644 --- a/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala +++ b/client/src/test/scala/io/delta/sharing/client/util/RetryUtilsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkFunSuite import io.delta.sharing.client.util.{RetryUtils, UnexpectedHttpStatus} import io.delta.sharing.client.util.RetryUtils._ +import io.delta.sharing.spark.MissingEndStreamActionException class RetryUtilsSuite extends SparkFunSuite { test("shouldRetry") { @@ -35,6 +36,7 @@ class RetryUtilsSuite extends SparkFunSuite { assert(shouldRetry(new IOException)) assert(shouldRetry(new java.net.SocketTimeoutException)) assert(!shouldRetry(new RuntimeException)) + assert(shouldRetry(new MissingEndStreamActionException("missing"))) } test("runWithExponentialBackoff") { diff --git a/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala b/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala index 769273120..25216a921 100644 --- a/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala +++ b/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala @@ -273,7 +273,7 @@ class DeltaSharingService(serverConfig: ServerConfig) { if (headerString == null) { return Map.empty[String, String] } - headerString.toLowerCase().split(";") + headerString.toLowerCase().split(DELTA_SHARING_CAPABILITIES_DELIMITER) .map(_.split("=")) .filter(_.size == 2) .map { splits => @@ -321,7 +321,8 @@ class DeltaSharingService(serverConfig: ServerConfig) { pageToken = None, includeRefreshToken = false, refreshToken = None, - responseFormatSet = responseFormatSet) + responseFormatSet = responseFormatSet, + includeEndStreamAction = false) streamingOutput(Some(queryResult.version), queryResult.responseFormat, queryResult.actions) } @@ -396,6 +397,7 @@ class DeltaSharingService(serverConfig: ServerConfig) { } } val responseFormatSet = getResponseFormatSet(capabilitiesMap) + val includeEndStreamAction = getRequestEndStreamAction(capabilitiesMap) val queryResult = deltaSharedTableLoader.loadTable(tableConfig).query( includeFiles = true, request.predicateHints, @@ -409,7 +411,8 @@ class DeltaSharingService(serverConfig: ServerConfig) { request.pageToken, request.includeRefreshToken.getOrElse(false), request.refreshToken, - responseFormatSet = responseFormatSet) + responseFormatSet = responseFormatSet, + includeEndStreamAction = includeEndStreamAction) if (queryResult.version < tableConfig.startVersion) { throw new DeltaSharingIllegalArgumentException( s"You can only query table data since version ${tableConfig.startVersion}." @@ -417,7 +420,12 @@ class DeltaSharingService(serverConfig: ServerConfig) { } logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table " + s"and sign ${queryResult.actions.length - 2} urls for table $share/$schema/$table") - streamingOutput(Some(queryResult.version), queryResult.responseFormat, queryResult.actions) + streamingOutput( + Some(queryResult.version), + queryResult.responseFormat, + queryResult.actions, + includeEndStreamAction = includeEndStreamAction + ) } // scalastyle:off argcount @@ -451,6 +459,7 @@ class DeltaSharingService(serverConfig: ServerConfig) { } val responseFormatSet = getResponseFormatSet(capabilitiesMap) + val includeEndStreamAction = getRequestEndStreamAction(capabilitiesMap) val queryResult = deltaSharedTableLoader.loadTable(tableConfig).queryCDF( getCdfOptionsMap( Option(startingVersion), @@ -461,21 +470,34 @@ class DeltaSharingService(serverConfig: ServerConfig) { includeHistoricalMetadata = Try(includeHistoricalMetadata.toBoolean).getOrElse(false), Option(maxFiles).map(_.toInt), Option(pageToken), - responseFormatSet = responseFormatSet + responseFormatSet = responseFormatSet, + includeEndStreamAction = includeEndStreamAction ) logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table cdf " + s"and sign ${queryResult.actions.length - 2} urls for table $share/$schema/$table") - streamingOutput(Some(queryResult.version), queryResult.responseFormat, queryResult.actions) + streamingOutput( + Some(queryResult.version), + queryResult.responseFormat, + queryResult.actions, + includeEndStreamAction = includeEndStreamAction + ) } private def streamingOutput( version: Option[Long], responseFormat: String, - actions: Seq[Object]): HttpResponse = { + actions: Seq[Object], + includeEndStreamAction: Boolean = false): HttpResponse = { + var capabilities = Seq[String](s"${DELTA_SHARING_RESPONSE_FORMAT}=$responseFormat") + if (includeEndStreamAction) { + capabilities = capabilities :+ s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } + val dsCapHeader = capabilities.mkString(DELTA_SHARING_CAPABILITIES_DELIMITER) + val headers = if (version.isDefined) { createHeadersBuilderForTableVersion(version.get) .set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE) - .set(DELTA_SHARING_CAPABILITIES_HEADER, s"$DELTA_SHARING_RESPONSE_FORMAT=$responseFormat") + .set(DELTA_SHARING_CAPABILITIES_HEADER, dsCapHeader) .build() } else { ResponseHeaders.builder(200) @@ -503,6 +525,8 @@ object DeltaSharingService { val DELTA_TABLE_METADATA_CONTENT_TYPE = "application/x-ndjson; charset=utf-8" val DELTA_SHARING_CAPABILITIES_HEADER = "delta-sharing-capabilities" val DELTA_SHARING_RESPONSE_FORMAT = "responseformat" + val DELTA_SHARING_INCLUDE_END_STREAM_ACTION = "includeendstreamaction" + val DELTA_SHARING_CAPABILITIES_DELIMITER = ";" private val parser = { val parser = ArgumentParsers @@ -628,6 +652,11 @@ object DeltaSharingService { ).split(",").toSet } + private[server] def getRequestEndStreamAction( + headerCapabilities: Map[String, String]): Boolean = { + headerCapabilities.get(DELTA_SHARING_INCLUDE_END_STREAM_ACTION).exists(_.toBoolean) + } + def main(args: Array[String]): Unit = { val ns = parser.parseArgsOrFail(args) val serverConfigPath = ns.getString("config") diff --git a/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala b/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala index 4a8d94f63..a504a6374 100644 --- a/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala +++ b/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala @@ -380,7 +380,8 @@ class DeltaSharedTable( pageToken: Option[String], includeRefreshToken: Boolean, refreshToken: Option[String], - responseFormatSet: Set[String]): QueryResult = withClassLoader { + responseFormatSet: Set[String], + includeEndStreamAction: Boolean): QueryResult = withClassLoader { // scalastyle:on argcount // TODO Support `limitHint` if (Seq(version, timestamp, startingVersion).filter(_.isDefined).size >= 2) { @@ -463,7 +464,8 @@ class DeltaSharedTable( maxFiles, pageTokenOpt, queryParamChecksum, - responseFormat + responseFormat, + includeEndStreamAction ) } else if (includeFiles) { val ts = if (isVersionQuery) { @@ -555,7 +557,7 @@ class DeltaSharedTable( // For backwards compatibility, return an `endStreamAction` object only when // `includeRefreshToken` is true or `maxFiles` is specified filteredFiles ++ { - if (includeRefreshToken || maxFiles.isDefined) { + if (includeRefreshToken || maxFiles.isDefined || includeEndStreamAction) { Seq(getEndStreamAction(nextPageTokenStr, minUrlExpirationTimestamp, refreshTokenStr)) } else { Nil @@ -575,7 +577,8 @@ class DeltaSharedTable( maxFilesOpt: Option[Int], pageTokenOpt: Option[QueryTablePageToken], queryParamChecksum: String, - responseFormat: String + responseFormat: String, + includeEndStreamAction: Boolean ): Seq[Object] = { // For subsequent page calls, instead of using the current latestVersion, use latestVersion in // the pageToken (which is equal to the latestVersion when the first page call is received), @@ -695,7 +698,7 @@ class DeltaSharedTable( } // Return an `endStreamAction` object only when `maxFiles` is specified for // backwards compatibility. - if (maxFilesOpt.isDefined) { + if (maxFilesOpt.isDefined || includeEndStreamAction) { actions.append(getEndStreamAction(null, minUrlExpirationTimestamp)) } actions.toSeq @@ -706,7 +709,8 @@ class DeltaSharedTable( includeHistoricalMetadata: Boolean = false, maxFiles: Option[Int], pageToken: Option[String], - responseFormatSet: Set[String] = Set(DeltaSharedTable.RESPONSE_FORMAT_PARQUET) + responseFormatSet: Set[String] = Set(DeltaSharedTable.RESPONSE_FORMAT_PARQUET), + includeEndStreamAction: Boolean ): QueryResult = withClassLoader { // Step 1: validate pageToken if it's specified lazy val queryParamChecksum = computeChecksum( @@ -865,7 +869,7 @@ class DeltaSharedTable( } // Return an `endStreamAction` object only when `maxFiles` is specified for // backwards compatibility. - if (maxFiles.isDefined) { + if (maxFiles.isDefined || includeEndStreamAction) { actions.append(getEndStreamAction(null, minUrlExpirationTimestamp)) } QueryResult(start, actions.toSeq, responseFormat) diff --git a/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala b/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala index 22b504099..6533abea9 100644 --- a/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala +++ b/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala @@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils import org.scalatest.{BeforeAndAfterAll, FunSuite} import scalapb.json4s.JsonFormat +import io.delta.sharing.server.DeltaSharingService.DELTA_SHARING_INCLUDE_END_STREAM_ACTION import io.delta.sharing.server.config.ServerConfig import io.delta.sharing.server.model._ import io.delta.sharing.server.protocol._ @@ -121,14 +122,16 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { method: Option[String] = None, data: Option[String] = None, expectedTableVersion: Option[Long] = None, - responseFormat: String = RESPONSE_FORMAT_PARQUET): String = { + responseFormat: String = RESPONSE_FORMAT_PARQUET, + includeEndStreamAction: Boolean = false): String = { readHttpContent( url, method, data, responseFormat, expectedTableVersion, - "application/x-ndjson; charset=utf-8" + "application/x-ndjson; charset=utf-8", + includeEndStreamAction = includeEndStreamAction ) } @@ -139,12 +142,22 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { data: Option[String] = None, responseFormat: String, expectedTableVersion: Option[Long] = None, - expectedContentType: String): String = { + expectedContentType: String, + includeEndStreamAction: Boolean = false): String = { val connection = new URL(url).openConnection().asInstanceOf[HttpsURLConnection] connection.setRequestProperty("Authorization", s"Bearer ${TestResource.testAuthorizationToken}") + val deltaSharingCapabilities = Seq.newBuilder[String] if (responseFormat == RESPONSE_FORMAT_DELTA) { - connection.setRequestProperty("delta-sharing-capabilities", "responseFormat=delta") + deltaSharingCapabilities += s"responseFormat=$responseFormat" } + if (includeEndStreamAction) { + deltaSharingCapabilities += s"$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } + val result = deltaSharingCapabilities.result() + if (result.length > 0) { + connection.setRequestProperty("delta-sharing-capabilities", result.mkString(";")) + } + method.foreach(connection.setRequestMethod) data.foreach { d => connection.setDoOutput(true) @@ -168,8 +181,11 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { s"Incorrect content type: $contentType. Error: $content") if (expectedTableVersion.isDefined) { val responseCapabilities = connection.getHeaderField("delta-sharing-capabilities") - assert(responseCapabilities == s"responseformat=$responseFormat", - s"Incorrect response format: $responseCapabilities") + var expectedHeader = s"responseformat=$responseFormat" + if (includeEndStreamAction) { + expectedHeader += s";$DELTA_SHARING_INCLUDE_END_STREAM_ACTION=true" + } + assert(responseCapabilities == expectedHeader, s"Incorrect header: $responseCapabilities") } val deltaTableVersion = connection.getHeaderField("Delta-Table-Version") expectedTableVersion.foreach { v => @@ -541,93 +557,114 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { } integrationTest("table1 - non partitioned - /shares/{share}/schemas/{schema}/tables/{table}/query") { - Seq( - RESPONSE_FORMAT_PARQUET, - RESPONSE_FORMAT_DELTA, - s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", - s"$RESPONSE_FORMAT_PARQUET,$RESPONSE_FORMAT_DELTA" - ).foreach { responseFormat => - val respondedFormat = if (responseFormat == RESPONSE_FORMAT_DELTA) { - RESPONSE_FORMAT_DELTA - } else { - RESPONSE_FORMAT_PARQUET - } - val p = - s""" - |{ - | "predicateHints": [ - | "date = CAST('2021-04-28' AS DATE)" - | ] - |} - |""".stripMargin - val response = readNDJson(requestPath("/shares/share1/schemas/default/tables/table1/query"), Some("POST"), Some(p), Some(2), respondedFormat) - val lines = response.split("\n") - val protocol = lines(0) - val metadata = lines(1) - if (responseFormat == RESPONSE_FORMAT_DELTA) { - val responseProtocol = JsonUtils.fromJson[DeltaResponseSingleAction](protocol).protocol - assert(responseProtocol.deltaProtocol.minReaderVersion == 1) - - // unable to construct the delta action because the cases classes like AddFile/Metadata - // are private to io.delta.standalone.internal. - // So we only verify a couple important fields. - val responseMetadata = JsonUtils.fromJson[DeltaResponseSingleAction](metadata).metaData - assert(responseMetadata.deltaMetadata.id == "ed96aa41-1d81-4b7f-8fb5-846878b4b0cf") - - val actualFiles = lines.drop(2).map(f => JsonUtils.fromJson[DeltaResponseSingleAction](f).file) - assert(actualFiles(0).id == "061cb3683a467066995f8cdaabd8667d") - assert(actualFiles(0).deltaSingleAction.add != null) - assert(actualFiles(1).id == "e268cbf70dbaa6143e7e9fa3e2d3b00e") - assert(actualFiles(1).deltaSingleAction.add != null) - assert(actualFiles.count(_.expirationTimestamp > System.currentTimeMillis()) == 2) - verifyPreSignedUrl(actualFiles(0).deltaSingleAction.add.path, 781) - verifyPreSignedUrl(actualFiles(1).deltaSingleAction.add.path, 781) - } else { - val expectedProtocol = Protocol(minReaderVersion = 1).wrap - assert(expectedProtocol == JsonUtils.fromJson[SingleAction](protocol)) - val expectedMetadata = Metadata( - id = "ed96aa41-1d81-4b7f-8fb5-846878b4b0cf", - format = Format(), - schemaString = """{"type":"struct","fields":[{"name":"eventTime","type":"timestamp","nullable":true,"metadata":{}},{"name":"date","type":"date","nullable":true,"metadata":{}}]}""", - partitionColumns = Nil).wrap - assert(expectedMetadata == JsonUtils.fromJson[SingleAction](metadata)) - val files = lines.drop(2) - val actualFiles = files.map(f => JsonUtils.fromJson[SingleAction](f).file) - assert(actualFiles.size == 2) - val expectedFiles = Seq( - AddFile( - url = actualFiles(0).url, - expirationTimestamp = actualFiles(0).expirationTimestamp, - id = "061cb3683a467066995f8cdaabd8667d", - partitionValues = Map.empty, - size = 781, - stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T06:32:22.421Z","date":"2021-04-28"},"maxValues":{"eventTime":"2021-04-28T06:32:22.421Z","date":"2021-04-28"},"nullCount":{"eventTime":0,"date":0}}""" - ), - AddFile( - url = actualFiles(1).url, - expirationTimestamp = actualFiles(1).expirationTimestamp, - id = "e268cbf70dbaa6143e7e9fa3e2d3b00e", - partitionValues = Map.empty, - size = 781, - stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T06:32:02.070Z","date":"2021-04-28"},"maxValues":{"eventTime":"2021-04-28T06:32:02.070Z","date":"2021-04-28"},"nullCount":{"eventTime":0,"date":0}}""" - ) + Seq(true, false).foreach { includeEndStreamAction => + Seq( + RESPONSE_FORMAT_PARQUET, + RESPONSE_FORMAT_DELTA, + s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", + s"$RESPONSE_FORMAT_PARQUET,$RESPONSE_FORMAT_DELTA" + ).foreach { responseFormat => + val respondedFormat = if (responseFormat == RESPONSE_FORMAT_DELTA) { + RESPONSE_FORMAT_DELTA + } else { + RESPONSE_FORMAT_PARQUET + } + val p = + s""" + |{ + | "predicateHints": [ + | "date = CAST('2021-04-28' AS DATE)" + | ] + |} + |""".stripMargin + val response = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some(p), + Some(2), + respondedFormat, + includeEndStreamAction = includeEndStreamAction ) - assert(actualFiles.count(_.expirationTimestamp != null) == 2) - assert(expectedFiles == actualFiles.toList) - verifyPreSignedUrl(actualFiles(0).url, 781) - verifyPreSignedUrl(actualFiles(1).url, 781) + var lines = response.split("\n") + val protocol = lines(0) + val metadata = lines(1) + if (includeEndStreamAction) { + val endAction = JsonUtils.fromJson[SingleAction](lines.last).endStreamAction + assert(endAction != null) + assert(endAction.minUrlExpirationTimestamp != null) + lines = lines.dropRight(1) + } + if (responseFormat == RESPONSE_FORMAT_DELTA) { + val responseProtocol = JsonUtils.fromJson[DeltaResponseSingleAction](protocol).protocol + assert(responseProtocol.deltaProtocol.minReaderVersion == 1) + + // unable to construct the delta action because the cases classes like AddFile/Metadata + // are private to io.delta.standalone.internal. + // So we only verify a couple important fields. + val responseMetadata = JsonUtils.fromJson[DeltaResponseSingleAction](metadata).metaData + assert(responseMetadata.deltaMetadata.id == "ed96aa41-1d81-4b7f-8fb5-846878b4b0cf") + + val actualFiles = lines.drop(2).map(f => JsonUtils.fromJson[DeltaResponseSingleAction](f).file) + assert(actualFiles(0).id == "061cb3683a467066995f8cdaabd8667d") + assert(actualFiles(0).deltaSingleAction.add != null) + assert(actualFiles(1).id == "e268cbf70dbaa6143e7e9fa3e2d3b00e") + assert(actualFiles(1).deltaSingleAction.add != null) + assert(actualFiles.count(_.expirationTimestamp > System.currentTimeMillis()) == 2) + verifyPreSignedUrl(actualFiles(0).deltaSingleAction.add.path, 781) + verifyPreSignedUrl(actualFiles(1).deltaSingleAction.add.path, 781) + } else { + val expectedProtocol = Protocol(minReaderVersion = 1).wrap + assert(expectedProtocol == JsonUtils.fromJson[SingleAction](protocol)) + val expectedMetadata = Metadata( + id = "ed96aa41-1d81-4b7f-8fb5-846878b4b0cf", + format = Format(), + schemaString = """{"type":"struct","fields":[{"name":"eventTime","type":"timestamp","nullable":true,"metadata":{}},{"name":"date","type":"date","nullable":true,"metadata":{}}]}""", + partitionColumns = Nil).wrap + assert(expectedMetadata == JsonUtils.fromJson[SingleAction](metadata)) + val files = lines.drop(2) + val actualFiles = files.map(f => JsonUtils.fromJson[SingleAction](f).file) + assert(actualFiles.size == 2) + val expectedFiles = Seq( + AddFile( + url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, + id = "061cb3683a467066995f8cdaabd8667d", + partitionValues = Map.empty, + size = 781, + stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T06:32:22.421Z","date":"2021-04-28"},"maxValues":{"eventTime":"2021-04-28T06:32:22.421Z","date":"2021-04-28"},"nullCount":{"eventTime":0,"date":0}}""" + ), + AddFile( + url = actualFiles(1).url, + expirationTimestamp = actualFiles(1).expirationTimestamp, + id = "e268cbf70dbaa6143e7e9fa3e2d3b00e", + partitionValues = Map.empty, + size = 781, + stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T06:32:02.070Z","date":"2021-04-28"},"maxValues":{"eventTime":"2021-04-28T06:32:02.070Z","date":"2021-04-28"},"nullCount":{"eventTime":0,"date":0}}""" + ) + ) + assert(actualFiles.count(_.expirationTimestamp != null) == 2) + assert(expectedFiles == actualFiles.toList) + verifyPreSignedUrl(actualFiles(0).url, 781) + verifyPreSignedUrl(actualFiles(1).url, 781) + } } } } integrationTest("table1 - non partitioned - paginated query") { - Seq(RESPONSE_FORMAT_PARQUET, RESPONSE_FORMAT_DELTA).foreach { responseFormat => + Seq( + (RESPONSE_FORMAT_PARQUET, true), + (RESPONSE_FORMAT_PARQUET, false), + (RESPONSE_FORMAT_DELTA, true), + (RESPONSE_FORMAT_DELTA, false) + ).foreach { case (responseFormat, includeEndStreamAction) => var response = readNDJson( requestPath("/shares/share1/schemas/default/tables/table1/query"), Some("POST"), Some("""{"maxFiles": 1}"""), Some(2), - responseFormat + responseFormat, + includeEndStreamAction = includeEndStreamAction ) var lines = response.split("\n") assert(lines.length == 4) @@ -716,31 +753,34 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { } integrationTest("refresh query returns the same set of files as initial query") { - val initialResponse = readNDJson( - requestPath("/shares/share1/schemas/default/tables/table1/query"), - Some("POST"), - Some("""{"includeRefreshToken": true}"""), - Some(2) - ).split("\n") - assert(initialResponse.length == 5) - val endAction = JsonUtils.fromJson[SingleAction](initialResponse.last).endStreamAction - assert(endAction.refreshToken != null) + Seq(true, false).foreach { includeEndStreamAction => + val initialResponse = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some("""{"includeRefreshToken": true}"""), + Some(2), + includeEndStreamAction = includeEndStreamAction + ).split("\n") + assert(initialResponse.length == 5) + val endAction = JsonUtils.fromJson[SingleAction](initialResponse.last).endStreamAction + assert(endAction.refreshToken != null) - val refreshResponse = readNDJson( - requestPath("/shares/share1/schemas/default/tables/table1/query"), - Some("POST"), - Some(s"""{"includeRefreshToken": true, "refreshToken": "${endAction.refreshToken}"}"""), - Some(2) - ).split("\n") - assert(refreshResponse.length == 5) - // protocol - assert(initialResponse(0) == refreshResponse(0)) - // metadata - assert(initialResponse(1) == refreshResponse(1)) - // files - val initialFiles = initialResponse.slice(2, 4).map(f => JsonUtils.fromJson[SingleAction](f).file) - val refreshedFiles = refreshResponse.slice(2, 4).map(f => JsonUtils.fromJson[SingleAction](f).file) - assert(initialFiles.map(_.id) sameElements refreshedFiles.map(_.id)) + val refreshResponse = readNDJson( + requestPath("/shares/share1/schemas/default/tables/table1/query"), + Some("POST"), + Some(s"""{"includeRefreshToken": true, "refreshToken": "${endAction.refreshToken}"}"""), + Some(2) + ).split("\n") + assert(refreshResponse.length == 5) + // protocol + assert(initialResponse(0) == refreshResponse(0)) + // metadata + assert(initialResponse(1) == refreshResponse(1)) + // files + val initialFiles = initialResponse.slice(2, 4).map(f => JsonUtils.fromJson[SingleAction](f).file) + val refreshedFiles = refreshResponse.slice(2, 4).map(f => JsonUtils.fromJson[SingleAction](f).file) + assert(initialFiles.map(_.id) sameElements refreshedFiles.map(_.id)) + } } integrationTest("refresh query - exception") { @@ -2253,20 +2293,37 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { integrationTest("cdf_table_cdf_enabled_changes - query table changes") { Seq( - RESPONSE_FORMAT_PARQUET, - RESPONSE_FORMAT_DELTA, - s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", - s"$RESPONSE_FORMAT_PARQUET,$RESPONSE_FORMAT_DELTA" - ).foreach { responseFormat => + (RESPONSE_FORMAT_PARQUET, true), + (RESPONSE_FORMAT_PARQUET, false), + (RESPONSE_FORMAT_DELTA, true), + (RESPONSE_FORMAT_DELTA, false), + (s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", true), + (s"$RESPONSE_FORMAT_DELTA,$RESPONSE_FORMAT_PARQUET", false), + (s"$RESPONSE_FORMAT_PARQUET,$RESPONSE_FORMAT_DELTA", true), + (s"$RESPONSE_FORMAT_PARQUET,$RESPONSE_FORMAT_DELTA", false) + ).foreach { case (responseFormat, includeEndStreamAction) => val respondedFormat = if (responseFormat == RESPONSE_FORMAT_DELTA) { RESPONSE_FORMAT_DELTA } else { RESPONSE_FORMAT_PARQUET } - val response = readNDJson(requestPath(s"/shares/share8/schemas/default/tables/cdf_table_cdf_enabled/changes?startingVersion=0&endingVersion=3"), Some("GET"), None, Some(0), respondedFormat) - val lines = response.split("\n") + val response = readNDJson( + requestPath(s"/shares/share8/schemas/default/tables/cdf_table_cdf_enabled/changes?startingVersion=0&endingVersion=3"), + Some("GET"), + None, + Some(0), + respondedFormat, + includeEndStreamAction = includeEndStreamAction + ) + var lines = response.split("\n") val protocol = lines(0) val metadata = lines(1) + if (includeEndStreamAction) { + val endAction = JsonUtils.fromJson[SingleAction](lines.last).endStreamAction + assert(endAction != null) + assert(endAction.minUrlExpirationTimestamp != null) + lines = lines.dropRight(1) + } if (responseFormat == RESPONSE_FORMAT_DELTA) { val responseProtocol = JsonUtils.fromJson[DeltaResponseSingleAction](protocol).protocol assert(responseProtocol.deltaProtocol.minReaderVersion == 1) @@ -2343,13 +2400,18 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { integrationTest("cdf_table_cdf_enabled_changes - paginated query table changes") { // version 1: 3 adds // version 2: 1 cdc - Seq(RESPONSE_FORMAT_PARQUET, RESPONSE_FORMAT_DELTA).foreach { responseFormat => + Seq( + (RESPONSE_FORMAT_PARQUET, true), + (RESPONSE_FORMAT_PARQUET, false), + (RESPONSE_FORMAT_DELTA, true), + (RESPONSE_FORMAT_DELTA, false)).foreach { case (responseFormat, includeEndStreamAction) => var response = readNDJson( requestPath("/shares/share8/schemas/default/tables/cdf_table_cdf_enabled/changes?startingVersion=0&endingVersion=2&maxFiles=2"), Some("GET"), None, Some(0), - responseFormat + responseFormat, + includeEndStreamAction = includeEndStreamAction ) var lines = response.split("\n") assert(lines.length == 5) diff --git a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSourceSuite.scala b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSourceSuite.scala index e849dc1bd..b892f092d 100644 --- a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSourceSuite.scala +++ b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSourceSuite.scala @@ -248,20 +248,20 @@ class DeltaSharingSourceSuite extends QueryTest /** * Test spark config of query table version interval */ - integrationTest("query table version interval cannot be less than 30 seconds") { + integrationTest("query table version interval cannot be less than 10 seconds") { spark.sessionState.conf.setConfString( "spark.delta.sharing.streaming.queryTableVersionIntervalSeconds", - "29" + "9" ) val message = intercept[Exception] { val query = spark.readStream.format("deltaSharing").load(tablePath) .writeStream.format("console").start() query.processAllAvailable() }.getMessage - assert(message.contains("must not be less than 30 seconds.")) + assert(message.contains("must not be less than 10 seconds.")) spark.sessionState.conf.setConfString( "spark.delta.sharing.streaming.queryTableVersionIntervalSeconds", - "30" + "10" ) } diff --git a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala index 293856c0b..ec45c8015 100644 --- a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala +++ b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingSuite.scala @@ -448,7 +448,8 @@ class DeltaSharingSuite extends QueryTest with SharedSparkSession with DeltaShar .option("startingVersion", 0).load(tablePath) checkAnswer(df, Nil) } - assert (ex.getMessage.contains("""400 Bad Request {"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) + assert(ex.getMessage.contains("""400 Bad Request""")) + assert(ex.getMessage.contains("""{"errorCode":"RESOURCE_DOES_NOT_EXIST"""")) } integrationTest("azure support") { diff --git a/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala b/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala index bdf63fc2a..98e3f2431 100644 --- a/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala +++ b/spark/src/test/scala/io/delta/sharing/spark/TestDeltaSharingClient.scala @@ -46,7 +46,8 @@ class TestDeltaSharingClient( responseFormat: String = DeltaSharingOptions.RESPONSE_FORMAT_PARQUET, readerFeatures: String = "", queryTablePaginationEnabled: Boolean = false, - maxFilesPerReq: Int = 10000 + maxFilesPerReq: Int = 10000, + endStreamActionEnabled: Boolean = false ) extends DeltaSharingClient { import DeltaSharingOptions.RESPONSE_FORMAT_PARQUET